Author: celestin
Date: Sat Dec 3 16:24:55 2011
New Revision: 1209942
URL: http://svn.apache.org/viewvc?rev=1209942&view=rev
Log:
New implementation of
AbstractRealDistribution.inverseCumulativeProbability(double). Solves MATH-699,
and leads to slightly smaller execution times.
Added:
commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
(with props)
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java?rev=1209942&r1=1209941&r2=1209942&view=diff
==============================================================================
---
commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
(original)
+++
commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
Sat Dec 3 16:24:55 2011
@@ -20,7 +20,6 @@ import java.io.Serializable;
import org.apache.commons.math.analysis.UnivariateFunction;
import org.apache.commons.math.analysis.solvers.UnivariateRealSolverUtils;
-import org.apache.commons.math.exception.MathInternalError;
import org.apache.commons.math.exception.NotStrictlyPositiveException;
import org.apache.commons.math.exception.NumberIsTooLargeException;
import org.apache.commons.math.exception.OutOfRangeException;
@@ -69,50 +68,80 @@ implements RealDistribution, Serializabl
/** {@inheritDoc} */
public double inverseCumulativeProbability(final double p) throws
OutOfRangeException {
-
if (p < 0.0 || p > 1.0) {
throw new OutOfRangeException(p, 0, 1);
}
- // by default, do simple root finding using bracketing and default
solver.
- // subclasses can override if there is a better method.
- UnivariateFunction rootFindingFunction =
- new UnivariateFunction() {
- public double value(double x) {
+ double lowerBound = getSupportLowerBound();
+ if (p == 0.0) {
+ return lowerBound;
+ }
+
+ double upperBound = getSupportUpperBound();
+ if (p == 1.0) {
+ return upperBound;
+ }
+
+ final double mu = getNumericalMean();
+ final double sig = FastMath.sqrt(getNumericalVariance());
+ final boolean chebyshevApplies;
+ chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) ||
+ Double.isInfinite(sig) || Double.isNaN(sig));
+
+ if (lowerBound == Double.NEGATIVE_INFINITY) {
+ if (chebyshevApplies) {
+ lowerBound = mu - sig * FastMath.sqrt((1. - p) / p);
+ } else {
+ lowerBound = -1.0;
+ while (cumulativeProbability(lowerBound) >= p) {
+ lowerBound *= 2.0;
+ }
+ }
+ }
+
+ if (upperBound == Double.POSITIVE_INFINITY) {
+ if (chebyshevApplies) {
+ upperBound = mu + sig * FastMath.sqrt(p / (1. - p));
+ } else {
+ upperBound = 1.0;
+ while (cumulativeProbability(upperBound) < p) {
+ upperBound *= 2.0;
+ }
+ }
+ }
+
+ final UnivariateFunction toSolve = new UnivariateFunction() {
+
+ public double value(final double x) {
return cumulativeProbability(x) - p;
}
};
- // Try to bracket root, test domain endpoints if this fails
- double lowerBound = getDomainLowerBound(p);
- double upperBound = getDomainUpperBound(p);
- double[] bracket = null;
- try {
- bracket = UnivariateRealSolverUtils.bracket(
- rootFindingFunction, getInitialDomain(p),
- lowerBound, upperBound);
- } catch (NumberIsTooLargeException ex) {
- /*
- * Check domain endpoints to see if one gives value that is within
- * the default solver's defaultAbsoluteAccuracy of 0 (will be the
- * case if density has bounded support and p is 0 or 1).
- */
- if (FastMath.abs(rootFindingFunction.value(lowerBound)) <
getSolverAbsoluteAccuracy()) {
- return lowerBound;
- }
- if (FastMath.abs(rootFindingFunction.value(upperBound)) <
getSolverAbsoluteAccuracy()) {
- return upperBound;
+ double x = UnivariateRealSolverUtils.solve(toSolve,
+ lowerBound,
+ upperBound,
+
getSolverAbsoluteAccuracy());
+
+ if (!isSupportConnected()) {
+ /* Test for plateau. */
+ final double dx = getSolverAbsoluteAccuracy();
+ if (x - dx >= getSupportLowerBound()) {
+ double px = cumulativeProbability(x);
+ if (cumulativeProbability(x - dx) == px) {
+ upperBound = x;
+ while (upperBound - lowerBound > dx) {
+ final double midPoint = 0.5 * (lowerBound +
upperBound);
+ if (cumulativeProbability(midPoint) < px) {
+ lowerBound = midPoint;
+ } else {
+ upperBound = midPoint;
+ }
+ }
+ return upperBound;
+ }
}
- // Failed bracket convergence was not because of corner solution
- throw new MathInternalError(ex);
}
-
- // find root
- double root = UnivariateRealSolverUtils.solve(rootFindingFunction,
- // override getSolverAbsoluteAccuracy() to use a Brent solver
with
- // absolute accuracy different from the default.
- bracket[0],bracket[1], getSolverAbsoluteAccuracy());
- return root;
+ return x;
}
/**
Added:
commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java?rev=1209942&view=auto
==============================================================================
---
commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
(added)
+++
commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
Sat Dec 3 16:24:55 2011
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math.distribution;
+
+import org.apache.commons.math.analysis.UnivariateFunction;
+import org.apache.commons.math.analysis.integration.RombergIntegrator;
+import org.apache.commons.math.analysis.integration.UnivariateRealIntegrator;
+import org.apache.commons.math.exception.OutOfRangeException;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** Various tests related to MATH-699. */
+public class AbstractRealDistributionTest {
+
+ @Test
+ public void testContinuous() {
+ final double x0 = 0.0;
+ final double x1 = 1.0;
+ final double x2 = 2.0;
+ final double x3 = 3.0;
+ final double p12 = 0.5;
+ final AbstractRealDistribution distribution;
+ distribution = new AbstractRealDistribution() {
+
+ public double cumulativeProbability(final double x) {
+ if ((x < x0) || (x > x3)) {
+ throw new OutOfRangeException(x, x0, x3);
+ }
+ if (x <= x1) {
+ return p12 * (x - x0) / (x1 - x0);
+ } else if (x <= x2) {
+ return p12;
+ } else if (x <= x3) {
+ return p12 + (1.0 - p12) * (x - x2) / (x3 - x2);
+ }
+ return 0.0;
+ }
+
+ public double density(final double x) {
+ if ((x < x0) || (x > x3)) {
+ throw new OutOfRangeException(x, x0, x3);
+ }
+ if (x <= x1) {
+ return p12 / (x1 - x0);
+ } else if (x <= x2) {
+ return 0.0;
+ } else if (x <= x3) {
+ return (1.0 - p12) / (x3 - x2);
+ }
+ return 0.0;
+ }
+
+ @Override
+ protected double getDomainLowerBound(final double p) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected double getDomainUpperBound(final double p) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected double getInitialDomain(final double p) {
+ throw new UnsupportedOperationException();
+ }
+
+ public double getNumericalMean() {
+ return ((x0 + x1) * p12 + (x2 + x3) * (1.0 - p12)) / 2.0;
+ }
+
+ public double getNumericalVariance() {
+ final double meanX = getNumericalMean();
+ final double meanX2;
+ meanX2 = ((x0 * x0 + x0 * x1 + x1 * x1) * p12 + (x2 * x2 + x2
+ * x3 + x3 * x3)
+ * (1.0 - p12)) / 3.0;
+ return meanX2 - meanX * meanX;
+ }
+
+ public double getSupportLowerBound() {
+ return x0;
+ }
+
+ public double getSupportUpperBound() {
+ return x3;
+ }
+
+ public boolean isSupportConnected() {
+ return false;
+ }
+
+ public boolean isSupportLowerBoundInclusive() {
+ return true;
+ }
+
+ public boolean isSupportUpperBoundInclusive() {
+ return true;
+ }
+
+ public double probability(final double x) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ final double expected = x1;
+ final double actual = distribution.inverseCumulativeProbability(p12);
+ Assert.assertEquals("", expected, actual,
+ distribution.getSolverAbsoluteAccuracy());
+ }
+
+ @Test
+ public void testDiscontinuous() {
+ final double x0 = 0.0;
+ final double x1 = 0.25;
+ final double x2 = 0.5;
+ final double x3 = 0.75;
+ final double x4 = 1.0;
+ final double p12 = 1.0 / 3.0;
+ final double p23 = 2.0 / 3.0;
+ final AbstractRealDistribution distribution;
+ distribution = new AbstractRealDistribution() {
+
+ public double cumulativeProbability(final double x) {
+ if ((x < x0) || (x > x4)) {
+ throw new OutOfRangeException(x, x0, x4);
+ }
+ if (x <= x1) {
+ return p12 * (x - x0) / (x1 - x0);
+ } else if (x <= x2) {
+ return p12;
+ } else if (x <= x3) {
+ return p23;
+ } else {
+ return (1.0 - p23) * (x - x3) / (x4 - x3) + p23;
+ }
+ }
+
+ public double density(final double x) {
+ if ((x < x0) || (x > x4)) {
+ throw new OutOfRangeException(x, x0, x4);
+ }
+ if (x <= x1) {
+ return p12 / (x1 - x0);
+ } else if (x <= x2) {
+ return 0.0;
+ } else if (x <= x3) {
+ return 0.0;
+ } else {
+ return (1.0 - p23) / (x4 - x3);
+ }
+ }
+
+ @Override
+ protected double getDomainLowerBound(final double p) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected double getDomainUpperBound(final double p) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ protected double getInitialDomain(final double p) {
+ throw new UnsupportedOperationException();
+ }
+
+ public double getNumericalMean() {
+ final UnivariateFunction f = new UnivariateFunction() {
+
+ public double value(final double x) {
+ return x * density(x);
+ }
+ };
+ final UnivariateRealIntegrator integrator = new
RombergIntegrator();
+ return integrator.integrate(Integer.MAX_VALUE, f, x0, x4);
+ }
+
+ public double getNumericalVariance() {
+ final double meanX = getNumericalMean();
+ final UnivariateFunction f = new UnivariateFunction() {
+
+ public double value(final double x) {
+ return x * x * density(x);
+ }
+ };
+ final UnivariateRealIntegrator integrator = new
RombergIntegrator();
+ final double meanX2 = integrator.integrate(Integer.MAX_VALUE,
+ f, x0, x4);
+ return meanX2 - meanX * meanX;
+ }
+
+ public double getSupportLowerBound() {
+ return x0;
+ }
+
+ public double getSupportUpperBound() {
+ return x4;
+ }
+
+ public boolean isSupportConnected() {
+ return false;
+ }
+
+ public boolean isSupportLowerBoundInclusive() {
+ return true;
+ }
+
+ public boolean isSupportUpperBoundInclusive() {
+ return true;
+ }
+
+ public double probability(final double x) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ final double expected = x2;
+ final double actual = distribution.inverseCumulativeProbability(p23);
+ Assert.assertEquals("", expected, actual,
+ distribution.getSolverAbsoluteAccuracy());
+
+ }
+}
Propchange:
commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
------------------------------------------------------------------------------
svn:eol-style = native
Propchange:
commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
------------------------------------------------------------------------------
svn:keywords = Author Date Id Revision