IGNITE-9055: [ML] SVM throws NPE in case of empty partitions

this closes #4412


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/d0e6def7
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/d0e6def7
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/d0e6def7

Branch: refs/heads/ignite-8446
Commit: d0e6def704f2d1646f9bf5b4b9a19243f6d48f75
Parents: 13e2a31
Author: dmitrievanthony <[email protected]>
Authored: Tue Jul 24 17:56:50 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Tue Jul 24 17:56:50 2018 +0300

----------------------------------------------------------------------
 .../dataset/impl/cache/CacheBasedDataset.java   |   3 +-
 .../ml/svm/SVMBinaryTrainerIntegrationTest.java | 102 +++++++++++++++++++
 .../org/apache/ignite/ml/svm/SVMTestSuite.java  |   3 +-
 3 files changed, 106 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/d0e6def7/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
index 1b492a7..67e0d56 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
@@ -163,7 +163,8 @@ public class CacheBasedDataset<K, V, C extends 
Serializable, D extends AutoClose
 
         R res = identity;
         for (R partRes : results)
-            res = reduce.apply(res, partRes);
+            if (partRes != null)
+                res = reduce.apply(res, partRes);
 
         return res;
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/d0e6def7/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
new file mode 100644
index 0000000..d227de7
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.svm;
+
+import java.util.Arrays;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests for {@link SVMLinearBinaryClassificationTrainer} that require to 
start the whole Ignite infrastructure.
+ */
+public class SVMBinaryTrainerIntegrationTest extends GridCommonAbstractTest {
+    /** Fixed size of Dataset. */
+    private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+    /** Fixed size of columns in Dataset. */
+    private static final int AMOUNT_OF_FEATURES = 2;
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Number of nodes in grid */
+    private static final int NODE_COUNT = 3;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /**
+     * {@inheritDoc}
+     */
+    @Override protected void beforeTest() throws Exception {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /**
+     * Test trainer on classification model y = x.
+     */
+    public void testTrainWithTheLinearlySeparableCase() {
+        IgniteCache<Integer, double[]> data = 
ignite.getOrCreateCache(UUID.randomUUID().toString());
+
+        ThreadLocalRandom rndX = ThreadLocalRandom.current();
+        ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+        for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+            double x = rndX.nextDouble(-1000, 1000);
+            double y = rndY.nextDouble(-1000, 1000);
+            double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+            vec[0] = y - x > 0 ? 1 : -1; // assign label.
+            vec[1] = x;
+            vec[2] = y;
+            data.put(i, vec);
+        }
+
+        SVMLinearBinaryClassificationTrainer trainer = new 
SVMLinearBinaryClassificationTrainer();
+
+        SVMLinearBinaryClassificationModel mdl = trainer.fit(
+            ignite,
+            data,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 
10})), PRECISION);
+        TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 
100})), PRECISION);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/d0e6def7/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
index 8178e30..822ad18 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
@@ -27,7 +27,8 @@ import org.junit.runners.Suite;
 @Suite.SuiteClasses({
     SVMModelTest.class,
     SVMBinaryTrainerTest.class,
-    SVMMultiClassTrainerTest.class
+    SVMMultiClassTrainerTest.class,
+    SVMBinaryTrainerIntegrationTest.class
 })
 public class SVMTestSuite {
     // No-op.

Reply via email to