Repository: mahout Updated Branches: refs/heads/branch-0.14.0 49ad8cb45 -> 545648f6a
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java b/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java new file mode 100644 index 0000000..5a0c660 --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/decomposer/lanczos/TestLanczosSolver.java @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.lanczos; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.decomposer.SolverTest; +import org.apache.mahout.math.solver.EigenDecomposition; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class TestLanczosSolver extends SolverTest { + private static final Logger log = LoggerFactory.getLogger(TestLanczosSolver.class); + + private static final double ERROR_TOLERANCE = 0.05; + + @Test + public void testEigenvalueCheck() throws Exception { + int size = 100; + Matrix m = randomHierarchicalSymmetricMatrix(size); + + Vector initialVector = new DenseVector(size); + initialVector.assign(1.0 / Math.sqrt(size)); + LanczosSolver solver = new LanczosSolver(); + int desiredRank = 80; + LanczosState state = new LanczosState(m, desiredRank, initialVector); + // set initial vector? + solver.solve(state, desiredRank, true); + + EigenDecomposition decomposition = new EigenDecomposition(m); + Vector eigenvalues = decomposition.getRealEigenvalues(); + + float fractionOfEigensExpectedGood = 0.6f; + for (int i = 0; i < fractionOfEigensExpectedGood * desiredRank; i++) { + double s = state.getSingularValue(i); + double e = eigenvalues.get(i); + log.info("{} : L = {}, E = {}", i, s, e); + assertTrue("Singular value differs from eigenvalue", Math.abs((s-e)/e) < ERROR_TOLERANCE); + Vector v = state.getRightSingularVector(i); + Vector v2 = decomposition.getV().viewColumn(i); + double error = 1 - Math.abs(v.dot(v2)/(v.norm(2) * v2.norm(2))); + log.info("error: {}", error); + assertTrue(i + ": 1 - cosAngle = " + error, error < ERROR_TOLERANCE); + } + } + + + @Test + public void testLanczosSolver() throws Exception { + int numRows = 800; + int numColumns = 500; + Matrix corpus = randomHierarchicalMatrix(numRows, numColumns, false); + Vector initialVector = new DenseVector(numColumns); + initialVector.assign(1.0 / Math.sqrt(numColumns)); + int rank = 50; + LanczosState state = new LanczosState(corpus, rank, initialVector); + LanczosSolver solver = new LanczosSolver(); + solver.solve(state, rank, false); + assertOrthonormal(state); + for (int i = 0; i < rank/2; i++) { + assertEigen(i, state.getRightSingularVector(i), corpus, ERROR_TOLERANCE, false); + } + //assertEigen(eigens, corpus, rank / 2, ERROR_TOLERANCE, false); + } + + @Test + public void testLanczosSolverSymmetric() throws Exception { + int numCols = 500; + Matrix corpus = randomHierarchicalSymmetricMatrix(numCols); + Vector initialVector = new DenseVector(numCols); + initialVector.assign(1.0 / Math.sqrt(numCols)); + int rank = 30; + LanczosState state = new LanczosState(corpus, rank, initialVector); + LanczosSolver solver = new LanczosSolver(); + solver.solve(state, rank, true); + //assertOrthonormal(state); + //assertEigen(state, rank / 2, ERROR_TOLERANCE, true); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java b/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java new file mode 100644 index 0000000..b5280b4 --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/jet/stat/GammaTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.jet.stat; + +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Iterables; +import com.google.common.io.CharStreams; +import com.google.common.io.InputSupplier; +import com.google.common.io.Resources; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.MahoutTestCase; +import org.junit.Test; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.Random; + +public final class GammaTest extends MahoutTestCase { + + @Test + public void testGamma() { + double[] x = {1, 2, 5, 10, 20, 50, 100}; + double[] expected = { + 1.000000e+00, 1.000000e+00, 2.400000e+01, 3.628800e+05, 1.216451e+17, 6.082819e+62, 9.332622e+155 + }; + + for (int i = 0; i < x.length; i++) { + assertEquals(expected[i], Gamma.gamma(x[i]), expected[i] * 1.0e-5); + assertEquals(gammaInteger(x[i]), Gamma.gamma(x[i]), expected[i] * 1.0e-5); + assertEquals(gammaInteger(x[i]), Math.exp(Gamma.logGamma(x[i])), expected[i] * 1.0e-5); + } + } + + @Test + public void testNegativeArgForGamma() { + double[] x = {-30.3, -20.7, -10.5, -1.1, 0.5, 0.99, -0.999}; + double[] expected = { + -5.243216e-33, -1.904051e-19, -2.640122e-07, 9.714806e+00, 1.772454e+00, 1.005872e+00, -1.000424e+03 + }; + + for (int i = 0; i < x.length; i++) { + assertEquals(expected[i], Gamma.gamma(x[i]), Math.abs(expected[i] * 1.0e-5)); + assertEquals(Math.abs(expected[i]), Math.abs(Math.exp(Gamma.logGamma(x[i]))), Math.abs(expected[i] * 1.0e-5)); + } + } + + private static double gammaInteger(double x) { + double r = 1; + for (int i = 2; i < x; i++) { + r *= i; + } + return r; + } + + @Test + public void testBigX() { + assertEquals(factorial(4), 4 * 3 * 2, 0); + assertEquals(factorial(4), Gamma.gamma(5), 0); + assertEquals(factorial(14), Gamma.gamma(15), 0); + assertEquals(factorial(34), Gamma.gamma(35), 1.0e-15 * factorial(34)); + assertEquals(factorial(44), Gamma.gamma(45), 1.0e-15 * factorial(44)); + + assertEquals(-6.884137e-40 + 3.508309e-47, Gamma.gamma(-35.1), 1.0e-52); + assertEquals(-3.915646e-41 - 3.526813e-48 - 1.172516e-55, Gamma.gamma(-35.9), 1.0e-52); + assertEquals(-2000000000.577215, Gamma.gamma(-0.5e-9), 1.0e-15 * 2000000000.577215); + assertEquals(1999999999.422784, Gamma.gamma(0.5e-9), 1.0e-15 * 1999999999.422784); + assertEquals(1.324296658017984e+252, Gamma.gamma(146.1), 1.0e-10 * 1.324296658017984e+252); + + for (double x : new double[]{5, 15, 35, 45, -35.1, -35.9, -0.5e-9, 0.5e-9, 146.1}) { + double ref = Math.log(Math.abs(Gamma.gamma(x))); + double actual = Gamma.logGamma(x); + double diff = Math.abs(ref - actual) / ref; + assertEquals("gamma versus logGamma at " + x + " (diff = " + diff + ')', 0, (ref - actual) / ref, 1.0e-8); + } + } + + private static double factorial(int n) { + double r = 1; + for (int i = 2; i <= n; i++) { + r *= i; + } + return r; + } + + @Test + public void beta() { + Random r = RandomUtils.getRandom(); + for (int i = 0; i < 200; i++) { + double alpha = -50 * Math.log1p(-r.nextDouble()); + double beta = -50 * Math.log1p(-r.nextDouble()); + double ref = Math.exp(Gamma.logGamma(alpha) + Gamma.logGamma(beta) - Gamma.logGamma(alpha + beta)); + double actual = Gamma.beta(alpha, beta); + double err = (ref - actual) / ref; + assertEquals("beta at (" + alpha + ", " + beta + ") relative error = " + err, 0, err, 1.0e-10); + } + } + + @Test + public void incompleteBeta() throws IOException { + Splitter onComma = Splitter.on(",").trimResults(); + + InputSupplier<InputStreamReader> input = + Resources.newReaderSupplier(Resources.getResource("beta-test-data.csv"), Charsets.UTF_8); + boolean header = true; + for (String line : CharStreams.readLines(input)) { + if (header) { + // skip + header = false; + } else { + Iterable<String> values = onComma.split(line); + double alpha = Double.parseDouble(Iterables.get(values, 0)); + double beta = Double.parseDouble(Iterables.get(values, 1)); + double x = Double.parseDouble(Iterables.get(values, 2)); + double ref = Double.parseDouble(Iterables.get(values, 3)); + double actual = Gamma.incompleteBeta(alpha, beta, x); + assertEquals(alpha + "," + beta + ',' + x, ref, actual, ref * 1.0e-5); + } + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java b/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java new file mode 100644 index 0000000..47bbb0e --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/jet/stat/ProbabilityTest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.jet.stat; + +import org.apache.mahout.math.MahoutTestCase; +import org.junit.Test; + +import java.util.Locale; + +public final class ProbabilityTest extends MahoutTestCase { + + @Test + public void testNormalCdf() { + // computed by R + // pnorm(seq(-5,5, length.out=100)) + double[] ref = { + 2.866516e-07, 4.816530e-07, 8.013697e-07, 1.320248e-06, 2.153811e-06, + 3.479323e-06, 5.565743e-06, 8.816559e-06, 1.383023e-05, 2.148428e-05, + 3.305072e-05, 5.035210e-05, 7.596947e-05, 1.135152e-04, 1.679855e-04, + 2.462079e-04, 3.574003e-04, 5.138562e-04, 7.317683e-04, 1.032198e-03, + 1.442193e-03, 1.996034e-03, 2.736602e-03, 3.716808e-03, 5.001037e-03, + 6.666521e-03, 8.804535e-03, 1.152131e-02, 1.493850e-02, 1.919309e-02, + 2.443656e-02, 3.083320e-02, 3.855748e-02, 4.779035e-02, 5.871452e-02, + 7.150870e-02, 8.634102e-02, 1.033618e-01, 1.226957e-01, 1.444345e-01, + 1.686293e-01, 1.952845e-01, 2.243525e-01, 2.557301e-01, 2.892574e-01, + 3.247181e-01, 3.618436e-01, 4.003175e-01, 4.397847e-01, 4.798600e-01, + 5.201400e-01, 5.602153e-01, 5.996825e-01, 6.381564e-01, 6.752819e-01, + 7.107426e-01, 7.442699e-01, 7.756475e-01, 8.047155e-01, 8.313707e-01, + 8.555655e-01, 8.773043e-01, 8.966382e-01, 9.136590e-01, 9.284913e-01, + 9.412855e-01, 9.522096e-01, 9.614425e-01, 9.691668e-01, 9.755634e-01, + 9.808069e-01, 9.850615e-01, 9.884787e-01, 9.911955e-01, 9.933335e-01, + 9.949990e-01, 9.962832e-01, 9.972634e-01, 9.980040e-01, 9.985578e-01, + 9.989678e-01, 9.992682e-01, 9.994861e-01, 9.996426e-01, 9.997538e-01, + 9.998320e-01, 9.998865e-01, 9.999240e-01, 9.999496e-01, 9.999669e-01, + 9.999785e-01, 9.999862e-01, 9.999912e-01, 9.999944e-01, 9.999965e-01, + 9.999978e-01, 9.999987e-01, 9.999992e-01, 9.999995e-01, 9.999997e-01 + }; + assertEquals(0.682689492137 / 2 + 0.5, Probability.normal(1), 1.0e-7); + + int i = 0; + for (double x = -5; x <= 5.005; x += 10.0 / 99) { + assertEquals("Test 1 cdf function at " + x, ref[i], Probability.normal(x), 1.0e-6); + assertEquals("Test 2 cdf function at " + x, ref[i], Probability.normal(12, 1, x + 12), 1.0e-6); + assertEquals("Test 3 cdf function at " + x, ref[i], Probability.normal(12, 0.25, x / 2.0 + 12), 1.0e-6); + i++; + } + } + + @Test + public void testBetaCdf() { + // values computed using: + //> pbeta(seq(0, 1, length.out=100), 1, 1) + //> pbeta(seq(0, 1, length.out=100), 2, 1) + //> pbeta(seq(0, 1, length.out=100), 2, 5) + //> pbeta(seq(0, 1, length.out=100), 0.2, 5) + //> pbeta(seq(0, 1, length.out=100), 0.2, 0.01) + + double[][] ref = new double[5][]; + + ref[0] = new double[]{ + 0.00000000, 0.01010101, 0.02020202, 0.03030303, 0.04040404, 0.05050505, + 0.06060606, 0.07070707, 0.08080808, 0.09090909, 0.10101010, 0.11111111, + 0.12121212, 0.13131313, 0.14141414, 0.15151515, 0.16161616, 0.17171717, + 0.18181818, 0.19191919, 0.20202020, 0.21212121, 0.22222222, 0.23232323, + 0.24242424, 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929, + 0.30303030, 0.31313131, 0.32323232, 0.33333333, 0.34343434, 0.35353535, + 0.36363636, 0.37373737, 0.38383838, 0.39393939, 0.40404040, 0.41414141, + 0.42424242, 0.43434343, 0.44444444, 0.45454545, 0.46464646, 0.47474747, + 0.48484848, 0.49494949, 0.50505051, 0.51515152, 0.52525253, 0.53535354, + 0.54545455, 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.59595960, + 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465, 0.65656566, + 0.66666667, 0.67676768, 0.68686869, 0.69696970, 0.70707071, 0.71717172, + 0.72727273, 0.73737374, 0.74747475, 0.75757576, 0.76767677, 0.77777778, + 0.78787879, 0.79797980, 0.80808081, 0.81818182, 0.82828283, 0.83838384, + 0.84848485, 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.89898990, + 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495, 0.95959596, + 0.96969697, 0.97979798, 0.98989899, 1.00000000 + }; + ref[1] = new double[]{ + 0.0000000000, 0.0001020304, 0.0004081216, 0.0009182736, 0.0016324865, + 0.0025507601, 0.0036730946, 0.0049994898, 0.0065299459, 0.0082644628, + 0.0102030405, 0.0123456790, 0.0146923783, 0.0172431385, 0.0199979594, + 0.0229568411, 0.0261197837, 0.0294867871, 0.0330578512, 0.0368329762, + 0.0408121620, 0.0449954086, 0.0493827160, 0.0539740843, 0.0587695133, + 0.0637690032, 0.0689725538, 0.0743801653, 0.0799918376, 0.0858075707, + 0.0918273646, 0.0980512193, 0.1044791348, 0.1111111111, 0.1179471483, + 0.1249872462, 0.1322314050, 0.1396796245, 0.1473319049, 0.1551882461, + 0.1632486481, 0.1715131109, 0.1799816345, 0.1886542190, 0.1975308642, + 0.2066115702, 0.2158963371, 0.2253851648, 0.2350780533, 0.2449750026, + 0.2550760127, 0.2653810836, 0.2758902153, 0.2866034078, 0.2975206612, + 0.3086419753, 0.3199673503, 0.3314967860, 0.3432302826, 0.3551678400, + 0.3673094582, 0.3796551372, 0.3922048771, 0.4049586777, 0.4179165391, + 0.4310784614, 0.4444444444, 0.4580144883, 0.4717885930, 0.4857667585, + 0.4999489848, 0.5143352719, 0.5289256198, 0.5437200286, 0.5587184981, + 0.5739210285, 0.5893276196, 0.6049382716, 0.6207529844, 0.6367717580, + 0.6529945924, 0.6694214876, 0.6860524436, 0.7028874605, 0.7199265381, + 0.7371696766, 0.7546168758, 0.7722681359, 0.7901234568, 0.8081828385, + 0.8264462810, 0.8449137843, 0.8635853484, 0.8824609734, 0.9015406591, + 0.9208244057, 0.9403122130, 0.9600040812, 0.9799000102, 1.0000000000 + }; + ref[2] = new double[]{ + 0.000000000, 0.001489698, 0.005799444, 0.012698382, 0.021966298, 0.033393335, + 0.046779694, 0.061935356, 0.078679798, 0.096841712, 0.116258735, 0.136777178, + 0.158251755, 0.180545326, 0.203528637, 0.227080061, 0.251085352, 0.275437393, + 0.300035957, 0.324787463, 0.349604743, 0.374406809, 0.399118623, 0.423670875, + 0.447999763, 0.472046772, 0.495758466, 0.519086275, 0.541986291, 0.564419069, + 0.586349424, 0.607746242, 0.628582288, 0.648834019, 0.668481403, 0.687507740, + 0.705899486, 0.723646086, 0.740739801, 0.757175549, 0.772950746, 0.788065147, + 0.802520695, 0.816321377, 0.829473074, 0.841983426, 0.853861691, 0.865118615, + 0.875766302, 0.885818092, 0.895288433, 0.904192771, 0.912547431, 0.920369513, + 0.927676778, 0.934487554, 0.940820632, 0.946695177, 0.952130629, 0.957146627, + 0.961762916, 0.965999275, 0.969875437, 0.973411020, 0.976625460, 0.979537944, + 0.982167353, 0.984532203, 0.986650598, 0.988540173, 0.990218056, 0.991700827, + 0.993004475, 0.994144371, 0.995135237, 0.995991117, 0.996725360, 0.997350600, + 0.997878739, 0.998320942, 0.998687627, 0.998988463, 0.999232371, 0.999427531, + 0.999581387, 0.999700663, 0.999791377, 0.999858864, 0.999907798, 0.999942219, + 0.999965567, 0.999980718, 0.999990021, 0.999995342, 0.999998111, 0.999999376, + 0.999999851, 0.999999980, 0.999999999, 1.000000000 + }; + ref[3] = new double[]{ + 0.0000000, 0.5858072, 0.6684658, 0.7201859, 0.7578936, 0.7873991, 0.8114552, + 0.8316029, 0.8487998, 0.8636849, 0.8767081, 0.8881993, 0.8984080, 0.9075280, + 0.9157131, 0.9230876, 0.9297536, 0.9357958, 0.9412856, 0.9462835, 0.9508414, + 0.9550044, 0.9588113, 0.9622963, 0.9654896, 0.9684178, 0.9711044, 0.9735707, + 0.9758356, 0.9779161, 0.9798276, 0.9815839, 0.9831977, 0.9846805, 0.9860426, + 0.9872936, 0.9884422, 0.9894965, 0.9904638, 0.9913509, 0.9921638, 0.9929085, + 0.9935900, 0.9942134, 0.9947832, 0.9953034, 0.9957779, 0.9962104, 0.9966041, + 0.9969621, 0.9972872, 0.9975821, 0.9978492, 0.9980907, 0.9983088, 0.9985055, + 0.9986824, 0.9988414, 0.9989839, 0.9991113, 0.9992251, 0.9993265, 0.9994165, + 0.9994963, 0.9995668, 0.9996288, 0.9996834, 0.9997311, 0.9997727, 0.9998089, + 0.9998401, 0.9998671, 0.9998901, 0.9999098, 0.9999265, 0.9999406, 0.9999524, + 0.9999622, 0.9999703, 0.9999769, 0.9999823, 0.9999866, 0.9999900, 0.9999927, + 0.9999947, 0.9999963, 0.9999975, 0.9999983, 0.9999989, 0.9999993, 0.9999996, + 0.9999998, 0.9999999, 0.9999999, 1.0000000, 1.0000000, 1.0000000, 1.0000000, + 1.0000000, 1.0000000 + }; + ref[4] = new double[]{ + 0.00000000, 0.01908202, 0.02195656, 0.02385194, 0.02530810, 0.02650923, + 0.02754205, 0.02845484, 0.02927741, 0.03002959, 0.03072522, 0.03137444, + 0.03198487, 0.03256240, 0.03311171, 0.03363655, 0.03414001, 0.03462464, + 0.03509259, 0.03554568, 0.03598550, 0.03641339, 0.03683054, 0.03723799, + 0.03763667, 0.03802739, 0.03841091, 0.03878787, 0.03915890, 0.03952453, + 0.03988529, 0.04024162, 0.04059396, 0.04094272, 0.04128827, 0.04163096, + 0.04197113, 0.04230909, 0.04264515, 0.04297958, 0.04331268, 0.04364471, + 0.04397592, 0.04430658, 0.04463693, 0.04496722, 0.04529770, 0.04562860, + 0.04596017, 0.04629265, 0.04662629, 0.04696134, 0.04729804, 0.04763666, + 0.04797747, 0.04832073, 0.04866673, 0.04901578, 0.04936816, 0.04972422, + 0.05008428, 0.05044871, 0.05081789, 0.05119222, 0.05157213, 0.05195809, + 0.05235059, 0.05275018, 0.05315743, 0.05357298, 0.05399753, 0.05443184, + 0.05487673, 0.05533315, 0.05580212, 0.05628480, 0.05678247, 0.05729660, + 0.05782885, 0.05838111, 0.05895557, 0.05955475, 0.06018161, 0.06083965, + 0.06153300, 0.06226670, 0.06304685, 0.06388102, 0.06477877, 0.06575235, + 0.06681788, 0.06799717, 0.06932077, 0.07083331, 0.07260394, 0.07474824, + 0.07748243, 0.08129056, 0.08771055, 1.00000000 + }; + + double[] alpha = {1.0, 2.0, 2.0, 0.2, 0.2}; + double[] beta = {1.0, 1.0, 5.0, 5.0, 0.01}; + for (int j = 0; j < 4; j++) { + for (int i = 0; i < 100; i++) { + double x = i / 99.0; + String p = String.format(Locale.ENGLISH, + "pbeta(q=%6.4f, shape1=%5.3f shape2=%5.3f) = %.8f", + x, alpha[j], beta[j], ref[j][i]); + assertEquals(p, ref[j][i], Probability.beta(alpha[j], beta[j], x), 1.0e-7); + } + } + } + + @Test + public void testLogGamma() { + double[] xValues = {1.1, 2.1, 3.1, 4.1, 5.1, 20.1, 100.1, -0.9}; + double[] ref = { + -0.04987244, 0.04543774, 0.78737508, 1.91877719, 3.32976417, 39.63719250, 359.59427179, 2.35807317 + }; + for (int i = 0; i < xValues.length; i++) { + double x = xValues[i]; + assertEquals(ref[i], Gamma.logGamma(x), 1.0e-7); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java b/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java new file mode 100644 index 0000000..b8c4624 --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/random/ChineseRestaurantTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.QRDecomposition; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +public final class ChineseRestaurantTest extends MahoutTestCase { + + @Test + public void testDepth() { + List<Integer> totals = Lists.newArrayList(); + for (int i = 0; i < 1000; i++) { + ChineseRestaurant x = new ChineseRestaurant(10); + Multiset<Integer> counts = HashMultiset.create(); + for (int j = 0; j < 100; j++) { + counts.add(x.sample()); + } + List<Integer> tmp = Lists.newArrayList(); + for (Integer k : counts.elementSet()) { + tmp.add(counts.count(k)); + } + Collections.sort(tmp, Collections.reverseOrder()); + while (totals.size() < tmp.size()) { + totals.add(0); + } + int j = 0; + for (Integer k : tmp) { + totals.set(j, totals.get(j) + k); + j++; + } + } + + // these are empirically derived values, not principled ones + assertEquals(25000.0, (double) totals.get(0), 1000); + assertEquals(24000.0, (double) totals.get(1), 1000); + assertEquals(8000.0, (double) totals.get(2), 200); + assertEquals(1000.0, (double) totals.get(15), 50); + assertEquals(1000.0, (double) totals.get(20), 40); + } + + @Test + public void testExtremeDiscount() { + ChineseRestaurant x = new ChineseRestaurant(100, 1); + Multiset<Integer> counts = HashMultiset.create(); + for (int i = 0; i < 10000; i++) { + counts.add(x.sample()); + } + assertEquals(10000, x.size()); + for (int i = 0; i < 10000; i++) { + assertEquals(1, x.count(i)); + } + } + + @Test + public void testGrowth() { + ChineseRestaurant s0 = new ChineseRestaurant(10, 0.0); + ChineseRestaurant s5 = new ChineseRestaurant(10, 0.5); + ChineseRestaurant s9 = new ChineseRestaurant(10, 0.9); + Set<Double> splits = ImmutableSet.of(1.0, 1.5, 2.0, 3.0, 5.0, 8.0); + + double offset0 = 0; + int k = 0; + int i = 0; + Matrix m5 = new DenseMatrix(20, 3); + Matrix m9 = new DenseMatrix(20, 3); + while (i <= 200000) { + double n = i / Math.pow(10, Math.floor(Math.log10(i))); + if (splits.contains(n)) { + //System.out.printf("%d\t%d\t%d\t%d\n", i, s0.size(), s5.size(), s9.size()); + if (i > 900) { + double predict5 = predictSize(m5.viewPart(0, k, 0, 3), i, 0.5); + assertEquals(predict5, Math.log(s5.size()), 1); + + double predict9 = predictSize(m9.viewPart(0, k, 0, 3), i, 0.9); + assertEquals(predict9, Math.log(s9.size()), 1); + + //assertEquals(10.5 * Math.log(i) - offset0, s0.size(), 10); + } else if (i > 50) { + double x = 10.5 * Math.log(i) - s0.size(); + m5.viewRow(k).assign(new double[]{Math.log(s5.size()), Math.log(i), 1}); + m9.viewRow(k).assign(new double[]{Math.log(s9.size()), Math.log(i), 1}); + + k++; + offset0 += (x - offset0) / k; + } + if (i > 10000) { + assertEquals(0.0, (double) hapaxCount(s0) / s0.size(), 0.25); + assertEquals(0.5, (double) hapaxCount(s5) / s5.size(), 0.1); + assertEquals(0.9, (double) hapaxCount(s9) / s9.size(), 0.05); + } + } + s0.sample(); + s5.sample(); + s9.sample(); + i++; + } + } + + /** + * Predict the power law growth in number of unique samples from the first few data points. + * Also check that the fitted growth coefficient is about right. + * + * @param m + * @param currentIndex Total data points seen so far. Unique values should be log(currentIndex)*expectedCoefficient + offset. + * @param expectedCoefficient What slope do we expect. + * @return The predicted value for log(currentIndex) + */ + private static double predictSize(Matrix m, int currentIndex, double expectedCoefficient) { + int rows = m.rowSize(); + Matrix a = m.viewPart(0, rows, 1, 2); + Matrix b = m.viewPart(0, rows, 0, 1); + + Matrix ata = a.transpose().times(a); + Matrix atb = a.transpose().times(b); + QRDecomposition s = new QRDecomposition(ata); + Matrix r = s.solve(atb).transpose(); + assertEquals(expectedCoefficient, r.get(0, 0), 0.2); + return r.times(new DenseVector(new double[]{Math.log(currentIndex), 1})).get(0); + } + + private static int hapaxCount(ChineseRestaurant s) { + int r = 0; + for (int i = 0; i < s.size(); i++) { + if (s.count(i) == 1) { + r++; + } + } + return r; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java b/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java new file mode 100644 index 0000000..120b25f --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/randomized/RandomBlasting.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.randomized; + +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.mahout.math.map.OpenIntIntHashMap; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.apache.mahout.math.map.OpenObjectIntHashMap; +import org.apache.mahout.math.set.AbstractIntSet; +import org.apache.mahout.math.set.OpenHashSet; +import org.apache.mahout.math.set.OpenIntHashSet; +import org.junit.Test; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.Repeat; +import com.carrotsearch.randomizedtesting.annotations.Seed; + +/** + * Some randomized tests against Java Util Collections. + */ +public class RandomBlasting extends RandomizedTest { + private static enum Operation { + ADD, REMOVE, CLEAR, INDEXOF, ISEMPTY, SIZE + } + + @Test + @Repeat(iterations = 20) + public void testAgainstReferenceOpenObjectIntHashMap() { + OpenObjectIntHashMap<Integer> base = new OpenObjectIntHashMap<>(); + Map<Integer, Integer> reference = new HashMap<>(); + + List<Operation> ops = Lists.newArrayList(); + addOp(ops, Operation.ADD, 60); + addOp(ops, Operation.REMOVE, 30); + addOp(ops, Operation.INDEXOF, 30); + addOp(ops, Operation.CLEAR, 5); + addOp(ops, Operation.ISEMPTY, 2); + addOp(ops, Operation.SIZE, 2); + + int max = randomIntBetween(1000, 20000); + for (int reps = 0; reps < max; reps++) { + // Ensure some collisions among keys. + int k = randomIntBetween(0, max / 4); + int v = randomInt(); + switch (randomFrom(ops)) { + case ADD: + assertEquals(reference.put(k, v) == null, base.put(k, v)); + break; + + case REMOVE: + assertEquals(reference.remove(k) != null, base.removeKey(k)); + break; + + case INDEXOF: + assertEquals(reference.containsKey(k), base.containsKey(k)); + break; + + case CLEAR: + reference.clear(); + base.clear(); + break; + + case ISEMPTY: + assertEquals(reference.isEmpty(), base.isEmpty()); + break; + + case SIZE: + assertEquals(reference.size(), base.size()); + break; + + default: + throw new RuntimeException(); + } + } + } + + @Test + @Repeat(iterations = 20) + public void testAgainstReferenceOpenIntObjectHashMap() { + OpenIntObjectHashMap<Integer> base = new OpenIntObjectHashMap<>(); + Map<Integer, Integer> reference = new HashMap<>(); + + List<Operation> ops = Lists.newArrayList(); + addOp(ops, Operation.ADD, 60); + addOp(ops, Operation.REMOVE, 30); + addOp(ops, Operation.INDEXOF, 30); + addOp(ops, Operation.CLEAR, 5); + addOp(ops, Operation.ISEMPTY, 2); + addOp(ops, Operation.SIZE, 2); + + int max = randomIntBetween(1000, 20000); + for (int reps = 0; reps < max; reps++) { + // Ensure some collisions among keys. + int k = randomIntBetween(0, max / 4); + int v = randomInt(); + switch (randomFrom(ops)) { + case ADD: + assertEquals(reference.put(k, v) == null, base.put(k, v)); + break; + + case REMOVE: + assertEquals(reference.remove(k) != null, base.removeKey(k)); + break; + + case INDEXOF: + assertEquals(reference.containsKey(k), base.containsKey(k)); + break; + + case CLEAR: + reference.clear(); + base.clear(); + break; + + case ISEMPTY: + assertEquals(reference.isEmpty(), base.isEmpty()); + break; + + case SIZE: + assertEquals(reference.size(), base.size()); + break; + + default: + throw new RuntimeException(); + } + } + } + + @Test + @Repeat(iterations = 20) + public void testAgainstReferenceOpenIntIntHashMap() { + OpenIntIntHashMap base = new OpenIntIntHashMap(); + HashMap<Integer, Integer> reference = new HashMap<>(); + + List<Operation> ops = Lists.newArrayList(); + addOp(ops, Operation.ADD, 60); + addOp(ops, Operation.REMOVE, 30); + addOp(ops, Operation.INDEXOF, 30); + addOp(ops, Operation.CLEAR, 5); + addOp(ops, Operation.ISEMPTY, 2); + addOp(ops, Operation.SIZE, 2); + + int max = randomIntBetween(1000, 20000); + for (int reps = 0; reps < max; reps++) { + // Ensure some collisions among keys. + int k = randomIntBetween(0, max / 4); + int v = randomInt(); + switch (randomFrom(ops)) { + case ADD: + Integer prevValue = reference.put(k, v); + + if (prevValue == null) { + assertEquals(true, base.put(k, v)); + } else { + assertEquals(prevValue.intValue(), base.get(k)); + assertEquals(false, base.put(k, v)); + } + break; + + case REMOVE: + assertEquals(reference.containsKey(k), base.containsKey(k)); + + Integer removed = reference.remove(k); + if (removed == null) { + assertEquals(false, base.removeKey(k)); + } else { + assertEquals(removed.intValue(), base.get(k)); + assertEquals(true, base.removeKey(k)); + } + break; + + case INDEXOF: + assertEquals(reference.containsKey(k), base.containsKey(k)); + break; + + case CLEAR: + reference.clear(); + base.clear(); + break; + + case ISEMPTY: + assertEquals(reference.isEmpty(), base.isEmpty()); + break; + + case SIZE: + assertEquals(reference.size(), base.size()); + break; + + default: + throw new RuntimeException(); + } + } + } + + @Test + @Repeat(iterations = 20) + public void testAgainstReferenceOpenIntHashSet() { + AbstractIntSet base = new OpenIntHashSet(); + HashSet<Integer> reference = Sets.newHashSet(); + + List<Operation> ops = Lists.newArrayList(); + addOp(ops, Operation.ADD, 60); + addOp(ops, Operation.REMOVE, 30); + addOp(ops, Operation.INDEXOF, 30); + addOp(ops, Operation.CLEAR, 5); + addOp(ops, Operation.ISEMPTY, 2); + addOp(ops, Operation.SIZE, 2); + + int max = randomIntBetween(1000, 20000); + for (int reps = 0; reps < max; reps++) { + // Ensure some collisions among keys. + int k = randomIntBetween(0, max / 4); + switch (randomFrom(ops)) { + case ADD: + assertEquals(reference.add(k), base.add(k)); + break; + + case REMOVE: + assertEquals(reference.remove(k), base.remove(k)); + break; + + case INDEXOF: + assertEquals(reference.contains(k), base.contains(k)); + break; + + case CLEAR: + reference.clear(); + base.clear(); + break; + + case ISEMPTY: + assertEquals(reference.isEmpty(), base.isEmpty()); + break; + + case SIZE: + assertEquals(reference.size(), base.size()); + break; + + default: + throw new RuntimeException(); + } + } + } + + @Seed("deadbeef") + @Test + @Repeat(iterations = 20) + public void testAgainstReferenceOpenHashSet() { + Set<Integer> base = new OpenHashSet<>(); + Set<Integer> reference = Sets.newHashSet(); + + List<Operation> ops = Lists.newArrayList(); + addOp(ops, Operation.ADD, 60); + addOp(ops, Operation.REMOVE, 30); + addOp(ops, Operation.INDEXOF, 30); + addOp(ops, Operation.CLEAR, 5); + addOp(ops, Operation.ISEMPTY, 2); + addOp(ops, Operation.SIZE, 2); + + int max = randomIntBetween(1000, 20000); + for (int reps = 0; reps < max; reps++) { + // Ensure some collisions among keys. + int k = randomIntBetween(0, max / 4); + switch (randomFrom(ops)) { + case ADD: + assertEquals(reference.contains(k), base.contains(k)); + break; + + case REMOVE: + assertEquals(reference.remove(k), base.remove(k)); + break; + + case INDEXOF: + assertEquals(reference.contains(k), base.contains(k)); + break; + + case CLEAR: + reference.clear(); + base.clear(); + break; + + case ISEMPTY: + assertEquals(reference.isEmpty(), base.isEmpty()); + break; + + case SIZE: + assertEquals(reference.size(), base.size()); + break; + + default: + throw new RuntimeException(); + } + } + } + + /** + * @see "https://issues.apache.org/jira/browse/MAHOUT-1225" + */ + @Test + public void testMahout1225() { + AbstractIntSet s = new OpenIntHashSet(); + s.clear(); + s.add(23); + s.add(46); + s.clear(); + s.add(70); + s.add(93); + s.contains(100); + } + + /** */ + @Test + public void testClearTable() throws Exception { + OpenObjectIntHashMap<Integer> m = new OpenObjectIntHashMap<>(); + m.clear(); // rehash from the default capacity to the next prime after 1 (3). + m.put(1, 2); + m.clear(); // Should clear internal references. + + Field tableField = m.getClass().getDeclaredField("table"); + tableField.setAccessible(true); + Object[] table = (Object[]) tableField.get(m); + + assertEquals(Sets.newHashSet(Arrays.asList(new Object[] {null})), Sets.newHashSet(Arrays.asList(table))); + } + + /** Add multiple repetitions of op to the list. */ + private static void addOp(List<Operation> ops, Operation op, int reps) { + for (int i = 0; i < reps; i++) { + ops.add(op); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java b/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java new file mode 100644 index 0000000..92fd5bb --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/ssvd/SequentialBigSvdTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.ssvd; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.DiagonalMatrix; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomTrinaryMatrix; +import org.apache.mahout.math.SingularValueDecomposition; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.junit.Test; + +public final class SequentialBigSvdTest extends MahoutTestCase { + + @Test + public void testSingularValues() { + Matrix A = lowRankMatrix(); + + SequentialBigSvd s = new SequentialBigSvd(A, 8); + SingularValueDecomposition svd = new SingularValueDecomposition(A); + + Vector reference = new DenseVector(svd.getSingularValues()).viewPart(0, 8); + assertEquals(reference, s.getSingularValues()); + + assertEquals(A, s.getU().times(new DiagonalMatrix(s.getSingularValues())).times(s.getV().transpose())); + } + + @Test + public void testLeftVectors() { + Matrix A = lowRankMatrix(); + + SequentialBigSvd s = new SequentialBigSvd(A, 8); + SingularValueDecomposition svd = new SingularValueDecomposition(A); + + // can only check first few singular vectors because once the singular values + // go to zero, the singular vectors are not uniquely determined + Matrix u1 = svd.getU().viewPart(0, 20, 0, 4).assign(Functions.ABS); + Matrix u2 = s.getU().viewPart(0, 20, 0, 4).assign(Functions.ABS); + assertEquals(0, u1.minus(u2).aggregate(Functions.PLUS, Functions.ABS), 1.0e-9); + } + + private static void assertEquals(Matrix u1, Matrix u2) { + assertEquals(0, u1.minus(u2).aggregate(Functions.MAX, Functions.ABS), 1.0e-10); + } + + private static void assertEquals(Vector u1, Vector u2) { + assertEquals(0, u1.minus(u2).aggregate(Functions.MAX, Functions.ABS), 1.0e-10); + } + + @Test + public void testRightVectors() { + Matrix A = lowRankMatrix(); + + SequentialBigSvd s = new SequentialBigSvd(A, 6); + SingularValueDecomposition svd = new SingularValueDecomposition(A); + + Matrix v1 = svd.getV().viewPart(0, 20, 0, 3).assign(Functions.ABS); + Matrix v2 = s.getV().viewPart(0, 20, 0, 3).assign(Functions.ABS); + assertEquals(v1, v2); + } + + private static Matrix lowRankMatrix() { + Matrix u = new RandomTrinaryMatrix(1, 20, 4, false); + Matrix d = new DiagonalMatrix(new double[]{5, 3, 1, 0.5}); + Matrix v = new RandomTrinaryMatrix(2, 23, 4, false); + + return u.times(d).times(v.transpose()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java b/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java new file mode 100644 index 0000000..681213b --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.jet.random.AbstractContinousDistribution; +import org.apache.mahout.math.jet.random.Gamma; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Random; + +public final class OnlineSummarizerTest extends MahoutTestCase { + + @Test + public void testStats() { + /** + the reference limits here were derived using a numerical simulation where I took + 10,000 samples from the distribution in question and computed the stats from that + sample to get min, 25%-ile, median and so on. I did this 1000 times to get 5% and + 95% confidence limits for those values. + */ + + //symmetrical, well behaved + System.out.printf("normal\n"); + check(normal(10000)); + + //asymmetrical, well behaved. The range for the maximum was fudged slightly to all this to pass. + System.out.printf("exp\n"); + check(exp(10000)); + + //asymmetrical, wacko distribution where mean/median is about 200 + System.out.printf("gamma\n"); + check(gamma(10000, 0.1)); + } + + private static void check(double[] samples) { + OnlineSummarizer s = new OnlineSummarizer(); + double mean = 0; + double sd = 0; + int n = 1; + for (double x : samples) { + s.add(x); + double old = mean; + mean += (x - mean) / n; + sd += (x - old) * (x - mean); + n++; + } + sd = Math.sqrt(sd / samples.length); + + Arrays.sort(samples); + +// for (int i = 0; i < 5; i++) { +// int index = Math.abs(Arrays.binarySearch(samples, s.getQuartile(i))); +// assertEquals("quartile " + i, i * (samples.length - 1) / 4.0, index, 10); +// } +// assertEquals(s.getQuartile(2), s.getMedian(), 0); + + assertEquals("mean", s.getMean(), mean, 0); + assertEquals("sd", s.getSD(), sd, 1e-8); + } + + private static double[] normal(int n) { + double[] r = new double[n]; + Random gen = RandomUtils.getRandom(1L); + for (int i = 0; i < n; i++) { + r[i] = gen.nextGaussian(); + } + return r; + } + + private static double[] exp(int n) { + double[] r = new double[n]; + Random gen = RandomUtils.getRandom(1L); + for (int i = 0; i < n; i++) { + r[i] = -Math.log1p(-gen.nextDouble()); + } + return r; + } + + private static double[] gamma(int n, double shape) { + double[] r = new double[n]; + Random gen = RandomUtils.getRandom(); + AbstractContinousDistribution gamma = new Gamma(shape, shape, gen); + for (int i = 0; i < n; i++) { + r[i] = gamma.nextDouble(); + } + return r; + } +} + +
