http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java deleted file mode 100644 index 2774028..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTest.java +++ /dev/null @@ -1,820 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions; - -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; -import org.apache.ignite.ml.math.exceptions.NullArgumentException; -import org.apache.ignite.ml.math.exceptions.SingularMatrixException; -import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.math.util.MatrixUtil; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -/** - * Tests for {@link OLSMultipleLinearRegression}. - */ -public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegressionTest { - /** */ - private double[] y; - - /** */ - private double[][] x; - - /** */ - @Before - @Override public void setUp() { - y = new double[] {11.0, 12.0, 13.0, 14.0, 15.0, 16.0}; - x = new double[6][]; - x[0] = new double[] {0, 0, 0, 0, 0}; - x[1] = new double[] {2.0, 0, 0, 0, 0}; - x[2] = new double[] {0, 3.0, 0, 0, 0}; - x[3] = new double[] {0, 0, 4.0, 0, 0}; - x[4] = new double[] {0, 0, 0, 5.0, 0}; - x[5] = new double[] {0, 0, 0, 0, 6.0}; - super.setUp(); - } - - /** */ - @Override protected OLSMultipleLinearRegression createRegression() { - OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); - regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x)); - return regression; - } - - /** */ - @Override protected int getNumberOfRegressors() { - return x[0].length + 1; - } - - /** */ - @Override protected int getSampleSize() { - return y.length; - } - - /** */ - @Test(expected = MathIllegalArgumentException.class) - public void cannotAddSampleDataWithSizeMismatch() { - double[] y = new double[] {1.0, 2.0}; - double[][] x = new double[1][]; - x[0] = new double[] {1.0, 0}; - createRegression().newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x)); - } - - /** */ - @Test - public void testPerfectFit() { - double[] betaHat = regression.estimateRegressionParameters(); - TestUtils.assertEquals(new double[] {11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0}, - betaHat, - 1e-13); - double[] residuals = regression.estimateResiduals(); - TestUtils.assertEquals(new double[] {0d, 0d, 0d, 0d, 0d, 0d}, residuals, - 1e-13); - Matrix errors = regression.estimateRegressionParametersVariance(); - final double[] s = {1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0}; - Matrix refVar = new DenseLocalOnHeapMatrix(s.length, s.length); - for (int i = 0; i < refVar.rowSize(); i++) - for (int j = 0; j < refVar.columnSize(); j++) { - if (i == 0) { - refVar.setX(i, j, s[j]); - continue; - } - double x = s[i] * s[j]; - refVar.setX(i, j, (i == j) ? 2 * x : x); - } - Assert.assertEquals(0.0, - TestUtils.maximumAbsoluteRowSum(errors.minus(refVar)), - 5.0e-16 * TestUtils.maximumAbsoluteRowSum(refVar)); - Assert.assertEquals(1, ((OLSMultipleLinearRegression)regression).calculateRSquared(), 1E-12); - } - - /** - * Test Longley dataset against certified values provided by NIST. - * Data Source: J. Longley (1967) "An Appraisal of Least Squares - * Programs for the Electronic Computer from the Point of View of the User" - * Journal of the American Statistical Association, vol. 62. September, - * pp. 819-841. - * - * Certified values (and data) are from NIST: - * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat - */ - @Test - public void testLongly() { - // Y values are first, then independent vars - // Each row is one observation - double[] design = new double[] { - 60323, 83.0, 234289, 2356, 1590, 107608, 1947, - 61122, 88.5, 259426, 2325, 1456, 108632, 1948, - 60171, 88.2, 258054, 3682, 1616, 109773, 1949, - 61187, 89.5, 284599, 3351, 1650, 110929, 1950, - 63221, 96.2, 328975, 2099, 3099, 112075, 1951, - 63639, 98.1, 346999, 1932, 3594, 113270, 1952, - 64989, 99.0, 365385, 1870, 3547, 115094, 1953, - 63761, 100.0, 363112, 3578, 3350, 116219, 1954, - 66019, 101.2, 397469, 2904, 3048, 117388, 1955, - 67857, 104.6, 419180, 2822, 2857, 118734, 1956, - 68169, 108.4, 442769, 2936, 2798, 120445, 1957, - 66513, 110.8, 444546, 4681, 2637, 121950, 1958, - 68655, 112.6, 482704, 3813, 2552, 123366, 1959, - 69564, 114.2, 502601, 3931, 2514, 125368, 1960, - 69331, 115.7, 518173, 4806, 2572, 127852, 1961, - 70551, 116.9, 554894, 4007, 2827, 130081, 1962 - }; - - final int nobs = 16; - final int nvars = 6; - - // Estimate the model - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix()); - - // Check expected beta values from NIST - double[] betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - -3482258.63459582, 15.0618722713733, - -0.358191792925910E-01, -2.02022980381683, - -1.03322686717359, -0.511041056535807E-01, - 1829.15146461355}, 2E-6); // - - // Check expected residuals from R - double[] residuals = mdl.estimateResiduals(); - TestUtils.assertEquals(residuals, new double[] { - 267.340029759711, -94.0139423988359, 46.28716775752924, - -410.114621930906, 309.7145907602313, -249.3112153297231, - -164.0489563956039, -13.18035686637081, 14.30477260005235, - 455.394094551857, -17.26892711483297, -39.0550425226967, - -155.5499735953195, -85.6713080421283, 341.9315139607727, - -206.7578251937366}, - 1E-7); - - // Check standard errors from NIST - double[] errors = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(new double[] { - 890420.383607373, - 84.9149257747669, - 0.334910077722432E-01, - 0.488399681651699, - 0.214274163161675, - 0.226073200069370, - 455.478499142212}, errors, 1E-6); - - // Check regression standard error against R - Assert.assertEquals(304.8540735619638, mdl.estimateRegressionStandardError(), 1E-8); - - // Check R-Square statistics against R - Assert.assertEquals(0.995479004577296, mdl.calculateRSquared(), 1E-12); - Assert.assertEquals(0.992465007628826, mdl.calculateAdjustedRSquared(), 1E-12); - - // TODO: IGNITE-5826, uncomment. - // checkVarianceConsistency(model); - - // Estimate model without intercept - mdl.setNoIntercept(true); - mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix()); - - // Check expected beta values from R - betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - -52.99357013868291, 0.07107319907358, - -0.42346585566399, -0.57256866841929, - -0.41420358884978, 48.41786562001326}, 1E-8); - - // Check standard errors from R - errors = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(new double[] { - 129.54486693117232, 0.03016640003786, - 0.41773654056612, 0.27899087467676, 0.32128496193363, - 17.68948737819961}, errors, 1E-11); - - // Check expected residuals from R - residuals = mdl.estimateResiduals(); - TestUtils.assertEquals(residuals, new double[] { - 279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948, - -440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932, - 73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709, - -361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846}, - 1E-8); - - // Check regression standard error against R - Assert.assertEquals(475.1655079819517, mdl.estimateRegressionStandardError(), 1E-10); - - // Check R-Square statistics against R - Assert.assertEquals(0.9999670130706, mdl.calculateRSquared(), 1E-12); - Assert.assertEquals(0.999947220913, mdl.calculateAdjustedRSquared(), 1E-12); - - } - - /** - * Test R Swiss fertility dataset against R. - * Data Source: R datasets package - */ - @Test - public void testSwissFertility() { - double[] design = new double[] { - 80.2, 17.0, 15, 12, 9.96, - 83.1, 45.1, 6, 9, 84.84, - 92.5, 39.7, 5, 5, 93.40, - 85.8, 36.5, 12, 7, 33.77, - 76.9, 43.5, 17, 15, 5.16, - 76.1, 35.3, 9, 7, 90.57, - 83.8, 70.2, 16, 7, 92.85, - 92.4, 67.8, 14, 8, 97.16, - 82.4, 53.3, 12, 7, 97.67, - 82.9, 45.2, 16, 13, 91.38, - 87.1, 64.5, 14, 6, 98.61, - 64.1, 62.0, 21, 12, 8.52, - 66.9, 67.5, 14, 7, 2.27, - 68.9, 60.7, 19, 12, 4.43, - 61.7, 69.3, 22, 5, 2.82, - 68.3, 72.6, 18, 2, 24.20, - 71.7, 34.0, 17, 8, 3.30, - 55.7, 19.4, 26, 28, 12.11, - 54.3, 15.2, 31, 20, 2.15, - 65.1, 73.0, 19, 9, 2.84, - 65.5, 59.8, 22, 10, 5.23, - 65.0, 55.1, 14, 3, 4.52, - 56.6, 50.9, 22, 12, 15.14, - 57.4, 54.1, 20, 6, 4.20, - 72.5, 71.2, 12, 1, 2.40, - 74.2, 58.1, 14, 8, 5.23, - 72.0, 63.5, 6, 3, 2.56, - 60.5, 60.8, 16, 10, 7.72, - 58.3, 26.8, 25, 19, 18.46, - 65.4, 49.5, 15, 8, 6.10, - 75.5, 85.9, 3, 2, 99.71, - 69.3, 84.9, 7, 6, 99.68, - 77.3, 89.7, 5, 2, 100.00, - 70.5, 78.2, 12, 6, 98.96, - 79.4, 64.9, 7, 3, 98.22, - 65.0, 75.9, 9, 9, 99.06, - 92.2, 84.6, 3, 3, 99.46, - 79.3, 63.1, 13, 13, 96.83, - 70.4, 38.4, 26, 12, 5.62, - 65.7, 7.7, 29, 11, 13.79, - 72.7, 16.7, 22, 13, 11.22, - 64.4, 17.6, 35, 32, 16.92, - 77.6, 37.6, 15, 7, 4.97, - 67.6, 18.7, 25, 7, 8.65, - 35.0, 1.2, 37, 53, 42.34, - 44.7, 46.6, 16, 29, 50.43, - 42.8, 27.7, 22, 29, 58.33 - }; - - final int nobs = 47; - final int nvars = 4; - - // Estimate the model - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix()); - - // Check expected beta values from R - double[] betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - 91.05542390271397, - -0.22064551045715, - -0.26058239824328, - -0.96161238456030, - 0.12441843147162}, 1E-12); - - // Check expected residuals from R - double[] residuals = mdl.estimateResiduals(); - TestUtils.assertEquals(residuals, new double[] { - 7.1044267859730512, 1.6580347433531366, - 4.6944952770029644, 8.4548022690166160, 13.6547432343186212, - -9.3586864458500774, 7.5822446330520386, 15.5568995563859289, - 0.8113090736598980, 7.1186762732484308, 7.4251378771228724, - 2.6761316873234109, 0.8351584810309354, 7.1769991119615177, - -3.8746753206299553, -3.1337779476387251, -0.1412575244091504, - 1.1186809170469780, -6.3588097346816594, 3.4039270429434074, - 2.3374058329820175, -7.9272368576900503, -7.8361010968497959, - -11.2597369269357070, 0.9445333697827101, 6.6544245101380328, - -0.9146136301118665, -4.3152449403848570, -4.3536932047009183, - -3.8907885169304661, -6.3027643926302188, -7.8308982189289091, - -3.1792280015332750, -6.7167298771158226, -4.8469946718041754, - -10.6335664353633685, 11.1031134362036958, 6.0084032641811733, - 5.4326230830188482, -7.2375578629692230, 2.1671550814448222, - 15.0147574652763112, 4.8625103516321015, -7.1597256413907706, - -0.4515205619767598, -10.2916870903837587, -15.7812984571900063}, - 1E-12); - - // Check standard errors from R - double[] errors = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(new double[] { - 6.94881329475087, - 0.07360008972340, - 0.27410957467466, - 0.19454551679325, - 0.03726654773803}, errors, 1E-10); - - // Check regression standard error against R - Assert.assertEquals(7.73642194433223, mdl.estimateRegressionStandardError(), 1E-12); - - // Check R-Square statistics against R - Assert.assertEquals(0.649789742860228, mdl.calculateRSquared(), 1E-12); - Assert.assertEquals(0.6164363850373927, mdl.calculateAdjustedRSquared(), 1E-12); - - // TODO: IGNITE-5826, uncomment. - // checkVarianceConsistency(model); - - // Estimate the model with no intercept - mdl = new OLSMultipleLinearRegression(); - mdl.setNoIntercept(true); - mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix()); - - // Check expected beta values from R - betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - 0.52191832900513, - 2.36588087917963, - -0.94770353802795, - 0.30851985863609}, 1E-12); - - // Check expected residuals from R - residuals = mdl.estimateResiduals(); - TestUtils.assertEquals(residuals, new double[] { - 44.138759883538249, 27.720705122356215, 35.873200836126799, - 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814, - 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165, - -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336, - -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924, - -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944, - -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898, - -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603, - -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384, - -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695, - -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976, - 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615}, - 1E-12); - - // Check standard errors from R - errors = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(new double[] { - 0.10470063765677, 0.41684100584290, - 0.43370143099691, 0.07694953606522}, errors, 1E-10); - - // Check regression standard error against R - Assert.assertEquals(17.24710630547, mdl.estimateRegressionStandardError(), 1E-10); - - // Check R-Square statistics against R - Assert.assertEquals(0.946350722085, mdl.calculateRSquared(), 1E-12); - Assert.assertEquals(0.9413600915813, mdl.calculateAdjustedRSquared(), 1E-12); - } - - /** - * Test hat matrix computation - */ - @Test - public void testHat() { - - /* - * This example is from "The Hat Matrix in Regression and ANOVA", - * David C. Hoaglin and Roy E. Welsch, - * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. - * - */ - double[] design = new double[] { - 11.14, .499, 11.1, - 12.74, .558, 8.9, - 13.13, .604, 8.8, - 11.51, .441, 8.9, - 12.38, .550, 8.8, - 12.60, .528, 9.9, - 11.13, .418, 10.7, - 11.7, .480, 10.5, - 11.02, .406, 10.5, - 11.41, .467, 10.7 - }; - - int nobs = 10; - int nvars = 2; - - // Estimate the model - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.newSampleData(design, nobs, nvars, new DenseLocalOnHeapMatrix()); - - Matrix hat = mdl.calculateHat(); - - - // Reference data is upper half of symmetric hat matrix - double[] refData = new double[] { - .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242, - .242, .292, .136, .243, .128, -.041, .033, -.035, .004, - .417, -.019, .273, .187, -.126, .044, -.153, .004, - .604, .197, -.038, .168, -.022, .275, -.028, - .252, .111, -.030, .019, -.010, -.010, - .148, .042, .117, .012, .111, - .262, .145, .277, .174, - .154, .120, .168, - .315, .148, - .187 - }; - - // Check against reference data and verify symmetry - int k = 0; - for (int i = 0; i < 10; i++) { - for (int j = i; j < 10; j++) { - Assert.assertEquals(refData[k], hat.getX(i, j), 10e-3); - Assert.assertEquals(hat.getX(i, j), hat.getX(j, i), 10e-12); - k++; - } - } - - /* - * Verify that residuals computed using the hat matrix are close to - * what we get from direct computation, i.e. r = (I - H) y - */ - double[] residuals = mdl.estimateResiduals(); - Matrix id = MatrixUtil.identityLike(hat, 10); - double[] hatResiduals = id.minus(hat).times(mdl.getY()).getStorage().data(); - TestUtils.assertEquals(residuals, hatResiduals, 10e-12); - } - - /** - * test calculateYVariance - */ - @Test - public void testYVariance() { - // assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0}; - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x)); - TestUtils.assertEquals(mdl.calculateYVariance(), 3.5, 0); - } - - /** - * Verifies that setting X and Y separately has the same effect as newSample(X,Y). - */ - @Test - public void testNewSample2() { - double[] y = new double[] {1, 2, 3, 4}; - double[][] x = new double[][] { - {19, 22, 33}, - {20, 30, 40}, - {25, 35, 45}, - {27, 37, 47} - }; - OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); - regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x)); - Matrix combinedX = regression.getX().copy(); - Vector combinedY = regression.getY().copy(); - regression.newXSampleData(new DenseLocalOnHeapMatrix(x)); - regression.newYSampleData(new DenseLocalOnHeapVector(y)); - Assert.assertEquals(combinedX, regression.getX()); - Assert.assertEquals(combinedY, regression.getY()); - - // No intercept - regression.setNoIntercept(true); - regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x)); - combinedX = regression.getX().copy(); - combinedY = regression.getY().copy(); - regression.newXSampleData(new DenseLocalOnHeapMatrix(x)); - regression.newYSampleData(new DenseLocalOnHeapVector(y)); - Assert.assertEquals(combinedX, regression.getX()); - Assert.assertEquals(combinedY, regression.getY()); - } - - /** */ - @Test(expected = NullArgumentException.class) - public void testNewSampleDataYNull() { - createRegression().newSampleData(null, new DenseLocalOnHeapMatrix(new double[][] {{1}})); - } - - /** */ - @Test(expected = NullArgumentException.class) - public void testNewSampleDataXNull() { - createRegression().newSampleData(new DenseLocalOnHeapVector(new double[] {}), null); - } - - /** - * This is a test based on the Wampler1 data set - * http://www.itl.nist.gov/div898/strd/lls/data/Wampler1.shtml - */ - @Test - public void testWampler1() { - double[] data = new double[] { - 1, 0, - 6, 1, - 63, 2, - 364, 3, - 1365, 4, - 3906, 5, - 9331, 6, - 19608, 7, - 37449, 8, - 66430, 9, - 111111, 10, - 177156, 11, - 271453, 12, - 402234, 13, - 579195, 14, - 813616, 15, - 1118481, 16, - 1508598, 17, - 2000719, 18, - 2613660, 19, - 3368421, 20}; - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - - final int nvars = 5; - final int nobs = 21; - double[] tmp = new double[(nvars + 1) * nobs]; - int off = 0; - int off2 = 0; - for (int i = 0; i < nobs; i++) { - tmp[off2] = data[off]; - tmp[off2 + 1] = data[off + 1]; - tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1]; - tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2]; - tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3]; - tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4]; - off2 += (nvars + 1); - off += 2; - } - mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix()); - double[] betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - 1.0, - 1.0, 1.0, - 1.0, 1.0, - 1.0}, 1E-8); - - double[] se = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(se, - new double[] { - 0.0, - 0.0, 0.0, - 0.0, 0.0, - 0.0}, 1E-8); - - TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10); - TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7); - TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6); - } - - /** - * This is a test based on the Wampler2 data set - * http://www.itl.nist.gov/div898/strd/lls/data/Wampler2.shtml - */ - @Test - public void testWampler2() { - double[] data = new double[] { - 1.00000, 0, - 1.11111, 1, - 1.24992, 2, - 1.42753, 3, - 1.65984, 4, - 1.96875, 5, - 2.38336, 6, - 2.94117, 7, - 3.68928, 8, - 4.68559, 9, - 6.00000, 10, - 7.71561, 11, - 9.92992, 12, - 12.75603, 13, - 16.32384, 14, - 20.78125, 15, - 26.29536, 16, - 33.05367, 17, - 41.26528, 18, - 51.16209, 19, - 63.00000, 20}; - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - - final int nvars = 5; - final int nobs = 21; - double[] tmp = new double[(nvars + 1) * nobs]; - int off = 0; - int off2 = 0; - for (int i = 0; i < nobs; i++) { - tmp[off2] = data[off]; - tmp[off2 + 1] = data[off + 1]; - tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1]; - tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2]; - tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3]; - tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4]; - off2 += (nvars + 1); - off += 2; - } - mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix()); - double[] betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - 1.0, - 1.0e-1, - 1.0e-2, - 1.0e-3, 1.0e-4, - 1.0e-5}, 1E-8); - - double[] se = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(se, - new double[] { - 0.0, - 0.0, 0.0, - 0.0, 0.0, - 0.0}, 1E-8); - TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10); - TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7); - TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6); - } - - /** - * This is a test based on the Wampler3 data set - * http://www.itl.nist.gov/div898/strd/lls/data/Wampler3.shtml - */ - @Test - public void testWampler3() { - double[] data = new double[] { - 760, 0, - -2042, 1, - 2111, 2, - -1684, 3, - 3888, 4, - 1858, 5, - 11379, 6, - 17560, 7, - 39287, 8, - 64382, 9, - 113159, 10, - 175108, 11, - 273291, 12, - 400186, 13, - 581243, 14, - 811568, 15, - 1121004, 16, - 1506550, 17, - 2002767, 18, - 2611612, 19, - 3369180, 20}; - - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - final int nvars = 5; - final int nobs = 21; - double[] tmp = new double[(nvars + 1) * nobs]; - int off = 0; - int off2 = 0; - for (int i = 0; i < nobs; i++) { - tmp[off2] = data[off]; - tmp[off2 + 1] = data[off + 1]; - tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1]; - tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2]; - tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3]; - tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4]; - off2 += (nvars + 1); - off += 2; - } - mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix()); - double[] betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0}, 1E-8); - - double[] se = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(se, - new double[] { - 2152.32624678170, - 2363.55173469681, 779.343524331583, - 101.475507550350, 5.64566512170752, - 0.112324854679312}, 1E-8); // - - TestUtils.assertEquals(.999995559025820, mdl.calculateRSquared(), 1.0e-10); - TestUtils.assertEquals(5570284.53333333, mdl.estimateErrorVariance(), 1.0e-6); - TestUtils.assertEquals(83554268.0000000, mdl.calculateResidualSumOfSquares(), 1.0e-5); - } - - /** - * This is a test based on the Wampler4 data set - * http://www.itl.nist.gov/div898/strd/lls/data/Wampler4.shtml - */ - @Test - public void testWampler4() { - double[] data = new double[] { - 75901, 0, - -204794, 1, - 204863, 2, - -204436, 3, - 253665, 4, - -200894, 5, - 214131, 6, - -185192, 7, - 221249, 8, - -138370, 9, - 315911, 10, - -27644, 11, - 455253, 12, - 197434, 13, - 783995, 14, - 608816, 15, - 1370781, 16, - 1303798, 17, - 2205519, 18, - 2408860, 19, - 3444321, 20}; - - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - final int nvars = 5; - final int nobs = 21; - double[] tmp = new double[(nvars + 1) * nobs]; - int off = 0; - int off2 = 0; - for (int i = 0; i < nobs; i++) { - tmp[off2] = data[off]; - tmp[off2 + 1] = data[off + 1]; - tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1]; - tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2]; - tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3]; - tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4]; - off2 += (nvars + 1); - off += 2; - } - mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix()); - double[] betaHat = mdl.estimateRegressionParameters(); - TestUtils.assertEquals(betaHat, - new double[] { - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0}, 1E-6); - - double[] se = mdl.estimateRegressionParametersStandardErrors(); - TestUtils.assertEquals(se, - new double[] { - 215232.624678170, - 236355.173469681, 77934.3524331583, - 10147.5507550350, 564.566512170752, - 11.2324854679312}, 1E-8); - - TestUtils.assertEquals(.957478440825662, mdl.calculateRSquared(), 1.0e-10); - TestUtils.assertEquals(55702845333.3333, mdl.estimateErrorVariance(), 1.0e-4); - TestUtils.assertEquals(835542680000.000, mdl.calculateResidualSumOfSquares(), 1.0e-3); - } - - /** - * Anything requiring beta calculation should advertise SME. - */ - @Test(expected = SingularMatrixException.class) - public void testSingularCalculateBeta() { - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(1e-15); - mdl.newSampleData(new double[] {1, 2, 3, 1, 2, 3, 1, 2, 3}, 3, 2, new DenseLocalOnHeapMatrix()); - mdl.calculateBeta(); - } - - /** */ - @Test(expected = NullPointerException.class) - public void testNoDataNPECalculateBeta() { - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.calculateBeta(); - } - - /** */ - @Test(expected = NullPointerException.class) - public void testNoDataNPECalculateHat() { - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.calculateHat(); - } - - /** */ - @Test(expected = NullPointerException.class) - public void testNoDataNPESSTO() { - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.calculateTotalSumOfSquares(); - } - - /** */ - @Test(expected = MathIllegalArgumentException.class) - public void testMathIllegalArgumentException() { - OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression(); - mdl.validateSampleData(new DenseLocalOnHeapMatrix(1, 2), new DenseLocalOnHeapVector(1)); - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index be71934..5c79c8f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -17,6 +17,13 @@ package org.apache.ignite.ml.regressions; +import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionQRTrainerTest; +import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionSGDTrainerTest; +import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionQRTrainerTest; +import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionSGDTrainerTest; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; +import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionQRTrainerTest; +import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionSGDTrainerTest; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -25,11 +32,14 @@ import org.junit.runners.Suite; */ @RunWith(Suite.class) @Suite.SuiteClasses({ - OLSMultipleLinearRegressionTest.class, - DistributedOLSMultipleLinearRegressionTest.class, - DistributedBlockOLSMultipleLinearRegressionTest.class, - OLSMultipleLinearRegressionModelTest.class + LinearRegressionModelTest.class, + LocalLinearRegressionQRTrainerTest.class, + LocalLinearRegressionSGDTrainerTest.class, + DistributedLinearRegressionQRTrainerTest.class, + DistributedLinearRegressionSGDTrainerTest.class, + BlockDistributedLinearRegressionQRTrainerTest.class, + BlockDistributedLinearRegressionSGDTrainerTest.class }) public class RegressionsTestSuite { // No-op. -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java new file mode 100644 index 0000000..ed6bf36 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +/** + * Artificial regression datasets to be used in regression trainers tests. These datasets were generated by scikit-learn + * tools, {@code sklearn.datasets.make_regression} procedure. + */ +public class ArtificialRegressionDatasets { + /** + * Artificial dataset with 10 observations described by 1 feature. + */ + public static final TestDataset regression10x1 = new TestDataset(new double[][] { + {1.97657990214, 0.197725444973}, + {-5.0835948878, -0.279921224228}, + {-5.09032600779, -0.352291245969}, + {9.67660993007, 0.755464872441}, + {4.95927629958, 0.451981771462}, + {29.2635107429, 2.2277440173}, + {-18.3122588459, -1.25363275369}, + {-3.61729307199, -0.273362913982}, + {-7.19042139249, -0.473846634967}, + {3.68008403347, 0.353883097536} + }, new double[] {13.554054703}, -0.808655936776); + + /** + * Artificial dataset with 10 observations described by 5 features. + */ + public static final TestDataset regression10x5 = new TestDataset(new double[][] { + {118.635647237, 0.687593385888, -1.18956185502, -0.305420702986, 1.98794097418, -0.776629036361}, + {-18.2808432286, -0.165921853684, -0.156162539573, 1.56284391134, -0.198876782109, -0.0921618505605}, + {22.6110523992, 0.0268106268606, 0.702141470035, -0.41503615392, -1.09726502337, 1.30830482813}, + {209.820435262, 0.379809113402, -0.192097238579, -1.27460497119, 2.48052002019, -0.574430888865}, + {-253.750024054, -1.48044570917, -0.331747484523, 0.387993627712, 0.372583756237, -2.27404065923}, + {-24.6467766166, -0.66991474156, 0.269042238935, -0.271412703096, -0.561166818525, 1.37067541854}, + {-311.903650717, 0.268274438122, -1.10491275353, -1.06738703543, -2.24387799735, -0.207431467989}, + {74.2055323536, -0.329489531894, -0.493350762533, -0.644851462227, 0.661220945573, 1.65950140864}, + {57.0312289904, -1.07266578457, 0.80375035572, -0.45207210139, 1.69314420969, -1.10526080856}, + {12.149399645, 1.46504629281, -1.05843246079, 0.266225365277, -0.0113100353869, -0.983495425471} + }, new double[] {99.8393653561, 82.4948224094, 20.2087724072, 97.3306384162, 55.7502297387}, 3.98444039189); + + /** + * Artificial dataset with 100 observations described by 5 features. + */ + public static final TestDataset regression100x5 = new TestDataset(new double[][] { + {-44.2310642946, -0.0331360137605, -0.5290800706, -0.634340342338, -0.428433927151, 0.830582347183}, + {76.2539139721, -0.216200869652, 0.513212019048, -0.693404511747, 0.132995973133, 1.28470259833}, + {293.369799914, 2.90735870802, 0.457740818846, -0.490470696097, -0.442343455187, 0.584038258781}, + {124.258807314, 1.64158129148, 0.0616936820145, 1.24082841519, -1.20126518593, -0.542298907742}, + {13.6610807249, -1.10834821778, 0.545508208111, 1.81361288715, -0.786543112444, 0.250772626496}, + {101.924582305, -0.433526394969, 0.257594734335, 1.22333193911, 0.76626554927, -0.0400734567005}, + {25.5963186303, -0.202003301507, 0.717101151637, -0.486881225605, 1.15215024807, -0.921615554612}, + {75.7959681263, -0.604173187402, 0.0364386836472, 1.67544714536, 0.394743148877, 0.0237966550759}, + {-97.539357166, -0.774517689169, -0.0966902473883, -0.152250704254, -0.325472625458, 0.0720711851256}, + {0.394748999236, -0.559303402754, -0.0493339259273, -1.10840277768, -0.0800969523557, 1.80939282066}, + {-62.0138166431, 0.062614716778, -0.844143618016, 0.55269949861, -2.32580899335, 1.58020577369}, + {584.427692931, 2.13184767906, 1.22222461994, 1.71894070494, 2.69512281718, 0.294123497874}, + {-59.8323709765, 1.00006112818, -1.54481230765, -0.781282316493, 0.0255925284853, -0.0821173744608}, + {101.565711925, -0.38699836725, 1.06934591441, -0.260429311097, 1.02628949564, 0.0431473245174}, + {-141.592607814, 0.993279116267, -0.371768203378, -0.851483217286, -1.96241293548, -0.612279404296}, + {34.8038723379, -0.0182719243972, 0.306367604506, -0.650526589206, 1.30693112283, -0.587465952557}, + {-16.9554534069, -0.703006786668, -0.770718401931, 0.748423272307, 0.502544067819, 0.346625621533}, + {-76.2896177709, -0.16440174812, -1.77431555198, 0.195326723837, 2.01240994405, -1.19559207119}, + {-3.23827624818, -0.674138419631, -1.62238580284, 2.02235607862, 0.679194838679, 0.150203732584}, + {-21.962456854, -0.766271014206, 0.958599712131, -0.313045794728, 0.232655576106, -0.360950549871}, + {349.583669646, 1.75976166947, 1.47271612346, 0.0346005603489, 0.474907228495, 0.61379496381}, + {-418.397356757, -1.83395936566, -0.911702678716, -0.532478094882, -2.03835348133, -0.423005552518}, + {55.0298153952, -0.0301384716096, -0.0137929430966, -0.348583692759, 0.986486580719, 0.154436524434}, + {127.150063206, 1.92682560465, -0.434844790414, 0.1082898967, -0.00723338222402, -0.513199251824}, + {89.6172507626, 1.02463790902, 0.744369837717, 1.250323683, -1.58252612128, -0.588242778808}, + {92.5124829355, -0.403298547743, 0.0422774545428, -0.175000467434, 1.61110066857, 0.422330077287}, + {-303.040366788, 0.611569308879, -1.21926246291, -2.49250330276, -0.789166929605, -1.30166501196}, + {-17.4020602839, 1.72337202371, -1.83540537288, 0.731588761841, -0.338642535062, -1.11053518125}, + {114.918701324, 0.437385758628, 0.975885170381, 0.439444038872, 1.51666514156, -1.93095020264}, + {-8.43548064928, -0.799507968686, -0.00842968328782, -0.154994093964, 1.09169753491, -0.0114818657732}, + {109.209286025, 2.56472965015, -2.07047248035, -0.46764001177, 0.845267147375, -0.236767841427}, + {61.5259982971, -0.379391870148, -0.131017762354, -0.220275015864, 1.82097825699, -0.0568354876403}, + {-71.3872099588, 0.642138455414, -1.00242489879, 0.536780074488, 0.350977275771, -1.8204862883}, + {-21.2768078629, -0.454268998895, 0.0992324274219, 0.0363496803224, 0.281940751723, -0.198435570828}, + {-8.07838891387, -0.331642089041, -0.494067341253, 0.386035842816, -0.738221128298, 1.18236299649}, + {30.4818041751, 0.099206096537, 0.150688905006, 0.332932621949, 0.194845631964, -0.446717875795}, + {237.209150991, 1.12560447042, 0.448488431264, -0.724623711259, 0.401868257097, 1.67129001163}, + {185.172816475, 0.36594142556, -0.0796476435741, 0.473836257, 1.30890722633, 0.592415068693}, + {19.8830237044, 1.52497319332, 0.466906090264, -0.716635613964, -1.19532276745, -0.697663531684}, + {209.396793626, 0.368478789658, 0.699162303982, 1.96702434462, -0.815379139879, 0.863369634396}, + {-215.100514168, -1.83902416164, -1.14966820385, -1.01044860587, 1.76881340629, -0.32165916241}, + {-33.4687353426, -0.0451102002703, 0.642212950033, 0.580822065219, -1.02341504063, -0.781229325942}, + {150.251474823, 0.220170650298, 0.224858901011, 0.541299425328, 1.15151550963, 0.0329044069571}, + {92.2160506097, 1.86450932451, -0.991150940533, -1.49137866968, 1.02113774105, 0.0544762857136}, + {41.2138467595, -0.778892265105, 0.714957464344, 1.79833618993, -0.335322825621, -0.397548301803}, + {13.151262759, 0.301745607362, 0.129778280739, 0.260094818273, -0.10587841585, -0.599330307629}, + {-367.864703951, -1.68695981263, -0.611957677512, -0.0362971579679, -1.2169760515, -1.43224375134}, + {-57.218869838, 0.428806849751, 0.654302177028, -1.31651788496, 0.363857431276, -1.49953703016}, + {53.0877462955, -0.411907760185, -0.192634094071, -0.275879375023, 0.603562526571, 1.16508196734}, + {-8.11860742896, 1.00263982158, -0.157031169267, -1.11795623393, 0.35711440521, -0.851124640982}, + {-49.1878248403, -0.0253797866589, -0.574767070714, 0.200339045636, -0.0107042446803, -0.351288977927}, + {-73.8835407053, -2.07980276724, 1.12235566491, -0.917150593536, 0.741384768556, 0.56229424235}, + {143.163604045, 0.33627769945, 1.07948757447, 0.894869929963, 1.18688316974, -1.54722487849}, + {92.7045830908, 0.944091525689, 0.693296229491, 0.700097596814, -1.23666276942, -0.203890113084}, + {79.1878852355, -0.221973023853, -0.566066329011, 1.57683748648, 0.52854717911, 0.147924782476}, + {30.6547392801, -1.03466213359, 0.606784904328, -0.298096511956, 0.83332987683, 0.636339018254}, + {-329.128386019, -1.41363866598, -1.34966434823, -0.989010564149, 0.46889477248, -1.20493210784}, + {121.190205512, 0.0393914245697, 1.98392444232, -0.65310705226, -0.385899987099, 0.444982471471}, + {-97.0333075649, 0.264325871992, -0.43074811924, -1.14737761316, -0.453134140655, -0.038507405311}, + {158.273624516, 0.302255432981, -0.292046617818, 1.0704087606, 0.815965268115, 0.470631083546}, + {8.24795061818, -1.15155524496, 1.29538707184, -0.4650881541, 0.805123486308, -0.134706887329}, + {87.1140049059, -0.103540823781, -0.192259440773, 1.79648860085, -1.07525447993, 1.06985127941}, + {-25.1300772481, -0.97140742052, 0.033393948794, -0.698311192672, 0.74417168942, 0.752776770225}, + {-285.477057638, -0.480612406803, -1.46081500036, -1.92518386336, -0.426454066275, -0.0539099489597}, + {-65.1269988498, -1.22733468764, 0.121538452336, 0.752958777557, -0.40643211762, 0.257674949803}, + {-17.1813504942, 0.823753836891, 0.445142465255, 0.185644700144, -1.99733367514, -0.247899323048}, + {-46.7543447303, 0.183482778928, -0.934858705943, -1.21961947396, 0.460921844744, 0.571388077177}, + {-1.7536190499, -0.107517908181, 0.0334282610968, -0.556676121428, -0.485957577159, 0.943570398164}, + {-42.8460452689, 0.944999215632, 0.00530052154909, -0.348526283976, -1.724125354, -0.122649339813}, + {62.6291497267, 0.249619894002, 1.3139125969, -1.5644227783, 0.117605482783, 0.304844650662}, + {97.4552176343, 1.59332799639, -1.17868305562, 1.02998378902, -0.31959491258, -0.183038322076}, + {-6.19358885758, 0.437951016253, 0.373339269494, -0.204072768495, 0.477969349931, -1.52176449389}, + {34.0350630099, 0.839319087287, -0.610157662489, 1.73881448393, -1.89200107709, 0.204946415522}, + {54.9790822536, -0.191792583114, 0.989791127554, -0.502154080064, 0.469939512389, -0.102304071079}, + {58.8272402843, 0.0769623906454, 0.501297284297, -0.410054999243, 0.595712387781, -0.0968329050729}, + {95.3620983209, 0.0661481959314, 0.0935137309086, 1.11823292347, -0.612960777903, 0.767865072757}, + {62.4278196648, 0.78350610065, -1.09977017652, 0.526824784479, 1.41310104196, -0.887902707319}, + {57.6298676729, 0.60084172954, -0.785932027202, 0.0271301584637, -0.134109499719, 0.877256170191}, + {5.14112905382, -0.738359365006, 1.40242539359, -0.852833010305, -0.68365080837, 0.88561193696}, + {11.6057244034, -0.958911227571, 1.15715937023, 1.20108425431, 0.882980929338, -1.77404120156}, + {-265.758185272, -1.2092434823, -0.0550151798639, 0.00703735243613, -1.01767244359, -1.40616581707}, + {180.625928828, -0.139091127126, 0.243250756129, 2.17509702585, -0.541735827898, 1.2109459934}, + {-183.604103216, -0.324555097769, -1.71317286749, 1.03645005723, 0.497569347608, -1.96688185911}, + {9.93237328848, 0.825483591345, 0.910287997312, -1.64938108528, 0.98964075968, -1.65748940528}, + {-88.6846949813, -0.0759295112746, -0.593311990101, -0.578711915019, 0.256298822361, -0.429322890198}, + {175.367391479, 0.9361754906, -0.0172852897292, 1.04078658833, 0.919566407184, -0.554923019093}, + {-175.538247146, -1.43498590417, 0.37233438556, -0.897205352198, -0.339309952316, -0.0321624527843}, + {-126.331680318, 0.160446617623, 0.816642363249, -1.39863371652, 0.199747744327, -2.13493607457}, + {116.677107593, 1.19300905847, -0.404409346893, 0.646338976096, -0.534204093869, 0.36692724765}, + {-181.675962893, -1.57613169533, -0.41549571451, -0.956673746013, 0.35723782515, 0.318317395128}, + {-55.1457877823, 0.63723030991, -0.324480386466, 0.296028333894, -1.68117515658, -0.131945601375}, + {25.2534791013, 0.594818219911, -0.0247380403547, -0.101492246071, -0.0745619242015, -0.370837128867}, + {63.6006283756, -1.53493473818, 0.946464097439, 0.637741397831, 0.938866921166, 0.54405291856}, + {-69.6245547661, 0.328482934094, -0.776881060846, -0.285133098443, -1.06107824512, 0.49952182341}, + {233.425957233, 3.10582399189, -0.0854710508706, 0.455873479133, -0.0974589364949, -1.18914783551}, + {-86.5564290626, -0.819839276484, 0.584745927593, -0.544737106102, -1.21927675581, 0.758502626434}, + {425.357285631, 1.70712253847, 1.19892647853, 1.60619661301, 0.36832665241, 0.880791322709}, + {111.797225426, 0.558940594145, -0.746492420236, 1.90172101792, 0.853590062366, -0.867970723941}, + {-253.616801014, -0.426513440051, 0.0388582291888, -1.18576061365, -2.70895868242, 0.26982210287}, + {-394.801501024, -1.65087241498, 0.735525201393, -2.02413077052, -0.96492749037, -1.89014065613} + }, new double[] {93.3843533037, 72.3610889215, 57.5295295915, 63.7287541653, 65.2263084024}, 6.85683020686); + + /** + * Artificial dataset with 100 observations described by 10 features. + */ + public static final TestDataset regression100x10 = new TestDataset(new double[][] { + {69.5794204114, -0.684238565877, 0.175665643732, 0.882115894035, 0.612844187624, + -0.685301720572, -0.8266500007, -0.0383407025118, 1.7105205222, 0.457436379836, -0.291563926494}, + {80.1390102826, -1.80708821811, 0.811271788195, 0.30248512861, 0.910658009566, + -1.61869762501, -0.148325085362, -0.0714164596509, 0.671646742271, 2.15160094956, -0.0495754979721}, + {-156.975447515, 0.170702943934, -0.973403372054, -0.093974528453, 1.54577255871, + -0.0969022857972, -1.10639617368, 1.51752480948, -2.86016865032, 1.24063030602, -0.521785751026}, + {-158.134931891, 0.0890071395055, -0.0811824442353, -0.737354274843, -1.7575255492, + 0.265777246641, 0.0745347238144, -0.457603542683, -1.37034043839, 1.86011799875, 0.651214189491}, + {-131.465820263, 0.0767565260375, 0.651724194978, 0.142113799753, 0.244367469855, + -0.334395162837, -0.069092305876, -0.691806779713, -1.28386786177, -1.43647491141, 0.00721053414234}, + {-125.468890054, 0.43361925912, -0.800231440065, -0.576001094593, 0.0783664516431, + -1.33613252233, -0.968385062126, -1.22077801286, 0.193456109638, -3.09372314386, 0.817979620215}, + {-44.1113403874, -0.595796803171, 1.29482131972, -0.784513985654, 0.364702038003, + -3.2452492093, -0.451605560847, 0.988546607514, 0.492096628873, -0.343018842342, -0.519231306954}, + {61.2269707872, -0.0289059337716, -1.00409238976, 0.329908621635, 1.41965097539, + 0.0395065997587, -0.477939549336, 0.842336765911, -0.808790019648, 1.70241718768, -0.117194118865}, + {301.434286126, 0.430005308515, 1.01290089725, -0.228221561554, 0.463405921629, + -0.602413489517, 1.13832440088, 0.930949226185, -0.196440161506, 1.46304624346, 1.23831509056}, + {-270.454814681, -1.43805412632, -0.256309572507, -0.358047601174, 0.265151660237, + 1.07087986377, -1.93784654681, -0.854440691754, 0.665691996289, -1.87508012738, -0.387092423365}, + {-97.6198688184, -1.67658167161, -0.170246709551, -2.26863722189, 0.280289356338, + -0.690038347855, -1.69282684019, 0.978606053022, 1.28237852256, -1.2941998486, 0.766405365374}, + {-29.5630902399, -1.75615633921, 0.633927486329, -1.24117311555, -0.15884687004, + 0.31296863712, -1.29513272039, 0.344090683606, 1.19598425093, -1.96195019104, 1.81415061059}, + {-130.896377427, 0.577719366939, -0.087267771748, -0.060088767013, 0.469803880788, + -1.03078212088, -1.41547398887, 1.38980586981, -0.37118000595, -1.81689513712, -0.3099432567}, + {79.6300698059, 1.23408625633, 1.06464588017, 1.23403332691, -1.10993859098, + 0.874825200577, 0.589337796957, -1.10266185141, 0.842960469618, -0.89231962021, 0.284074900504}, + {-154.712112815, -1.64474237898, -0.328581696933, 0.38834343178, 0.02682160335, + -0.251167527796, -0.199330632103, -0.0405837345525, -0.908200250794, -1.3283756975, 0.540894408264}, + {233.447381562, 0.395156450609, 0.156412599781, 0.126453148554, 2.40829068933, + 1.01623530754, -0.0856520211145, -0.874970377099, 0.280617145254, -0.307070438514, 0.4599616054}, + {209.012380432, -0.848646647675, 0.558383548084, -0.259628264419, 1.1624126549, + -0.0755949979572, -0.373930759448, 0.985903312667, 0.435839508011, -0.760916312668, 1.89847574116}, + {-39.8987262091, 0.176656582642, 0.508538223618, 0.995038391204, -2.08809409812, + 0.743926580134, 0.246007971514, -0.458288599906, -0.579976479473, 0.0591577146017, 1.64321662761}, + {222.078510236, -0.24031989218, -0.168104260522, -0.727838425954, 0.557181757624, + -0.164906646307, 2.01559331734, 0.897263594222, 0.0921535309562, 0.351910490325, -0.018228500121}, + {-250.916272061, -2.71504637339, 0.498966191294, -3.16410707344, -0.842488891776, + 1.27425275951, 0.0141733666756, 0.695942743199, 0.0917995810179, -0.501447196978, -0.355738068451}, + {134.07259088, 0.0845637591619, 0.237410106679, -0.291458113729, 1.39418566986, + -1.18813057956, -0.683117067763, -0.518910379335, 1.35998426879, -1.28404562245, 0.489131754943}, + {104.988440209, 0.00770925058526, 0.47113239214, -0.606231247854, 0.310679840217, + 0.146297599928, 0.732013998647, -0.284544010865, 0.402622530153, -0.0217367745613, 0.0742970687987}, + {155.558071031, 1.11171654653, 0.726629222799, -0.195820863177, 0.801333855535, + 0.744034755544, 1.11377275513, -0.75673532139, -0.114117607244, -0.158966474923, -0.29701120385}, + {90.7600194013, -0.104364079622, -0.0165109945217, 0.933002972987, -1.80652594466, + -1.34760892883, -0.304511906801, 0.0584734540581, 1.5332169392, 0.478835797824, 1.71534051065}, + {-313.910553214, 0.149908925551, 0.232806828559, -0.0708920471592, -0.0649553559745, + 0.377753357707, -0.957292311668, 0.545360522582, -1.37905464371, -0.940702110994, -1.53620430047}, + {-80.9380113754, 0.135586606896, 0.95759558815, -1.36879020479, 0.735413996144, + 0.637984100201, -1.79563152885, 1.55025691631, 0.634702068786, -0.203690334141, -0.83954824721}, + {-244.336816695, -0.179127343947, -2.12396005014, -0.431179356484, -0.860562153749, + -1.10270688639, -0.986886012982, -0.945091656162, -0.445428453767, 1.32269756209, -0.223712672168}, + {123.069612745, 0.703857129626, 0.291605144784, 1.40233051946, 0.278603787802, + -0.693567967466, -0.15587953395, 2.10213915684, 0.130663329174, -0.393184478882, 0.0874812844555}, + {-148.274944223, 1.66294967732, 0.0830002694123, 0.32492930502, 1.11864359687, + -0.381901627785, -1.06367037132, -0.392583620174, -1.16283326187, 0.104931461025, -1.64719611405}, + {-82.0018788235, 0.497118817453, 0.731125358012, -0.00976413646786, -0.0178930713492, + -0.814978582886, 0.0602834712523, -0.661940479055, -0.957902899386, -1.34489251111, 0.22166518707}, + {-35.742996986, 0.0661349516701, -0.204314495629, 1.17101314753, -2.53846825562, + -0.560282479298, -0.393442894828, 0.988953809491, -0.911281277704, 0.86862242698, 2.59576940486}, + {-109.588885664, -0.0793151346628, -0.408962434518, -0.598817776528, 0.0277205469561, + 0.116291018958, 0.0280416838086, -0.72544170676, -0.669302814774, 0.0751898759816, -0.311002356179}, + {57.8285173441, 0.53753903532, 0.676340503752, -2.10608342721, 0.477714987751, + 0.465695114442, 0.245966562421, -1.05230350808, -0.309794163113, -1.12067331828, 1.07841453304}, + {204.660622582, -0.717565166685, 0.295179660279, -0.377579912697, 1.88425526905, + 0.251875238436, -0.900214103232, -1.02877401105, 0.291693915093, 1.24889067987, 1.78506220081}, + {350.949109103, 2.82276814452, -0.429358342127, 1.12140362367, 1.18120725208, + -1.63913834939, 1.61441562446, -0.364003766916, -0.258752942225, -0.808124680189, 0.556463488303}, + {170.960252153, 0.147245922081, 0.3257117575, 0.211749283649, -0.0150701808404, + -0.888523132148, 0.777862088798, 0.296729270892, -0.332927550718, 0.888968144245, 1.20913118467}, + {112.192270383, 0.129846138824, -0.934371449036, -0.595825303214, 1.74749214629, + -0.0500069421443, -0.161976298602, -2.54100791613, 1.99632530735, -0.0691582773758, -0.863939367415}, + {-56.7847711121, 0.0950532853751, -0.467349228201, -0.26457152362, -0.422134692317, + -0.0734763062127, 0.90128235602, -1.68470856275, -0.0699692697335, -0.463335845504, -0.301754321169}, + {-37.9223252258, -1.40835827778, 0.566142056244, -3.22393318933, 0.228823495106, + -1.8480727782, 0.129468321643, -1.77392686536, 0.0112549619662, 0.146433267822, 1.29379901303}, + {-59.7303066136, 0.835675535576, -0.552173157548, 1.90730898966, -0.520145317195, + 1.55174485912, -1.37531768692, -0.408165743742, 0.0939675842223, 0.318004128812, 0.324378038446}, + {-0.916090786983, 0.425763794043, -0.295541268984, -0.066619586336, 2.03494974978, + -0.197109278058, -0.823307883209, 0.895531446352, -0.276435938737, -1.54580056755, -0.820051830246}, + {-20.3601082842, 0.56420556369, 0.741234589387, -0.565853617392, -0.311399905686, + 2.24066463251, -0.071704904286, -1.22796531596, 0.186020404046, -0.786874824874, 0.23140277151}, + {-22.9342855182, -0.0682789648279, -1.30680909143, 0.0486490588348, 0.890275695028, + -0.257961411112, -0.381531755985, 1.56251482581, -2.11808219232, 0.741828675202, 0.696388901165}, + {-157.251026807, -2.3120966502, 0.183734662375, 1.02192264962, 0.591272941061, + -0.0132855098339, -1.02016546348, 1.19642432892, 0.867653154846, -1.37600041722, -1.08542822792}, + {-68.6110752055, -1.2429968179, -0.950064269349, -0.332379873336, 0.25793632341, + 0.145780713577, -0.512109283074, -0.477887632032, 0.448960776324, -0.190215737958, 0.219578347563}, + {-56.1204152481, -0.811729480846, -0.647410362207, 0.934547463984, -0.390943346216, + -0.409981308474, 0.0923465893049, 1.9281242912, -0.624713581674, -0.0599353282306, -0.0188591746808}, + {348.530651658, 2.51721790231, 0.7560998114, -2.69620396681, 0.5174276585, + 0.403570816695, 0.901648571306, 0.269313230294, 1.07811463589, 0.986649559679, 0.514710327657}, + {-105.719065924, 0.679016972998, 0.341319363316, -0.515209647377, 0.800000866847, + -0.795474442628, -0.866849274801, -1.32927961486, 0.17679343917, -1.93744422464, -0.476447619273}, + {-197.389429553, -1.98585668879, -0.962610549884, -2.48860863254, -0.545990524642, + -0.13005685654, -1.23413782366, 1.17443427507, 1.4785554038, -0.193717671824, -0.466403609229}, + {-23.9625285402, -0.392164367603, 1.07583388583, -0.412686712477, -0.89339030785, + -0.774862334739, -0.186491999529, -0.300162444329, 0.177377235999, 0.134038296039, 0.957945226616}, + {-91.145725943, -0.154640540119, 0.732911957939, -0.206326119636, -0.569816760116, + 0.249393336416, -1.02762332953, 0.25096708081, 0.386927162941, -0.346382299592, 0.243099162109}, + {-80.7295722208, -1.72670707303, 0.138139045677, 0.0648055728598, 0.186182854422, + 1.07226527747, -1.26133459043, 0.213883744163, 1.47115466163, -1.54791582859, 0.170924664865}, + {-317.060323531, -0.349785690206, -0.740759426066, -0.407970845617, -0.689282767277, + -1.25608665316, -0.772546119412, -2.02925712813, 0.132949072522, -0.191465137244, -1.29079690284}, + {-252.491508279, -1.24643122869, 1.55335609203, 0.356613424877, 0.817434495353, + -1.74503747683, -0.818046363088, -1.58284235058, 0.357919389759, -1.18942962791, -1.91728745247}, + {-66.8121363157, -0.584246455697, -0.104254351782, 1.17911687508, -0.29288167882, + 0.891836132692, 0.232853863255, 0.423294355343, -0.669493690103, -1.15783890498, 0.188213983735}, + {140.681464689, 1.33156046873, -1.8847915949, -0.666528837988, -0.513356191443, + 0.281290031669, -1.07815005006, 1.22384196227, 1.39093631269, 0.527644817197, 1.21595221509}, + {-174.22326767, 0.475428766034, 0.856847216768, -0.734282773151, -0.923514989791, + 0.917510828772, 0.674878068543, 0.0644776431114, -0.607796192908, 0.867740011912, -1.97799769281}, + {74.3899799579, 0.00915743526294, 0.553578683413, 1.66930486354, 0.15562803404, + 1.8455840688, -0.371704942927, 1.11228894843, -0.37464389118, -0.48789151589, 0.79553866342}, + {70.1167175897, 0.154877045187, 1.47803572976, -0.0355743163524, -2.47914644675, + 0.672384381837, 1.63160379529, 1.81874583854, 1.22797339421, -0.0131258061634, -0.390265963676}, + {-11.0364788877, 0.173049156249, -1.78140521797, -1.29982707214, -0.48025663179, + -0.469112922302, -1.98718063269, 0.585086542043, 0.264611327837, 1.48855512579, 2.00672263496}, + {-112.711292736, -1.59239636827, -0.600613018822, -0.0209667499746, -1.81872893331, + -0.739893084955, 0.140261888569, -0.498107678308, 2.53664045504, -0.536385019089, -0.608755809378}, + {-198.064468217, 0.737175509877, -2.01835515547, -2.18045950065, 0.428584922529, + -1.01848835019, -0.470645361539, -0.00703630153547, -2.2341302754, 1.51483167022, -0.410184418418}, + {70.2747963991, 1.49474111532, -0.19517712503, 0.7392852909, -0.326060871666, + -0.566710349675, 0.14053094122, -0.562830341306, 0.22931613446, -0.0344439061448, 0.175150510551}, + {207.909021337, 0.839887009159, 0.268826583246, -0.313047158862, 1.12009996015, + 0.214209976971, -0.396147338251, 2.16039704403, 0.699141312749, 0.756192350992, -0.145368196901}, + {169.428609429, -1.13702350819, 1.23964530597, -0.864443556622, -0.885630795949, + -0.523872327352, 0.467159824748, 0.476596383923, 0.4343735578, 1.4075417896, 2.22939328991}, + {-176.909833405, 0.0875512760866, -0.455542269288, 0.539742307764, -0.762003092788, + 0.41829123457, -0.818116139644, -2.01761645956, 0.557395073218, 1.5823271814, -1.0168826293}, + {-27.734298611, -0.841257541979, 0.348961259301, 1.36935991472, -0.0694528057586, + -1.27303784913, 0.152155656569, 1.9279466651, 0.9589415766, -1.76634370106, -1.08831026428}, + {-55.8416853588, 0.927711536927, 0.157856746063, -0.295628714893, 0.0296602829783, + 1.75198587897, -0.38285446366, -0.253287154535, -1.64032395229, -0.842089054965, 1.00493779183}, + {56.0899797005, 0.326117761734, -1.93514762146, 1.0229172721, 0.125568968732, + 2.37760000658, -0.498532972011, -0.733375842271, -0.757445726993, -0.49515057432, 2.01559891524}, + {-176.220234909, 1.571129843, -0.867707605929, -0.709690799512, -1.51535538937, + 1.27424225477, -0.109513704468, -1.46822183, 0.281077088939, -1.97084024232, -0.322309524179}, + {37.7155152941, 0.363383774219, -0.0240881298641, -1.60692745228, -1.26961656439, + -0.41299134216, 1.2890099968, -1.34101694629, -0.455387485256, -0.14055003482, 1.5407059956}, + {-102.163416997, -2.05927378316, -0.470182865756, -0.875528863204, 0.0361720859253, + -1.03713912263, 0.417362606334, 0.707587625276, -0.0591627772581, -2.58905252006, 0.516573345216}, + {-206.47095321, 0.270030584651, 1.85544202116, -0.144189208964, -0.696400687327, + 0.0226388634283, -0.490952489106, -1.69209527849, 0.00973614309272, -0.484105876992, -0.991474668217}, + {201.50637416, 0.513659215697, -0.335630132208, -0.140006500483, 0.149679720127, + -1.89526167503, -0.0614973894156, 0.0813221153552, 0.630952530848, 2.40201011339, 0.997708264073}, + {-72.0667371571, 0.0841570292899, -0.216125859013, -1.77155215764, 2.15081767322, + 0.00953341785443, -1.0826077946, -0.791135571106, -0.989393577892, -0.791485083644, -0.063560999686}, + {-162.903837815, -0.273764637097, 0.282387854873, -1.39881596931, 0.554941097854, + -0.88790718926, -0.693189960902, 0.398762630571, -1.61878562893, -0.345976341096, 0.138298909959}, + {-34.3291926715, -0.499883755911, -0.847296893019, -0.323673126437, 0.531205373462, + -0.0204345595983, 0.284954510306, 0.565031773028, -0.272049818708, -0.130369799738, -0.617572026201}, + {76.1272883187, -0.908810282403, -1.04139421904, 0.890678872055, 1.32990256154, + -0.0150445428835, 0.593918101047, 0.356897732999, 0.824651162423, -1.54544256217, -0.795703905296}, + {171.833705285, -0.0425219657568, -0.884042952325, 1.91202504537, 0.381908223898, + -0.205693527739, 1.53656598237, 0.534880398015, 0.291950716831, -1.1258051056, -0.0612803476297}, + {-235.445792009, 0.261252102941, -0.170931758001, 1.67878144235, 0.0278283741792, + -1.23194408479, -0.190931886594, 1.0000157972, -2.18792142659, -0.230654984288, -1.36626493512}, + {348.968834231, 1.35713154434, 0.950377770072, 0.0700577471848, 0.96907140156, + 2.00890422081, 0.0896405239806, 0.614309607351, 1.07723409067, 2.58506968136, 0.202889806148}, + {-61.0128039201, 0.465438505031, -1.31448530533, 0.374781933416, -0.0118298606041, + -0.477338357738, -0.587656108109, 1.66449545077, 0.435836048385, -0.287027953004, -1.06613472784}, + {-50.687090469, 0.382331825989, -0.597140322197, 1.1276065465, -1.35593777887, + 1.14949964423, -0.858742432885, -0.563211485633, -0.57167161928, 0.0294891749132, 1.9571639493}, + {-186.653649045, -0.00981380006029, 1.0371088941, -1.25319048981, -0.694043021068, + 1.7280802541, -0.191210409232, -0.866039238001, -0.0791927416078, -0.232228656558, -0.93723545053}, + {34.5395591744, 0.680943971029, -0.075875481801, -0.144408300848, -0.869070791528, + 0.496870904214, 1.0940401388, -0.510489750436, -0.47562728601, 0.951406841944, 0.12983846382}, + {-23.7618645627, 0.527032820313, -0.58295129357, -0.3894567306, -0.0547905472556, + -1.86103603537, 0.0506988360667, 1.02778539291, -0.0613720063422, 0.411280841442, -0.665810811374}, + {116.007776415, 0.441750249008, 0.549342185228, 0.731558201455, -0.903624700864, + -2.13208328824, 0.381223328983, 0.283479210749, 1.17705098922, -2.38800904207, 1.32108350152}, + {-148.479593311, -0.814604260049, -0.821204361946, -1.08768677334, -0.0659445766599, + 0.583741297405, 0.669345853296, -0.0935352010726, -0.254906787938, -0.394599725657, -1.26305927257}, + {244.865845084, 0.776784257443, 0.267205388558, 2.37746488031, -0.379275360853, + -0.157454754411, -0.359580726073, 0.886887721861, 1.53707627973, 0.634390546684, 0.984864824122}, + {-81.9954096721, 0.594841146008, -1.22273253129, 0.532466794358, 1.69864239257, + -0.12293671327, -2.06645974171, 0.611808231703, -1.32291985291, 0.722066660478, -0.0021343848511}, + {-245.715046329, -1.77850303496, -0.176518810079, 1.20463434525, -0.597826204963, + -1.45842350123, -0.765730251727, -2.17764204443, 0.12996635702, -0.705509516482, 0.170639846082}, + {123.011946043, -0.909707162714, 0.92357208515, 0.373251929121, 1.24629576577, + 0.0662688299998, -0.372240547929, -0.739353735168, 0.323495756066, 0.954154005738, 0.69606859977}, + {-70.4564963177, 0.650682297051, 0.378131376232, 1.37860253614, -0.924042783872, + 0.802851073842, -0.450299927542, 0.235646185302, -0.148779896161, 1.01308126122, -0.48206889502}, + {21.5288687935, 0.290876355386, 0.0765702960599, 0.905225489744, 0.252841861521, + 1.26729272819, 0.315397441908, -2.00317261368, -0.250990653758, 0.425615332405, 0.0875320802483}, + {231.370169905, 0.535138021352, -1.07151617232, 0.824383756287, 1.84428896701, + -0.890892034494, 0.0480296332924, -0.59251208055, 0.267564961845, -0.230698441998, 0.857077278291}, + {38.8318274023, 2.63547217711, -0.585553060394, 0.430550920323, -0.532619160993, + 1.25335488136, -1.65265278435, 0.0433880112291, -0.166143379872, 0.534066441314, 1.18929937797}, + {116.362219013, -0.275949982433, 0.468069787645, -0.879814121059, 0.862799331322, + 1.18464846725, 0.747084253268, 1.39202500691, -1.23374181275, 0.0949815110503, 0.696546907194}, + {260.540154731, 1.13798788241, -0.0991903174656, 0.1241636043, -0.201415073037, + 1.57683389508, 1.81535629587, 1.07873616646, -0.355800782882, 2.18333193195, 0.0711071144615}, + {-165.835194521, -2.76613178307, 0.805314338858, 0.81526046683, -0.710489036197, + -1.20189542317, -0.692110074722, -0.117239516622, 1.0431459458, -0.111898596299, -0.0775811519297}, + {-341.189958588, 0.668555635008, -1.0940034941, -0.497881262778, -0.603682823779, + -0.396875163796, -0.849144848521, 0.403936807183, -1.82076277475, -0.137500972546, -1.22769896568} + }, new double[] {45.8685095528, 11.9400336005, 16.3984976652, 79.9069814034, 5.65486853464, + 83.6427296424, 27.4571268153, 73.5881193584, 27.1465364511, 79.4095449062}, -5.14077007134); + + /** */ + public static class TestDataset { + + /** */ + private final double[][] data; + + /** */ + private final double[] expWeights; + + /** */ + private final double expIntercept; + + /** */ + TestDataset(double[][] data, double[] expWeights, double expIntercept) { + this.data = data; + this.expWeights = expWeights; + this.expIntercept = expIntercept; + } + + /** */ + public double[][] getData() { + return data; + } + + /** */ + public double[] getExpWeights() { + return expWeights; + } + + /** */ + public double getExpIntercept() { + return expIntercept; + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java new file mode 100644 index 0000000..0c09d75 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; + +/** + * Tests for {@link LinearRegressionQRTrainer} on {@link SparseBlockDistributedMatrix}. + */ +public class BlockDistributedLinearRegressionQRTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { + /** */ + public BlockDistributedLinearRegressionQRTrainerTest() { + super( + new LinearRegressionQRTrainer(), + SparseBlockDistributedMatrix::new, + SparseBlockDistributedVector::new, + 1e-6 + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java new file mode 100644 index 0000000..58037e2 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; + +/** + * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseBlockDistributedMatrix}. + */ +public class BlockDistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { + /** */ + public BlockDistributedLinearRegressionSGDTrainerTest() { + super( + new LinearRegressionSGDTrainer(100_000, 1e-12), + SparseBlockDistributedMatrix::new, + SparseBlockDistributedVector::new, + 1e-2); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java new file mode 100644 index 0000000..2a506d9 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; + +/** + * Tests for {@link LinearRegressionQRTrainer} on {@link SparseDistributedMatrix}. + */ +public class DistributedLinearRegressionQRTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { + /** */ + public DistributedLinearRegressionQRTrainerTest() { + super( + new LinearRegressionQRTrainer(), + SparseDistributedMatrix::new, + SparseDistributedVector::new, + 1e-6 + ); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java new file mode 100644 index 0000000..71d3b3b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; + +/** + * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseDistributedMatrix}. + */ +public class DistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { + /** */ + public DistributedLinearRegressionSGDTrainerTest() { + super( + new LinearRegressionSGDTrainer(100_000, 1e-12), + SparseDistributedMatrix::new, + SparseDistributedVector::new, + 1e-2); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java new file mode 100644 index 0000000..a55623c --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import java.util.Scanner; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.junit.Test; + +/** + * Base class for all linear regression trainers. + */ +public class GenericLinearRegressionTrainerTest { + /** */ + private final Trainer<LinearRegressionModel, Matrix> trainer; + + /** */ + private final IgniteFunction<double[][], Matrix> matrixCreator; + + /** */ + private final IgniteFunction<double[], Vector> vectorCreator; + + /** */ + private final double precision; + + /** */ + public GenericLinearRegressionTrainerTest( + Trainer<LinearRegressionModel, Matrix> trainer, + IgniteFunction<double[][], Matrix> matrixCreator, + IgniteFunction<double[], Vector> vectorCreator, + double precision) { + this.trainer = trainer; + this.matrixCreator = matrixCreator; + this.vectorCreator = vectorCreator; + this.precision = precision; + } + + /** + * Test trainer on regression model y = 2 * x. + */ + @Test + public void testTrainWithoutIntercept() { + Matrix data = matrixCreator.apply(new double[][] { + {2.0, 1.0}, + {4.0, 2.0} + }); + + LinearRegressionModel mdl = trainer.train(data); + + TestUtils.assertEquals(4, mdl.apply(vectorCreator.apply(new double[] {2})), precision); + TestUtils.assertEquals(6, mdl.apply(vectorCreator.apply(new double[] {3})), precision); + TestUtils.assertEquals(8, mdl.apply(vectorCreator.apply(new double[] {4})), precision); + } + + /** + * Test trainer on regression model y = -1 * x + 1. + */ + @Test + public void testTrainWithIntercept() { + Matrix data = matrixCreator.apply(new double[][] { + {1.0, 0.0}, + {0.0, 1.0} + }); + + LinearRegressionModel mdl = trainer.train(data); + + TestUtils.assertEquals(0.5, mdl.apply(vectorCreator.apply(new double[] {0.5})), precision); + TestUtils.assertEquals(2, mdl.apply(vectorCreator.apply(new double[] {-1})), precision); + TestUtils.assertEquals(-1, mdl.apply(vectorCreator.apply(new double[] {2})), precision); + } + + /** + * Test trainer on diabetes dataset. + */ + @Test + public void testTrainOnDiabetesDataset() { + Matrix data = loadDataset("datasets/regression/diabetes.csv", 442, 10); + + LinearRegressionModel mdl = trainer.train(data); + + Vector expWeights = vectorCreator.apply(new double[] { + -10.01219782, -239.81908937, 519.83978679, 324.39042769, -792.18416163, + 476.74583782, 101.04457032, 177.06417623, 751.27932109, 67.62538639 + }); + + double expIntercept = 152.13348416; + + TestUtils.assertEquals("Wrong weights", expWeights, mdl.getWeights(), precision); + TestUtils.assertEquals("Wrong intercept", expIntercept, mdl.getIntercept(), precision); + } + + /** + * Test trainer on boston dataset. + */ + @Test + public void testTrainOnBostonDataset() { + Matrix data = loadDataset("datasets/regression/boston.csv", 506, 13); + + LinearRegressionModel mdl = trainer.train(data); + + Vector expWeights = vectorCreator.apply(new double[] { + -1.07170557e-01, 4.63952195e-02, 2.08602395e-02, 2.68856140e+00, -1.77957587e+01, 3.80475246e+00, + 7.51061703e-04, -1.47575880e+00, 3.05655038e-01, -1.23293463e-02, -9.53463555e-01, 9.39251272e-03, + -5.25466633e-01 + }); + + double expIntercept = 36.4911032804; + + TestUtils.assertEquals("Wrong weights", expWeights, mdl.getWeights(), precision); + TestUtils.assertEquals("Wrong intercept", expIntercept, mdl.getIntercept(), precision); + } + + /** + * Tests trainer on artificial dataset with 10 observations described by 1 feature. + */ + @Test + public void testTrainOnArtificialDataset10x1() { + ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression10x1; + + LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); + + TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); + TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); + } + + /** + * Tests trainer on artificial dataset with 10 observations described by 5 features. + */ + @Test + public void testTrainOnArtificialDataset10x5() { + ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression10x5; + + LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); + + TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); + TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); + } + + /** + * Tests trainer on artificial dataset with 100 observations described by 5 features. + */ + @Test + public void testTrainOnArtificialDataset100x5() { + ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression100x5; + + LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); + + TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); + TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); + } + + /** + * Tests trainer on artificial dataset with 100 observations described by 10 features. + */ + @Test + public void testTrainOnArtificialDataset100x10() { + ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression100x10; + + LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); + + TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); + TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); + } + + /** + * Loads dataset file and returns corresponding matrix. + * + * @param fileName Dataset file name + * @param nobs Number of observations + * @param nvars Number of features + * @return Data matrix + */ + private Matrix loadDataset(String fileName, int nobs, int nvars) { + double[][] matrix = new double[nobs][nvars + 1]; + Scanner scanner = new Scanner(this.getClass().getClassLoader().getResourceAsStream(fileName)); + int i = 0; + while (scanner.hasNextLine()) { + String row = scanner.nextLine(); + int j = 0; + for (String feature : row.split(",")) { + matrix[i][j] = Double.parseDouble(feature); + j++; + } + i++; + } + return matrixCreator.apply(matrix); + } +}
