http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java new file mode 100644 index 0000000..4892ff8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.isolve.lsqr; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link LSQROnHeap}. + */ +@RunWith(Parameterized.class) +public class LSQROnHeapTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[] {1}, + new Integer[] {2}, + new Integer[] {3}, + new Integer[] {5}, + new Integer[] {7}, + new Integer[] {100}, + new Integer[] {1000} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** Tests solving simple linear system. */ + @Test + public void testSolveLinearSystem() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[]{3, 2, -1, 1}); + data.put(1, new double[]{2, -2, 4, -2}); + data.put(2, new double[]{-1, 0.5, -1, 0}); + + DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( + datasetBuilder, + new LinSysPartitionDataBuilderOnHeap<>( + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[3], + 3 + ) + ); + + LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); + + assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6); + } + + /** Tests solving simple linear system with specified x0. */ + @Test + public void testSolveLinearSystemWithX0() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[]{3, 2, -1, 1}); + data.put(1, new double[]{2, -2, 4, -2}); + data.put(2, new double[]{-1, 0.5, -1, 0}); + + DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts); + + LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( + datasetBuilder, + new LinSysPartitionDataBuilderOnHeap<>( + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[3], + 3 + ) + ); + + LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, + new double[] {999, 999, 999}); + + assertArrayEquals(new double[]{1, -2, -2}, res.getX(), 1e-6); + } + + /** Tests solving least squares problem. */ + @Test + public void testSolveLeastSquares() throws Exception { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107}); + data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867}); + data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728}); + data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991}); + data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611}); + data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197}); + data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012}); + data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889}); + data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949}); + data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583}); + + DatasetBuilder<Integer, double[]> datasetBuilder = new LocalDatasetBuilder<>(data, 1); + + try (LSQROnHeap<Integer, double[]> lsqr = new LSQROnHeap<>( + datasetBuilder, + new LinSysPartitionDataBuilderOnHeap<>( + (k, v) -> Arrays.copyOf(v, v.length - 1), + (k, v) -> v[4], + 4 + ) + )) { + LSQRResult res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); + + assertArrayEquals(new double[]{72.26948107, 15.95144674, 24.07403921, 66.73038781}, res.getX(), 1e-6); + } + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index 5c79c8f..82b3a1b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -21,6 +21,7 @@ import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionQ import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionSGDTrainerTest; import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionQRTrainerTest; import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionSGDTrainerTest; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionQRTrainerTest; import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionSGDTrainerTest; @@ -38,7 +39,8 @@ import org.junit.runners.Suite; DistributedLinearRegressionQRTrainerTest.class, DistributedLinearRegressionSGDTrainerTest.class, BlockDistributedLinearRegressionQRTrainerTest.class, - BlockDistributedLinearRegressionSGDTrainerTest.class + BlockDistributedLinearRegressionSGDTrainerTest.class, + LinearRegressionLSQRTrainerTest.class }) public class RegressionsTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/2f330a1c/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java new file mode 100644 index 0000000..3bb3ee7 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link LinearRegressionLSQRTrainer}. + */ +@RunWith(Parameterized.class) +public class LinearRegressionLSQRTrainerTest { + /** Parameters. */ + @Parameterized.Parameters(name = "Data divided on {0} partitions") + public static Iterable<Integer[]> data() { + return Arrays.asList( + new Integer[] {1}, + new Integer[] {2}, + new Integer[] {3}, + new Integer[] {5}, + new Integer[] {7}, + new Integer[] {100}, + new Integer[] {1000} + ); + } + + /** Number of partitions. */ + @Parameterized.Parameter + public int parts; + + /** + * Tests {@code fit()} method on a simple small dataset. + */ + @Test + public void testSmallDataFit() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(0, new double[] {-1.0915526, 1.81983527, -0.91409478, 0.70890712, -24.55724107}); + data.put(1, new double[] {-0.61072904, 0.37545517, 0.21705352, 0.09516495, -26.57226867}); + data.put(2, new double[] {0.05485406, 0.88219898, -0.80584547, 0.94668307, 61.80919728}); + data.put(3, new double[] {-0.24835094, -0.34000053, -1.69984651, -1.45902635, -161.65525991}); + data.put(4, new double[] {0.63675392, 0.31675535, 0.38837437, -1.1221971, -14.46432611}); + data.put(5, new double[] {0.14194017, 2.18158997, -0.28397346, -0.62090588, -3.2122197}); + data.put(6, new double[] {-0.53487507, 1.4454797, 0.21570443, -0.54161422, -46.5469012}); + data.put(7, new double[] {-1.58812173, -0.73216803, -2.15670676, -1.03195988, -247.23559889}); + data.put(8, new double[] {0.20702671, 0.92864654, 0.32721202, -0.09047503, 31.61484949}); + data.put(9, new double[] {-0.37890345, -0.04846179, -0.84122753, -1.14667474, -124.92598583}); + + LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>(); + + LinearRegressionModel mdl = trainer.fit( + new LocalDatasetBuilder<>(data, parts), + (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> v[4], + 4 + ); + + assertArrayEquals( + new double[]{72.26948107, 15.95144674, 24.07403921, 66.73038781}, + mdl.getWeights().getStorage().data(), + 1e-6 + ); + + assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-6); + } + + /** + * Tests {@code fit()} method on a big (100000 x 100) dataset. + */ + @Test + public void testBigDataFit() { + Random rnd = new Random(0); + Map<Integer, double[]> data = new HashMap<>(); + double[] coef = new double[100]; + double intercept = rnd.nextDouble() * 10; + + for (int i = 0; i < 100000; i++) { + double[] x = new double[coef.length + 1]; + + for (int j = 0; j < coef.length; j++) + x[j] = rnd.nextDouble() * 10; + + x[coef.length] = intercept; + + data.put(i, x); + } + + LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>(); + + LinearRegressionModel mdl = trainer.fit( + new LocalDatasetBuilder<>(data, parts), + (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), + (k, v) -> v[coef.length], + coef.length + ); + + assertArrayEquals(coef, mdl.getWeights().getStorage().data(), 1e-6); + + assertEquals(intercept, mdl.getIntercept(), 1e-6); + } +}