Modified: commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/MultiDirectionalTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/MultiDirectionalTest.java?rev=749850&r1=749849&r2=749850&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/MultiDirectionalTest.java (original) +++ commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/MultiDirectionalTest.java Wed Mar 4 00:07:51 2009 @@ -17,14 +17,17 @@ package org.apache.commons.math.optimization.direct; +import junit.framework.Test; +import junit.framework.TestCase; +import junit.framework.TestSuite; + +import org.apache.commons.math.ConvergenceException; import org.apache.commons.math.linear.decomposition.NotPositiveDefiniteMatrixException; -import org.apache.commons.math.optimization.ConvergenceChecker; +import org.apache.commons.math.optimization.GoalType; import org.apache.commons.math.optimization.ObjectiveException; import org.apache.commons.math.optimization.ObjectiveFunction; import org.apache.commons.math.optimization.PointValuePair; -import org.apache.commons.math.ConvergenceException; - -import junit.framework.*; +import org.apache.commons.math.optimization.ObjectiveValueChecker; public class MultiDirectionalTest extends TestCase { @@ -48,8 +51,8 @@ } }; try { - new MultiDirectional(1.9, 0.4).optimize(wrong, 10, new ValueChecker(1.0e-3), true, - new double[] { -0.5 }, new double[] { 0.5 }); + MultiDirectional optimizer = new MultiDirectional(0.9, 1.9); + optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 }); fail("an exception should have been thrown"); } catch (ObjectiveException ce) { // expected behavior @@ -58,8 +61,8 @@ fail("wrong exception caught: " + e.getMessage()); } try { - new MultiDirectional(1.9, 0.4).optimize(wrong, 10, new ValueChecker(1.0e-3), true, - new double[] { 0.5 }, new double[] { 1.5 }); + MultiDirectional optimizer = new MultiDirectional(0.9, 1.9); + optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 }); fail("an exception should have been thrown"); } catch (ObjectiveException ce) { // expected behavior @@ -86,69 +89,45 @@ public double objective(double[] variables) { final double x = variables[0]; final double y = variables[1]; - return Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y); + return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y)); } }; - MultiDirectional md = new MultiDirectional(); + MultiDirectional optimizer = new MultiDirectional(); + optimizer.setConvergenceChecker(new ObjectiveValueChecker(1.0e-10, 1.0e-30)); + optimizer.setMaxEvaluations(200); + optimizer.setStartConfiguration(new double[] { 0.2, 0.2 }); + PointValuePair optimum; // minimization - md.optimize(fourExtrema, 200, new ValueChecker(1.0e-8), true, - new double[] { -4, -2 }, new double[] { 1, 2 }, 10, 38821113105892l); - PointValuePair[] optima = md.getOptima(); - assertEquals(10, optima.length); - int localCount = 0; - int globalCount = 0; - for (PointValuePair optimum : optima) { - if (optimum != null) { - if (optimum.getPoint()[0] < 0) { - // this should be the local minimum - ++localCount; - assertEquals(xM, optimum.getPoint()[0], 1.0e-3); - assertEquals(yP, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXmYp, optimum.getValue(), 3.0e-8); - } else { - // this should be the global minimum - ++globalCount; - assertEquals(xP, optimum.getPoint()[0], 1.0e-3); - assertEquals(yM, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXpYm, optimum.getValue(), 3.0e-8); - } - } - } - assertTrue(localCount > 0); - assertTrue(globalCount > 0); - assertTrue(md.getTotalEvaluations() > 1400); - assertTrue(md.getTotalEvaluations() < 1700); - - // minimization - md.optimize(fourExtrema, 200, new ValueChecker(1.0e-8), false, - new double[] { -3.5, -1 }, new double[] { 0.5, 1.5 }, 10, 38821113105892l); - optima = md.getOptima(); - assertEquals(10, optima.length); - localCount = 0; - globalCount = 0; - for (PointValuePair optimum : optima) { - if (optimum != null) { - if (optimum.getPoint()[0] < 0) { - // this should be the local maximum - ++localCount; - assertEquals(xM, optimum.getPoint()[0], 1.0e-3); - assertEquals(yM, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXmYm, optimum.getValue(), 4.0e-8); - } else { - // this should be the global maximum - ++globalCount; - assertEquals(xP, optimum.getPoint()[0], 1.0e-3); - assertEquals(yP, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXpYp, optimum.getValue(), 4.0e-8); - } - } - } - assertTrue(localCount > 0); - assertTrue(globalCount > 0); - assertTrue(md.getTotalEvaluations() > 1400); - assertTrue(md.getTotalEvaluations() < 1700); + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 }); + assertEquals(xM, optimum.getPoint()[0], 4.0e-6); + assertEquals(yP, optimum.getPoint()[1], 3.0e-6); + assertEquals(valueXmYp, optimum.getValue(), 8.0e-13); + assertTrue(optimizer.getEvaluations() > 120); + assertTrue(optimizer.getEvaluations() < 150); + + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 }); + assertEquals(xP, optimum.getPoint()[0], 2.0e-8); + assertEquals(yM, optimum.getPoint()[1], 3.0e-6); + assertEquals(valueXpYm, optimum.getValue(), 2.0e-12); + assertTrue(optimizer.getEvaluations() > 120); + assertTrue(optimizer.getEvaluations() < 150); + + // maximization + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 }); + assertEquals(xM, optimum.getPoint()[0], 7.0e-7); + assertEquals(yM, optimum.getPoint()[1], 3.0e-7); + assertEquals(valueXmYm, optimum.getValue(), 2.0e-14); + assertTrue(optimizer.getEvaluations() > 120); + assertTrue(optimizer.getEvaluations() < 150); + + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 }); + assertEquals(xP, optimum.getPoint()[0], 2.0e-8); + assertEquals(yP, optimum.getPoint()[1], 3.0e-6); + assertEquals(valueXpYp, optimum.getValue(), 2.0e-12); + assertTrue(optimizer.getEvaluations() > 120); + assertTrue(optimizer.getEvaluations() < 150); } @@ -167,14 +146,19 @@ }; count = 0; + MultiDirectional optimizer = new MultiDirectional(); + optimizer.setConvergenceChecker(new ObjectiveValueChecker(-1, 1.0e-3)); + optimizer.setMaxEvaluations(100); + optimizer.setStartConfiguration(new double[][] { + { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 } + }); PointValuePair optimum = - new MultiDirectional().optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, - new double[][] { - { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 } - }); + optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 }); - assertTrue(count > 60); - assertTrue(optimum.getValue() > 0.01); + assertEquals(count, optimizer.getEvaluations()); + assertTrue(optimizer.getEvaluations() > 70); + assertTrue(optimizer.getEvaluations() < 100); + assertTrue(optimum.getValue() > 1.0e-2); } @@ -195,31 +179,18 @@ }; count = 0; + MultiDirectional optimizer = new MultiDirectional(); + optimizer.setConvergenceChecker(new ObjectiveValueChecker(-1.0, 1.0e-3)); + optimizer.setMaxEvaluations(1000); PointValuePair optimum = - new MultiDirectional().optimize(powell, 1000, new ValueChecker(1.0e-3), true, - new double[] { 3.0, -1.0, 0.0, 1.0 }, - new double[] { 4.0, 0.0, 1.0, 2.0 }); - assertTrue(count > 850); - assertTrue(optimum.getValue() > 0.015); + optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 }); + assertEquals(count, optimizer.getEvaluations()); + assertTrue(optimizer.getEvaluations() > 800); + assertTrue(optimizer.getEvaluations() < 900); + assertTrue(optimum.getValue() > 1.0e-2); } - private static class ValueChecker implements ConvergenceChecker { - - public ValueChecker(double threshold) { - this.threshold = threshold; - } - - public boolean converged(PointValuePair[] simplex) { - PointValuePair smallest = simplex[0]; - PointValuePair largest = simplex[simplex.length - 1]; - return (largest.getValue() - smallest.getValue()) < threshold; - } - - private double threshold; - - }; - public static Test suite() { return new TestSuite(MultiDirectionalTest.class); }
Modified: commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java?rev=749850&r1=749849&r2=749850&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java (original) +++ commons/proper/math/trunk/src/test/org/apache/commons/math/optimization/direct/NelderMeadTest.java Wed Mar 4 00:07:51 2009 @@ -17,19 +17,17 @@ package org.apache.commons.math.optimization.direct; +import junit.framework.Test; +import junit.framework.TestCase; +import junit.framework.TestSuite; + +import org.apache.commons.math.ConvergenceException; import org.apache.commons.math.linear.decomposition.NotPositiveDefiniteMatrixException; -import org.apache.commons.math.optimization.ConvergenceChecker; +import org.apache.commons.math.optimization.GoalType; import org.apache.commons.math.optimization.ObjectiveException; import org.apache.commons.math.optimization.ObjectiveFunction; import org.apache.commons.math.optimization.PointValuePair; -import org.apache.commons.math.ConvergenceException; -import org.apache.commons.math.random.JDKRandomGenerator; -import org.apache.commons.math.random.RandomGenerator; -import org.apache.commons.math.random.RandomVectorGenerator; -import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator; -import org.apache.commons.math.random.UniformRandomGenerator; - -import junit.framework.*; +import org.apache.commons.math.optimization.ObjectiveValueChecker; public class NelderMeadTest extends TestCase { @@ -41,7 +39,7 @@ public void testObjectiveExceptions() throws ConvergenceException { ObjectiveFunction wrong = new ObjectiveFunction() { - private static final long serialVersionUID = 2624035220997628868L; + private static final long serialVersionUID = 4751314470965489371L; public double objective(double[] x) throws ObjectiveException { if (x[0] < 0) { throw new ObjectiveException("{0}", "oops"); @@ -53,8 +51,8 @@ } }; try { - new NelderMead(0.9, 1.9, 0.4, 0.6).optimize(wrong, 10, new ValueChecker(1.0e-3), true, - new double[] { -0.5 }, new double[] { 0.5 }); + NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6); + optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 }); fail("an exception should have been thrown"); } catch (ObjectiveException ce) { // expected behavior @@ -63,8 +61,8 @@ fail("wrong exception caught: " + e.getMessage()); } try { - new NelderMead(0.9, 1.9, 0.4, 0.6).optimize(wrong, 10, new ValueChecker(1.0e-3), true, - new double[] { 0.5 }, new double[] { 1.5 }); + NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6); + optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 }); fail("an exception should have been thrown"); } catch (ObjectiveException ce) { // expected behavior @@ -87,82 +85,58 @@ final double valueXpYm = -0.7290400707055187115322; // global minimum final double valueXpYp = -valueXpYm; // global maximum ObjectiveFunction fourExtrema = new ObjectiveFunction() { - private static final long serialVersionUID = -7039124064449091152L; - public double objective(double[] variables) { + private static final long serialVersionUID = -7039124064449091152L; + public double objective(double[] variables) { final double x = variables[0]; final double y = variables[1]; - return Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y); + return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y)); } }; - NelderMead nm = new NelderMead(); - - // minimization - nm.optimize(fourExtrema, 100, new ValueChecker(1.0e-8), true, - new double[] { -5, -5 }, new double[] { 5, 5 }, 10, 38821113105892l); - PointValuePair[] optima = nm.getOptima(); - assertEquals(10, optima.length); - int localCount = 0; - int globalCount = 0; - for (PointValuePair optimum : optima) { - if (optimum != null) { - if (optimum.getPoint()[0] < 0) { - // this should be the local minimum - ++localCount; - assertEquals(xM, optimum.getPoint()[0], 1.0e-3); - assertEquals(yP, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXmYp, optimum.getValue(), 2.0e-8); - } else { - // this should be the global minimum - ++globalCount; - assertEquals(xP, optimum.getPoint()[0], 1.0e-3); - assertEquals(yM, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXpYm, optimum.getValue(), 2.0e-8); - } - } - } - assertTrue(localCount > 0); - assertTrue(globalCount > 0); - assertTrue(nm.getTotalEvaluations() > 600); - assertTrue(nm.getTotalEvaluations() < 800); + NelderMead optimizer = new NelderMead(); + optimizer.setConvergenceChecker(new ObjectiveValueChecker(1.0e-10, 1.0e-30)); + optimizer.setMaxEvaluations(100); + optimizer.setStartConfiguration(new double[] { 0.2, 0.2 }); + PointValuePair optimum; // minimization - nm.optimize(fourExtrema, 100, new ValueChecker(1.0e-8), false, - new double[] { -5, -5 }, new double[] { 5, 5 }, 10, 38821113105892l); - optima = nm.getOptima(); - assertEquals(10, optima.length); - localCount = 0; - globalCount = 0; - for (PointValuePair optimum : optima) { - if (optimum != null) { - if (optimum.getPoint()[0] < 0) { - // this should be the local maximum - ++localCount; - assertEquals(xM, optimum.getPoint()[0], 1.0e-3); - assertEquals(yM, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXmYm, optimum.getValue(), 2.0e-8); - } else { - // this should be the global maximum - ++globalCount; - assertEquals(xP, optimum.getPoint()[0], 1.0e-3); - assertEquals(yP, optimum.getPoint()[1], 1.0e-3); - assertEquals(valueXpYp, optimum.getValue(), 2.0e-8); - } - } - } - assertTrue(localCount > 0); - assertTrue(globalCount > 0); - assertTrue(nm.getTotalEvaluations() > 600); - assertTrue(nm.getTotalEvaluations() < 800); + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 }); + assertEquals(xM, optimum.getPoint()[0], 2.0e-7); + assertEquals(yP, optimum.getPoint()[1], 2.0e-5); + assertEquals(valueXmYp, optimum.getValue(), 6.0e-12); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + + optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 }); + assertEquals(xP, optimum.getPoint()[0], 5.0e-6); + assertEquals(yM, optimum.getPoint()[1], 6.0e-6); + assertEquals(valueXpYm, optimum.getValue(), 1.0e-11); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + + // maximization + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 }); + assertEquals(xM, optimum.getPoint()[0], 1.0e-5); + assertEquals(yM, optimum.getPoint()[1], 3.0e-6); + assertEquals(valueXmYm, optimum.getValue(), 3.0e-12); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); + + optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 }); + assertEquals(xP, optimum.getPoint()[0], 4.0e-6); + assertEquals(yP, optimum.getPoint()[1], 5.0e-6); + assertEquals(valueXpYp, optimum.getValue(), 7.0e-12); + assertTrue(optimizer.getEvaluations() > 60); + assertTrue(optimizer.getEvaluations() < 90); } public void testRosenbrock() - throws ObjectiveException, ConvergenceException, NotPositiveDefiniteMatrixException { + throws ObjectiveException, ConvergenceException { ObjectiveFunction rosenbrock = new ObjectiveFunction() { - private static final long serialVersionUID = -7039124064449091152L; + private static final long serialVersionUID = -9044950469615237490L; public double objective(double[] x) { ++count; double a = x[1] - x[0] * x[0]; @@ -172,61 +146,19 @@ }; count = 0; - NelderMead nm = new NelderMead(); - try { - nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, - new double[][] { - { -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 } - }, 1, 5384353l); - fail("an exception should have been thrown"); - } catch (ConvergenceException ce) { - // expected behavior - } catch (Exception e) { - e.printStackTrace(System.err); - fail("wrong exception caught: " + e.getMessage()); - } - - count = 0; + NelderMead optimizer = new NelderMead(); + optimizer.setConvergenceChecker(new ObjectiveValueChecker(-1, 1.0e-3)); + optimizer.setMaxEvaluations(100); + optimizer.setStartConfiguration(new double[][] { + { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 } + }); PointValuePair optimum = - nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, - new double[][] { - { -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 } - }, 10, 1642738l); - - assertTrue(count > 700); - assertTrue(count < 800); - assertEquals(0.0, optimum.getValue(), 5.0e-5); - assertEquals(1.0, optimum.getPoint()[0], 0.01); - assertEquals(1.0, optimum.getPoint()[1], 0.01); - - PointValuePair[] minima = nm.getOptima(); - assertEquals(10, minima.length); - assertNotNull(minima[0]); - assertNull(minima[minima.length - 1]); - for (int i = 0; i < minima.length; ++i) { - if (minima[i] == null) { - if ((i + 1) < minima.length) { - assertTrue(minima[i+1] == null); - } - } else { - if (i > 0) { - assertTrue(minima[i-1].getValue() <= minima[i].getValue()); - } - } - } + optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 }); - RandomGenerator rg = new JDKRandomGenerator(); - rg.setSeed(64453353l); - RandomVectorGenerator rvg = - new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 }, - new double[] { 0.2, 0.2 }, - new UniformRandomGenerator(rg)); - optimum = - nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, rvg); - assertEquals(0.0, optimum.getValue(), 2.0e-4); - optimum = - nm.optimize(rosenbrock, 100, new ValueChecker(1.0e-3), true, rvg, 3); - assertEquals(0.0, optimum.getValue(), 3.0e-5); + assertEquals(count, optimizer.getEvaluations()); + assertTrue(optimizer.getEvaluations() > 40); + assertTrue(optimizer.getEvaluations() < 50); + assertTrue(optimum.getValue() < 8.0e-4); } @@ -235,7 +167,7 @@ ObjectiveFunction powell = new ObjectiveFunction() { - private static final long serialVersionUID = -7681075710859391520L; + private static final long serialVersionUID = -832162886102041840L; public double objective(double[] x) { ++count; double a = x[0] + 10 * x[1]; @@ -247,37 +179,18 @@ }; count = 0; - NelderMead nm = new NelderMead(); + NelderMead optimizer = new NelderMead(); + optimizer.setConvergenceChecker(new ObjectiveValueChecker(-1.0, 1.0e-3)); + optimizer.setMaxEvaluations(200); PointValuePair optimum = - nm.optimize(powell, 200, new ValueChecker(1.0e-3), true, - new double[] { 3.0, -1.0, 0.0, 1.0 }, - new double[] { 4.0, 0.0, 1.0, 2.0 }, - 1, 1642738l); - assertTrue(count < 150); - assertEquals(0.0, optimum.getValue(), 6.0e-4); - assertEquals(0.0, optimum.getPoint()[0], 0.07); - assertEquals(0.0, optimum.getPoint()[1], 0.07); - assertEquals(0.0, optimum.getPoint()[2], 0.07); - assertEquals(0.0, optimum.getPoint()[3], 0.07); + optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 }); + assertEquals(count, optimizer.getEvaluations()); + assertTrue(optimizer.getEvaluations() > 110); + assertTrue(optimizer.getEvaluations() < 130); + assertTrue(optimum.getValue() < 2.0e-3); } - private static class ValueChecker implements ConvergenceChecker { - - public ValueChecker(double threshold) { - this.threshold = threshold; - } - - public boolean converged(PointValuePair[] simplex) { - PointValuePair smallest = simplex[0]; - PointValuePair largest = simplex[simplex.length - 1]; - return (largest.getValue() - smallest.getValue()) < threshold; - } - - private double threshold; - - }; - public static Test suite() { return new TestSuite(NelderMeadTest.class); }
