Repository: ignite Updated Branches: refs/heads/master 67023a88b -> df6356d5d
http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java new file mode 100644 index 0000000..fa8fac4 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; +import org.apache.ignite.ml.trainers.group.UpdatesStrategy; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link LinearRegressionSGDTrainer}. + */ +@RunWith(Parameterized.class) +public class LinearRegressionSGDTrainerTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[] {1}, + new Integer[] {2}, + new Integer[] {3}, + new Integer[] {5}, + new Integer[] {7}, + new Integer[] {100} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** + * Tests {@code fit()} method on a simple small dataset. + */ + @Test + public void testSmallDataFit() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107}); + data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867}); + data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728}); + data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991}); + data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611}); + data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197}); + data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012}); + data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889}); + data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949}); + data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583}); + + LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( + new RPropUpdateCalculator(), + RPropParameterUpdate::sumLocal, + RPropParameterUpdate::avg + ), 100000, 10, 100, 123L); + + LinearRegressionModel mdl = trainer.fit( + data, + parts, + (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> v[4] + ); + + assertArrayEquals( + new double[] {72.26948107, 15.95144674, 24.07403921, 66.73038781}, + mdl.getWeights().getStorage().data(), + 1e-1 + ); + + assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java deleted file mode 100644 index bea164d..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionSGDTrainerTest.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -/** - * Tests for {@link LinearRegressionSGDTrainer} on {@link DenseLocalOnHeapMatrix}. - */ -public class LocalLinearRegressionSGDTrainerTest extends GenericLinearRegressionTrainerTest { - /** */ - public LocalLinearRegressionSGDTrainerTest() { - super( - new LinearRegressionSGDTrainer(100_000, 1e-12), - DenseLocalOnHeapMatrix::new, - DenseLocalOnHeapVector::new, - 1e-2); - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java index 26ba2fb..0befd9b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java @@ -17,14 +17,14 @@ package org.apache.ignite.ml.svm; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.junit.Test; + import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.junit.Test; /** * Tests for {@link SVMLinearBinaryClassificationTrainer}. @@ -62,7 +62,8 @@ public class SVMBinaryTrainerTest { SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); SVMLinearBinaryClassificationModel mdl = trainer.fit( - new LocalDatasetBuilder<>(data, 10), + data, + 10, (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java index ad95eb4..31ab4d7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java @@ -17,14 +17,14 @@ package org.apache.ignite.ml.svm; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.junit.Test; + import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.junit.Test; /** * Tests for {@link SVMLinearBinaryClassificationTrainer}. @@ -65,7 +65,8 @@ public class SVMMultiClassTrainerTest { .withAmountOfIterations(20); SVMLinearMultiClassClassificationModel mdl = trainer.fit( - new LocalDatasetBuilder<>(data, 10), + data, + 10, (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java index 94bca3f..d5b0b86 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java @@ -17,16 +17,16 @@ package org.apache.ignite.ml.tree; -import java.util.Arrays; -import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import java.util.Arrays; +import java.util.Random; + /** * Tests for {@link DecisionTreeClassificationTrainer} that require to start the whole Ignite infrastructure. */ @@ -77,7 +77,8 @@ public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommon DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); DecisionTreeNode tree = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, data), + ignite, + data, (k, v) -> Arrays.copyOf(v, v.length - 1), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java index 2599bfe..12ef698 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java @@ -17,17 +17,12 @@ package org.apache.ignite.ml.tree; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.*; + import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertTrue; @@ -68,7 +63,8 @@ public class DecisionTreeClassificationTrainerTest { DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); DecisionTreeNode tree = trainer.fit( - new LocalDatasetBuilder<>(data, parts), + data, + parts, (k, v) -> Arrays.copyOf(v, v.length - 1), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java index 754ff20..c2a4638 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java @@ -17,16 +17,16 @@ package org.apache.ignite.ml.tree; -import java.util.Arrays; -import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import java.util.Arrays; +import java.util.Random; + /** * Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite infrastructure. */ @@ -77,7 +77,8 @@ public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbst DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0); DecisionTreeNode tree = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, data), + ignite, + data, (k, v) -> Arrays.copyOf(v, v.length - 1), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java index 3bdbf60..bcfb53f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -17,17 +17,12 @@ package org.apache.ignite.ml.tree; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.*; + import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertTrue; @@ -68,7 +63,8 @@ public class DecisionTreeRegressionTrainerTest { DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0); DecisionTreeNode tree = trainer.fit( - new LocalDatasetBuilder<>(data, parts), + data, + parts, (k, v) -> Arrays.copyOf(v, v.length - 1), (k, v) -> v[v.length - 1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java index b259ec9..35f805e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTIntegrationTest.java @@ -17,13 +17,11 @@ package org.apache.ignite.ml.tree.performance; -import java.io.IOException; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -31,6 +29,8 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor; import org.apache.ignite.ml.util.MnistUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import java.io.IOException; + /** * Tests {@link DecisionTreeClassificationTrainer} on the MNIST dataset that require to start the whole Ignite * infrastructure. For manual run. @@ -81,7 +81,8 @@ public class DecisionTreeMNISTIntegrationTest extends GridCommonAbstractTest { new SimpleStepFunctionCompressor<>()); DecisionTreeNode mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, trainingSet), + ignite, + trainingSet, (k, v) -> v.getPixels(), (k, v) -> (double) v.getLabel() ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java index 6dbd44c..b40c7ac 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/performance/DecisionTreeMNISTTest.java @@ -17,10 +17,6 @@ package org.apache.ignite.ml.tree.performance; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.nn.performance.MnistMLPTestUtil; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -28,6 +24,10 @@ import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressor; import org.apache.ignite.ml.util.MnistUtils; import org.junit.Test; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + import static junit.framework.TestCase.assertTrue; /** @@ -50,7 +50,8 @@ public class DecisionTreeMNISTTest { new SimpleStepFunctionCompressor<>()); DecisionTreeNode mdl = trainer.fit( - new LocalDatasetBuilder<>(trainingSet, 10), + trainingSet, + 10, (k, v) -> v.getPixels(), (k, v) -> (double) v.getLabel() );