Author: erans
Date: Fri Mar 11 14:03:59 2011
New Revision: 1080571
URL: http://svn.apache.org/viewvc?rev=1080571&view=rev
Log:
MATH-503
Added parametric version of the "Logistic" function.
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java?rev=1080571&r1=1080570&r2=1080571&view=diff
==============================================================================
---
commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
(original)
+++
commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
Fri Mar 11 14:03:59 2011
@@ -19,7 +19,10 @@ package org.apache.commons.math.analysis
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction;
+import org.apache.commons.math.analysis.ParametricUnivariateRealFunction;
import org.apache.commons.math.exception.NotStrictlyPositiveException;
+import org.apache.commons.math.exception.NullArgumentException;
+import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.util.FastMath;
/**
@@ -76,7 +79,7 @@ public class Logistic implements Differe
/** {@inheritDoc} */
public double value(double x) {
- return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * (m - x)),
oneOverN);
+ return value(m - x, k, b, q, a, oneOverN);
}
/** {@inheritDoc} */
@@ -94,4 +97,113 @@ public class Logistic implements Differe
}
};
}
+
+ /**
+ * Parametric function where the input array contains the parameters of
+ * the logit function, ordered as follows:
+ * <ul>
+ * <li>Lower asymptote</li>
+ * <li>Higher asymptote</li>
+ * </ul>
+ */
+ public static class Parametric implements ParametricUnivariateRealFunction
{
+ /**
+ * Computes the value of the sigmoid at {@code x}.
+ *
+ * @param x Value for which the function must be computed.
+ * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
+ * {@code a} and {@code n}.
+ * @return the value of the function.
+ * @throws NullArgumentException if {@code param} is {@code null}.
+ * @throws DimensionMismatchException if the size of {@code param} is
+ * not 6.
+ */
+ public double value(double x,
+ double[] param) {
+ validateParameters(param);
+ return Logistic.value(param[1] - x, param[0],
+ param[2], param[3],
+ param[4], 1 / param[5]);
+ }
+
+ /**
+ * Computes the value of the gradient at {@code x}.
+ * The components of the gradient vector are the partial
+ * derivatives of the function with respect to each of the
+ * <em>parameters</em>.
+ *
+ * @param x Value at which the gradient must be computed.
+ * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
+ * {@code a} and {@code n}.
+ * @return the gradient vector at {@code x}.
+ * @throws NullArgumentException if {@code param} is {@code null}.
+ * @throws DimensionMismatchException if the size of {@code param} is
+ * not 6.
+ */
+ public double[] gradient(double x, double[] param) {
+ validateParameters(param);
+
+ final double b = param[2];
+ final double q = param[3];
+
+ final double mMinusX = param[1] - x;
+ final double oneOverN = 1 / param[5];
+ final double exp = FastMath.exp(b * mMinusX);
+ final double qExp = q * exp;
+ final double qExp1 = qExp + 1;
+ final double factor1 = (param[0] - param[4]) * oneOverN /
FastMath.pow(qExp1, oneOverN);
+ final double factor2 = -factor1 / qExp1;
+
+ // Components of the gradient.
+ final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
+ final double gm = factor2 * b * qExp;
+ final double gb = factor2 * mMinusX * qExp;
+ final double gq = factor2 * exp;
+ final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
+ final double gn = factor1 * Math.log(qExp1) * oneOverN;
+
+ return new double[] { gk, gm, gb, gq, ga, gn };
+ }
+
+ /**
+ * Validates parameters to ensure they are appropriate for the
evaluation of
+ * the {@link #value(double,double[])} and {@link
#gradient(double,double[])}
+ * methods.
+ *
+ * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
+ * {@code a} and {@code n}.
+ * @throws NullArgumentException if {@code param} is {@code null}.
+ * @throws DimensionMismatchException if the size of {@code param} is
+ * not 6.
+ */
+ private void validateParameters(double[] param) {
+ if (param == null) {
+ throw new NullArgumentException();
+ }
+ if (param.length != 6) {
+ throw new DimensionMismatchException(param.length, 6);
+ }
+ if (param[5] <= 0) {
+ throw new NotStrictlyPositiveException(param[5]);
+ }
+ }
+ }
+
+ /**
+ * @param mMinusX {@code m - x}.
+ * @param k {@code k}.
+ * @param b {@code b}.
+ * @param q {@code q}.
+ * @param a {@code a}.
+ * @param oneOverN {@code 1 / n}.
+ * @return the value of the function.
+ */
+ private static double value(double mMinusX,
+ double k,
+ double b,
+ double q,
+ double a,
+ double oneOverN) {
+ return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX),
oneOverN);
+ }
}
Modified:
commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java?rev=1080571&r1=1080570&r2=1080571&view=diff
==============================================================================
---
commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
(original)
+++
commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
Fri Mar 11 14:03:59 2011
@@ -19,6 +19,8 @@ package org.apache.commons.math.analysis
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.exception.NotStrictlyPositiveException;
+import org.apache.commons.math.exception.NullArgumentException;
+import org.apache.commons.math.exception.DimensionMismatchException;
import org.apache.commons.math.util.FastMath;
import org.junit.Assert;
@@ -97,4 +99,99 @@ public class LogisticTest {
Assert.assertEquals("x=" + x, dgdx.value(x), dfdx.value(x), EPS);
}
}
+
+ @Test(expected=NullArgumentException.class)
+ public void testParametricUsage1() {
+ final Logistic.Parametric g = new Logistic.Parametric();
+ g.value(0, null);
+ }
+
+ @Test(expected=DimensionMismatchException.class)
+ public void testParametricUsage2() {
+ final Logistic.Parametric g = new Logistic.Parametric();
+ g.value(0, new double[] {0});
+ }
+
+ @Test(expected=NullArgumentException.class)
+ public void testParametricUsage3() {
+ final Logistic.Parametric g = new Logistic.Parametric();
+ g.gradient(0, null);
+ }
+
+ @Test(expected=DimensionMismatchException.class)
+ public void testParametricUsage4() {
+ final Logistic.Parametric g = new Logistic.Parametric();
+ g.gradient(0, new double[] {0});
+ }
+
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testParametricUsage5() {
+ final Logistic.Parametric g = new Logistic.Parametric();
+ g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
+ }
+
+ @Test(expected=NotStrictlyPositiveException.class)
+ public void testParametricUsage6() {
+ final Logistic.Parametric g = new Logistic.Parametric();
+ g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
+ }
+
+ @Test
+ public void testGradientComponent0Component4() {
+ final double k = 3;
+ final double a = 2;
+
+ final Logistic.Parametric f = new Logistic.Parametric();
+ // Compare using the "Sigmoid" function.
+ final Sigmoid.Parametric g = new Sigmoid.Parametric();
+
+ final double x = 0.12345;
+ final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
+ final double[] gg = g.gradient(x, new double[] {a, k});
+
+ Assert.assertEquals(gg[0], gf[4], EPS);
+ Assert.assertEquals(gg[1], gf[0], EPS);
+ }
+
+ @Test
+ public void testGradientComponent5() {
+ final double m = 1.2;
+ final double k = 3.4;
+ final double a = 2.3;
+ final double q = 0.567;
+ final double b = -FastMath.log(q);
+ final double n = 3.4;
+
+ final Logistic.Parametric f = new Logistic.Parametric();
+
+ final double x = m - 1;
+ final double qExp1 = 2;
+
+ final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
+
+ Assert.assertEquals((k - a) * FastMath.log(qExp1) / (n * n *
FastMath.pow(qExp1, 1 / n)),
+ gf[5], EPS);
+ }
+
+ @Test
+ public void testGradientComponent1Component2Component3() {
+ final double m = 1.2;
+ final double k = 3.4;
+ final double a = 2.3;
+ final double b = 0.567;
+ final double q = 1 / FastMath.exp(b * m);
+ final double n = 3.4;
+
+ final Logistic.Parametric f = new Logistic.Parametric();
+
+ final double x = 0;
+ final double qExp1 = 2;
+
+ final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
+
+ final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
+ Assert.assertEquals(factor * b, gf[1], EPS);
+ Assert.assertEquals(factor * m, gf[2], EPS);
+ Assert.assertEquals(factor / q, gf[3], EPS);
+ }
}