Author: tdunning
Date: Thu Sep 2 04:33:02 2010
New Revision: 991804
URL: http://svn.apache.org/viewvc?rev=991804&view=rev
Log:
MAHOUT-495 - Undeprecate Normal distribution. Extract common test patterns
into DistributionTest
Added:
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
Modified:
mahout/trunk/math/pom.xml
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
Modified: mahout/trunk/math/pom.xml
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/pom.xml?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
--- mahout/trunk/math/pom.xml (original)
+++ mahout/trunk/math/pom.xml Thu Sep 2 04:33:02 2010
@@ -87,6 +87,13 @@
<dependencies>
<dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math</artifactId>
+ <version>2.1</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>r03</version>
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
Thu Sep 2 04:33:02 2010
@@ -27,8 +27,16 @@ It is provided "as is" without expressed
package org.apache.mahout.math.jet.random;
/**
- * Abstract base class for all continuous distributions.
+ * Abstract base class for all continuous distributions. Continuous
distributions have
+ * probability density and a cumulative distribution functions.
*
*/
public abstract class AbstractContinousDistribution extends
AbstractDistribution {
+ public double cdf(double x) {
+ throw new UnsupportedOperationException("Can't compute pdf for " +
this.getClass().getName());
+ }
+
+ public double pdf(double x) {
+ throw new UnsupportedOperationException("Can't compute pdf for " +
this.getClass().getName());
+ }
}
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
Thu Sep 2 04:33:02 2010
@@ -31,7 +31,6 @@ import org.apache.mahout.math.function.I
import org.apache.mahout.math.jet.random.engine.RandomEngine;
public abstract class AbstractDistribution extends PersistentObject implements
UnaryFunction, IntFunction {
-
protected RandomEngine randomGenerator;
/** Makes this class non instantiable, but still let's others inherit from
it. */
@@ -93,22 +92,6 @@ public abstract class AbstractDistributi
return (int) Math.round(nextDouble());
}
- public byte nextByte() {
- return (byte)nextInt();
- }
-
- public char nextChar() {
- return (char)nextInt();
- }
-
- public long nextLong() {
- return Math.round(nextDouble());
- }
-
- public float nextFloat() {
- return (float)nextDouble();
- }
-
/** Sets the uniform random generator internally used. */
protected void setRandomGenerator(RandomEngine randomGenerator) {
this.randomGenerator = randomGenerator;
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java
Thu Sep 2 04:33:02 2010
@@ -11,8 +11,11 @@ package org.apache.mahout.math.jet.rando
import org.apache.mahout.math.jet.random.engine.RandomEngine;
import org.apache.mahout.math.jet.stat.Probability;
-/** @deprecated until unit tests are in place. Until this time, this
class/interface is unsupported. */
-...@deprecated
+import java.util.Locale;
+
+/**
+ * Implements a normal distribution specified mean and standard deviation.
+ */
public class Normal extends AbstractContinousDistribution {
private double mean;
@@ -24,30 +27,39 @@ public class Normal extends AbstractCont
private double normalizer; // performance cache
- // The uniform random number generated shared by all <b>static</b> methods.
- private static final Normal shared = new Normal(0.0, 1.0,
makeDefaultGenerator());
-
- /** Constructs a normal (gauss) distribution. Example: mean=0.0,
standardDeviation=1.0. */
+ /**
+ * @param mean The mean of the resulting distribution.
+ * @param standardDeviation The standard deviation of the distribution.
+ * @param randomGenerator The random number generator to use. This can
be null if you don't
+ * need to generate any numbers.
+ */
public Normal(double mean, double standardDeviation, RandomEngine
randomGenerator) {
setRandomGenerator(randomGenerator);
setState(mean, standardDeviation);
}
- /** Returns the cumulative distribution function. */
+ /**
+ * Returns the cumulative distribution function.
+ */
+ @Override
public double cdf(double x) {
return Probability.normal(mean, variance, x);
}
- /** Returns a random number from the distribution. */
+ /** Returns the probability density function. */
@Override
- public double nextDouble() {
- return nextDouble(this.mean, this.standardDeviation);
+ public double pdf(double x) {
+ double diff = x - mean;
+ return normalizer * Math.exp(-(diff * diff) / (2.0 * variance));
}
- /** Returns a random number from the distribution; bypasses the internal
state. */
- public double nextDouble(double mean, double standardDeviation) {
+ /**
+ * Returns a random number from the distribution.
+ */
+ @Override
+ public double nextDouble() {
// Uses polar Box-Muller transformation.
- if (cacheFilled && this.mean == mean && this.standardDeviation ==
standardDeviation) {
+ if (cacheFilled) {
cacheFilled = false;
return cache;
}
@@ -62,26 +74,23 @@ public class Normal extends AbstractCont
} while (r >= 1.0);
double z = Math.sqrt(-2.0 * Math.log(r) / r);
- cache = mean + standardDeviation * x * z;
+ cache = this.mean + this.standardDeviation * x * z;
cacheFilled = true;
- return mean + standardDeviation * y * z;
- }
-
- /** Returns the probability distribution function. */
- public double pdf(double x) {
- double diff = x - mean;
- return normalizer * Math.exp(-(diff * diff) / (2.0 * variance));
+ return this.mean + this.standardDeviation * y * z;
}
/** Sets the uniform random generator internally used. */
- @Override
- protected void setRandomGenerator(RandomEngine randomGenerator) {
+ public final void setRandomGenerator(RandomEngine randomGenerator) {
super.setRandomGenerator(randomGenerator);
this.cacheFilled = false;
}
- /** Sets the mean and variance. */
- public void setState(double mean, double standardDeviation) {
+ /**
+ * Sets the mean and variance.
+ * @param mean The new value for the mean.
+ * @param standardDeviation The new value for the standard deviation.
+ */
+ public final void setState(double mean, double standardDeviation) {
if (mean != this.mean || standardDeviation != this.standardDeviation) {
this.mean = mean;
this.standardDeviation = standardDeviation;
@@ -92,16 +101,8 @@ public class Normal extends AbstractCont
}
}
- /** Returns a random number from the distribution with the given mean and
standard deviation. */
- public static double staticNextDouble(double mean, double standardDeviation)
{
- synchronized (shared) {
- return shared.nextDouble(mean, standardDeviation);
- }
- }
-
/** Returns a String representation of the receiver. */
public String toString() {
- return this.getClass().getName() + '(' + mean + ',' + standardDeviation +
')';
+ return String.format(Locale.ENGLISH, "%s(m=%f, sd=%f)",
this.getClass().getName(), mean, standardDeviation);
}
-
}
Added:
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java?rev=991804&view=auto
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
(added)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
Thu Sep 2 04:33:02 2010
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.jet.random;
+
+import org.apache.commons.math.ConvergenceException;
+import org.apache.commons.math.FunctionEvaluationException;
+import org.apache.commons.math.analysis.UnivariateRealFunction;
+import org.apache.commons.math.analysis.integration.RombergIntegrator;
+import org.apache.commons.math.analysis.integration.UnivariateRealIntegrator;
+import org.junit.Assert;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Provides a consistency check for continuous distributions that relates the
pdf, cdf and
+ * samples. The pdf is checked against the cdf by quadrature. The sampling
is checked
+ * against the cdf using a G^2 (similar to chi^2) test.
+ */
+public class DistributionTest {
+ public void checkDistribution(final AbstractContinousDistribution dist,
double[] x, double offset, double scale, int n) throws ConvergenceException,
FunctionEvaluationException {
+ double[] xs = Arrays.copyOf(x, x.length);
+ for (int i = 0; i < xs.length; i++) {
+ xs[i] = xs[i]*scale+ offset;
+ }
+ Arrays.sort(xs);
+
+ // collect samples
+ double[] y = new double[n];
+ for (int i = 0; i < n; i++) {
+ y[i] = dist.nextDouble();
+ }
+ Arrays.sort(y);
+
+ // compute probabilities for bins
+ double[] p = new double[xs.length + 1];
+ double lastP = 0;
+ for (int i = 0; i < xs.length; i++) {
+ double thisP = dist.cdf(xs[i]);
+ p[i] = thisP - lastP;
+ lastP = thisP;
+ }
+ p[p.length - 1] = 1 - lastP;
+
+ // count samples in each bin
+ int[] k = new int[xs.length + 1];
+ int lastJ = 0;
+ for (int i = 0; i < k.length - 1; i++) {
+ int j = 0;
+ while (j < n && y[j] < xs[i]) {
+ j++;
+ }
+ k[i] = j - lastJ;
+ lastJ = j;
+ }
+ k[k.length - 1] = n - lastJ;
+
+ // now verify probabilities by comparing to integral of pdf
+ UnivariateRealIntegrator integrator = new RombergIntegrator();
+ for (int i = 0; i < xs.length - 1; i++) {
+ double delta = integrator.integrate(new UnivariateRealFunction() {
+ public double value(double v) throws FunctionEvaluationException {
+ return dist.pdf(v);
+ }
+ }, xs[i], xs[i + 1]);
+ assertEquals(delta, p[i + 1], 1e-6);
+ }
+
+ // finally compute G^2 of observed versus predicted. See
http://en.wikipedia.org/wiki/G-test
+ double sum = 0;
+ for (int i = 0; i < k.length; i++) {
+ if (k[i] != 0) {
+ sum += k[i] * Math.log(k[i] / p[i] / n);
+ }
+ }
+ sum *= 2;
+
+ // sum is chi^2 distributed with degrees of freedom equal to number of
partitions - 1
+ int dof = k.length - 1;
+ // fisher's approximation is that sqrt(2*x) is approximately unit normal
with mean sqrt(2*dof-1)
+ double z = Math.sqrt(2 * sum) - Math.sqrt(2 * dof - 1);
+ Assert.assertTrue(String.format("offset=%.3f scale=%.3f Z = %.1f", offset,
scale, z), Math.abs(z) < 3);
+ }
+
+ protected void checkCdf(double offset, double scale,
AbstractContinousDistribution dist, double[] breaks, double[] quantiles) {
+ int i = 0;
+ for (double x : breaks) {
+ assertEquals(String.format("m=%.3f sd=%.3f x=%.3f", offset, scale, x),
quantiles[i], dist.cdf(x * scale + offset), 1e-6);
+ i++;
+ }
+ }
+}
Modified:
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
(original)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
Thu Sep 2 04:33:02 2010
@@ -17,6 +17,8 @@
package org.apache.mahout.math.jet.random;
+import org.apache.commons.math.ConvergenceException;
+import org.apache.commons.math.FunctionEvaluationException;
import org.apache.mahout.math.jet.random.engine.MersenneTwister;
import org.junit.Test;
@@ -29,7 +31,7 @@ import static org.junit.Assert.assertEqu
* Created by IntelliJ IDEA. User: tdunning Date: Aug 31, 2010 Time: 7:14:19
PM To change this
* template use File | Settings | File Templates.
*/
-public class ExponentialTest {
+public class ExponentialTest extends DistributionTest {
@Test
public void testCdf() {
Exponential dist = new Exponential(5.0, new MersenneTwister(1));
@@ -65,10 +67,13 @@ public class ExponentialTest {
}
@Test
- public void testNextDouble() {
- for (double lambda : new double[] {13.0, 0.02, 1.6}) {
- Exponential dist = new Exponential(lambda, new MersenneTwister(1));
+ public void testNextDouble() throws ConvergenceException,
FunctionEvaluationException {
+ double[] x = {-0.01, 0.1053605, 0.2231436, 0.3566749, 0.5108256,
0.6931472, 0.9162907, 1.2039728, 1.6094379, 2.3025851};
+ Exponential dist = new Exponential(1, new MersenneTwister(1));
+ for (double lambda : new double[]{13.0, 0.02, 1.6}) {
+ dist.setState(lambda);
checkEmpiricalDistribution(dist, 10000, lambda);
+ checkDistribution(dist, x, 0, 1 / lambda, 10000);
}
}
Added:
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java?rev=991804&view=auto
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
(added)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
Thu Sep 2 04:33:02 2010
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.jet.random;
+
+import org.apache.commons.math.ConvergenceException;
+import org.apache.commons.math.FunctionEvaluationException;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Sep 1, 2010 Time: 9:09:44 AM
To change this
+ * template use File | Settings | File Templates.
+ */
+public class NormalTest extends DistributionTest {
+ private double[] breaks = {-1.2815516, -0.8416212, -0.5244005, -0.2533471,
0.0000000, 0.2533471, 0.5244005, 0.8416212, 1.2815516};
+ private double[] quantiles = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
+
+ @Test
+ public void testCdf() {
+ Random gen = new Random(1);
+ double offset = 0;
+ double scale = 1;
+ for (int k = 0; k < 20; k++) {
+ Normal dist = new Normal(offset, scale, null);
+ checkCdf(offset, scale, dist, breaks, quantiles);
+ offset = gen.nextGaussian();
+ scale = Math.exp(3 * gen.nextGaussian());
+ }
+ }
+
+ @Test
+ public void consistency() throws ConvergenceException,
FunctionEvaluationException {
+ Random gen = new Random(1);
+ double offset = 0;
+ double scale = 1;
+ for (int k = 0; k < 20; k++) {
+ Normal dist = new Normal(offset, scale, new MersenneTwister());
+ checkDistribution(dist, breaks, offset, scale, 10000);
+ offset = gen.nextGaussian();
+ scale = Math.exp(3 * gen.nextGaussian());
+ }
+ }
+
+ @Test
+ public void testSetState() throws ConvergenceException,
FunctionEvaluationException {
+ Normal dist = new Normal(0, 1, new MersenneTwister());
+ dist.setState(1.3, 5.9);
+ checkDistribution(dist, breaks, 1.3, 5.9, 10000);
+ }
+
+ @Test
+ public void testToString() {
+ Locale d = Locale.getDefault();
+ Locale.setDefault(Locale.GERMAN);
+ Assert.assertEquals("org.apache.mahout.math.jet.random.Normal(m=1.300000,
sd=5.900000)", new Normal(1.3, 5.9, null).toString());
+ Locale.setDefault(d);
+ }
+}