IGNITE-9055: [ML] SVM throws NPE in case of empty partitions this closes #4412
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/d0e6def7 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/d0e6def7 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/d0e6def7 Branch: refs/heads/ignite-8446 Commit: d0e6def704f2d1646f9bf5b4b9a19243f6d48f75 Parents: 13e2a31 Author: dmitrievanthony <[email protected]> Authored: Tue Jul 24 17:56:50 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Jul 24 17:56:50 2018 +0300 ---------------------------------------------------------------------- .../dataset/impl/cache/CacheBasedDataset.java | 3 +- .../ml/svm/SVMBinaryTrainerIntegrationTest.java | 102 +++++++++++++++++++ .../org/apache/ignite/ml/svm/SVMTestSuite.java | 3 +- 3 files changed, 106 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/d0e6def7/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java index 1b492a7..67e0d56 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java @@ -163,7 +163,8 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose R res = identity; for (R partRes : results) - res = reduce.apply(res, partRes); + if (partRes != null) + res = reduce.apply(res, partRes); return res; } http://git-wip-us.apache.org/repos/asf/ignite/blob/d0e6def7/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java new file mode 100644 index 0000000..d227de7 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.svm; + +import java.util.Arrays; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link SVMLinearBinaryClassificationTrainer} that require to start the whole Ignite infrastructure. + */ +public class SVMBinaryTrainerIntegrationTest extends GridCommonAbstractTest { + /** Fixed size of Dataset. */ + private static final int AMOUNT_OF_OBSERVATIONS = 1000; + + /** Fixed size of columns in Dataset. */ + private static final int AMOUNT_OF_FEATURES = 2; + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** + * Test trainer on classification model y = x. + */ + public void testTrainWithTheLinearlySeparableCase() { + IgniteCache<Integer, double[]> data = ignite.getOrCreateCache(UUID.randomUUID().toString()); + + ThreadLocalRandom rndX = ThreadLocalRandom.current(); + ThreadLocalRandom rndY = ThreadLocalRandom.current(); + + for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) { + double x = rndX.nextDouble(-1000, 1000); + double y = rndY.nextDouble(-1000, 1000); + double[] vec = new double[AMOUNT_OF_FEATURES + 1]; + vec[0] = y - x > 0 ? 1 : -1; // assign label. + vec[1] = x; + vec[2] = y; + data.put(i, vec); + } + + SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); + + SVMLinearBinaryClassificationModel mdl = trainer.fit( + ignite, + data, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 10})), PRECISION); + TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 100})), PRECISION); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0e6def7/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java index 8178e30..822ad18 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java @@ -27,7 +27,8 @@ import org.junit.runners.Suite; @Suite.SuiteClasses({ SVMModelTest.class, SVMBinaryTrainerTest.class, - SVMMultiClassTrainerTest.class + SVMMultiClassTrainerTest.class, + SVMBinaryTrainerIntegrationTest.class }) public class SVMTestSuite { // No-op.
