This is an automated email from the ASF dual-hosted git repository. erans pushed a commit to branch modularized_master in repository https://gitbox.apache.org/repos/asf/commons-math.git
commit 1d9670cb12613a2d8c27ad318237e000668f8836 Author: Gilles Sadowski <[email protected]> AuthorDate: Sat May 29 00:34:28 2021 +0200 MATH-1172: "SimpleCurveFitter" as parent class for curve fitter implementations. --- .../math4/legacy/fitting/GaussianCurveFitter.java | 274 ++------------------- .../math4/legacy/fitting/HarmonicCurveFitter.java | 146 ++--------- .../legacy/fitting/PolynomialCurveFitter.java | 70 +----- .../math4/legacy/fitting/SimpleCurveFitter.java | 213 +++++++++++++++- .../legacy/fitting/GaussianCurveFitterTest.java | 20 +- .../legacy/fitting/HarmonicCurveFitterTest.java | 12 +- .../legacy/fitting/PolynomialCurveFitterTest.java | 10 +- 7 files changed, 262 insertions(+), 483 deletions(-) diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java index 85378c9..69a4802 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java @@ -16,22 +16,15 @@ */ package org.apache.commons.math4.legacy.fitting; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; import java.util.List; +import java.util.Collection; import org.apache.commons.math4.legacy.analysis.function.Gaussian; import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException; import org.apache.commons.math4.legacy.exception.NullArgumentException; import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; import org.apache.commons.math4.legacy.exception.OutOfRangeException; -import org.apache.commons.math4.legacy.exception.ZeroException; import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; -import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder; -import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem; -import org.apache.commons.math4.legacy.linear.DiagonalMatrix; import org.apache.commons.math4.legacy.util.FastMath; /** @@ -69,7 +62,7 @@ import org.apache.commons.math4.legacy.util.FastMath; * * @since 3.3 */ -public class GaussianCurveFitter extends AbstractCurveFitter { +public class GaussianCurveFitter extends SimpleCurveFitter { /** Parametric function to be fitted. */ private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() { /** {@inheritDoc} */ @@ -98,10 +91,6 @@ public class GaussianCurveFitter extends AbstractCurveFitter { return v; } }; - /** Initial guess. */ - private final double[] initialGuess; - /** Maximum number of iterations of the optimization algorithm. */ - private final int maxIter; /** * Constructor used by the factory methods. @@ -112,8 +101,7 @@ public class GaussianCurveFitter extends AbstractCurveFitter { */ private GaussianCurveFitter(double[] initialGuess, int maxIter) { - this.initialGuess = initialGuess; - this.maxIter = maxIter; + super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter); } /** @@ -132,86 +120,27 @@ public class GaussianCurveFitter extends AbstractCurveFitter { } /** - * Configure the start point (initial guess). - * @param newStart new start point (initial guess) - * @return a new instance. - */ - public GaussianCurveFitter withStartPoint(double[] newStart) { - return new GaussianCurveFitter(newStart.clone(), - maxIter); - } - - /** - * Configure the maximum number of iterations. - * @param newMaxIter maximum number of iterations - * @return a new instance. - */ - public GaussianCurveFitter withMaxIterations(int newMaxIter) { - return new GaussianCurveFitter(initialGuess, - newMaxIter); - } - - /** {@inheritDoc} */ - @Override - protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { - - // Prepare least-squares problem. - final int len = observations.size(); - final double[] target = new double[len]; - final double[] weights = new double[len]; - - int i = 0; - for (WeightedObservedPoint obs : observations) { - target[i] = obs.getY(); - weights[i] = obs.getWeight(); - ++i; - } - - final AbstractCurveFitter.TheoreticalValuesFunction model = - new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations); - - final double[] startPoint = initialGuess != null ? - initialGuess : - // Compute estimation. - new ParameterGuesser(observations).guess(); - - // Return a new least squares problem set up to fit a Gaussian curve to the - // observed points. - return new LeastSquaresBuilder(). - maxEvaluations(Integer.MAX_VALUE). - maxIterations(maxIter). - start(startPoint). - target(target). - weight(new DiagonalMatrix(weights)). - model(model.getModelFunction(), model.getModelFunctionJacobian()). - build(); - - } - - /** * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} * of a {@link org.apache.commons.math4.legacy.analysis.function.Gaussian.Parametric} * based on the specified observed points. */ - public static class ParameterGuesser { - /** Normalization factor. */ - private final double norm; - /** Mean. */ - private final double mean; - /** Standard deviation. */ - private final double sigma; - + public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser { /** - * Constructs instance with the specified observed points. + * {@inheritDoc} * - * @param observations Observed points from which to guess the - * parameters of the Gaussian. + * @return the guessed parameters, in the following order: + * <ul> + * <li>Normalization factor</li> + * <li>Mean</li> + * <li>Standard deviation</li> + * </ul> * @throws NullArgumentException if {@code observations} is * {@code null}. * @throws NumberIsTooSmallException if there are less than 3 * observations. */ - public ParameterGuesser(Collection<WeightedObservedPoint> observations) { + @Override + public double[] guess(Collection<WeightedObservedPoint> observations) { if (observations == null) { throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); } @@ -220,68 +149,7 @@ public class GaussianCurveFitter extends AbstractCurveFitter { } final List<WeightedObservedPoint> sorted = sortObservations(observations); - final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0])); - - norm = params[0]; - mean = params[1]; - sigma = params[2]; - } - - /** - * Gets an estimation of the parameters. - * - * @return the guessed parameters, in the following order: - * <ul> - * <li>Normalization factor</li> - * <li>Mean</li> - * <li>Standard deviation</li> - * </ul> - */ - public double[] guess() { - return new double[] { norm, mean, sigma }; - } - - /** - * Sort the observations. - * - * @param unsorted Input observations. - * @return the input observations, sorted. - */ - private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { - final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted); - - final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() { - /** {@inheritDoc} */ - @Override - public int compare(WeightedObservedPoint p1, - WeightedObservedPoint p2) { - if (p1 == null && p2 == null) { - return 0; - } - if (p1 == null) { - return -1; - } - if (p2 == null) { - return 1; - } - int comp = Double.compare(p1.getX(), p2.getX()); - if (comp != 0) { - return comp; - } - comp = Double.compare(p1.getY(), p2.getY()); - if (comp != 0) { - return comp; - } - comp = Double.compare(p1.getWeight(), p2.getWeight()); - if (comp != 0) { - return comp; - } - return 0; - } - }; - - Collections.sort(observations, cmp); - return observations; + return basicGuess(sorted.toArray(new WeightedObservedPoint[0])); } /** @@ -309,119 +177,5 @@ public class GaussianCurveFitter extends AbstractCurveFitter { return new double[] { n, points[maxYIdx].getX(), s }; } - - /** - * Finds index of point in specified points with the largest Y. - * - * @param points Points to search. - * @return the index in specified points array. - */ - private int findMaxY(WeightedObservedPoint[] points) { - int maxYIdx = 0; - for (int i = 1; i < points.length; i++) { - if (points[i].getY() > points[maxYIdx].getY()) { - maxYIdx = i; - } - } - return maxYIdx; - } - - /** - * Interpolates using the specified points to determine X at the - * specified Y. - * - * @param points Points to use for interpolation. - * @param startIdx Index within points from which to start the search for - * interpolation bounds points. - * @param idxStep Index step for searching interpolation bounds points. - * @param y Y value for which X should be determined. - * @return the value of X for the specified Y. - * @throws ZeroException if {@code idxStep} is 0. - * @throws OutOfRangeException if specified {@code y} is not within the - * range of the specified {@code points}. - */ - private double interpolateXAtY(WeightedObservedPoint[] points, - int startIdx, - int idxStep, - double y) - throws OutOfRangeException { - if (idxStep == 0) { - throw new ZeroException(); - } - final WeightedObservedPoint[] twoPoints - = getInterpolationPointsForY(points, startIdx, idxStep, y); - final WeightedObservedPoint p1 = twoPoints[0]; - final WeightedObservedPoint p2 = twoPoints[1]; - if (p1.getY() == y) { - return p1.getX(); - } - if (p2.getY() == y) { - return p2.getX(); - } - return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / - (p2.getY() - p1.getY())); - } - - /** - * Gets the two bounding interpolation points from the specified points - * suitable for determining X at the specified Y. - * - * @param points Points to use for interpolation. - * @param startIdx Index within points from which to start search for - * interpolation bounds points. - * @param idxStep Index step for search for interpolation bounds points. - * @param y Y value for which X should be determined. - * @return the array containing two points suitable for determining X at - * the specified Y. - * @throws ZeroException if {@code idxStep} is 0. - * @throws OutOfRangeException if specified {@code y} is not within the - * range of the specified {@code points}. - */ - private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, - int startIdx, - int idxStep, - double y) - throws OutOfRangeException { - if (idxStep == 0) { - throw new ZeroException(); - } - for (int i = startIdx; - idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; - i += idxStep) { - final WeightedObservedPoint p1 = points[i]; - final WeightedObservedPoint p2 = points[i + idxStep]; - if (isBetween(y, p1.getY(), p2.getY())) { - if (idxStep < 0) { - return new WeightedObservedPoint[] { p2, p1 }; - } else { - return new WeightedObservedPoint[] { p1, p2 }; - } - } - } - - // Boundaries are replaced by dummy values because the raised - // exception is caught and the message never displayed. - // TODO: Exceptions should not be used for flow control. - throw new OutOfRangeException(y, - Double.NEGATIVE_INFINITY, - Double.POSITIVE_INFINITY); - } - - /** - * Determines whether a value is between two other values. - * - * @param value Value to test whether it is between {@code boundary1} - * and {@code boundary2}. - * @param boundary1 One end of the range. - * @param boundary2 Other end of the range. - * @return {@code true} if {@code value} is between {@code boundary1} and - * {@code boundary2} (inclusive), {@code false} otherwise. - */ - private boolean isBetween(double value, - double boundary1, - double boundary2) { - return (value >= boundary1 && value <= boundary2) || - (value >= boundary2 && value <= boundary1); - } } } diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java index b1b0af3..51c6b67 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java @@ -25,9 +25,6 @@ import org.apache.commons.math4.legacy.exception.MathIllegalStateException; import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException; import org.apache.commons.math4.legacy.exception.ZeroException; import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; -import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder; -import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem; -import org.apache.commons.math4.legacy.linear.DiagonalMatrix; import org.apache.commons.math4.legacy.util.FastMath; /** @@ -46,13 +43,9 @@ import org.apache.commons.math4.legacy.util.FastMath; * * @since 3.3 */ -public class HarmonicCurveFitter extends AbstractCurveFitter { +public class HarmonicCurveFitter extends SimpleCurveFitter { /** Parametric function to be fitted. */ private static final HarmonicOscillator.Parametric FUNCTION = new HarmonicOscillator.Parametric(); - /** Initial guess. */ - private final double[] initialGuess; - /** Maximum number of iterations of the optimization algorithm. */ - private final int maxIter; /** * Constructor used by the factory methods. @@ -63,8 +56,7 @@ public class HarmonicCurveFitter extends AbstractCurveFitter { */ private HarmonicCurveFitter(double[] initialGuess, int maxIter) { - this.initialGuess = initialGuess; - this.maxIter = maxIter; + super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter); } /** @@ -83,63 +75,6 @@ public class HarmonicCurveFitter extends AbstractCurveFitter { } /** - * Configure the start point (initial guess). - * @param newStart new start point (initial guess) - * @return a new instance. - */ - public HarmonicCurveFitter withStartPoint(double[] newStart) { - return new HarmonicCurveFitter(newStart.clone(), - maxIter); - } - - /** - * Configure the maximum number of iterations. - * @param newMaxIter maximum number of iterations - * @return a new instance. - */ - public HarmonicCurveFitter withMaxIterations(int newMaxIter) { - return new HarmonicCurveFitter(initialGuess, - newMaxIter); - } - - /** {@inheritDoc} */ - @Override - protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { - // Prepare least-squares problem. - final int len = observations.size(); - final double[] target = new double[len]; - final double[] weights = new double[len]; - - int i = 0; - for (WeightedObservedPoint obs : observations) { - target[i] = obs.getY(); - weights[i] = obs.getWeight(); - ++i; - } - - final AbstractCurveFitter.TheoreticalValuesFunction model - = new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, - observations); - - final double[] startPoint = initialGuess != null ? - initialGuess : - // Compute estimation. - new ParameterGuesser(observations).guess(); - - // Return a new optimizer set up to fit a Gaussian curve to the - // observed points. - return new LeastSquaresBuilder(). - maxEvaluations(Integer.MAX_VALUE). - maxIterations(maxIter). - start(startPoint). - target(target). - weight(new DiagonalMatrix(weights)). - model(model.getModelFunction(), model.getModelFunctionJacobian()). - build(); - - } - - /** * This class guesses harmonic coefficients from a sample. * <p>The algorithm used to guess the coefficients is as follows:</p> * @@ -238,24 +173,22 @@ public class HarmonicCurveFitter extends AbstractCurveFitter { * estimations, these operations run in \(O(n)\) time, where \(n\) is the * number of measurements.</p> */ - public static class ParameterGuesser { - /** Amplitude. */ - private final double a; - /** Angular frequency. */ - private final double omega; - /** Phase. */ - private final double phi; - + public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser { /** - * Simple constructor. + * {@inheritDoc} * - * @param observations Sampled observations. + * @return the guessed parameters, in the following order: + * <ul> + * <li>Amplitude</li> + * <li>Angular frequency</li> + * <li>Phase</li> + * </ul> * @throws NumberIsTooSmallException if the sample is too short. * @throws ZeroException if the abscissa range is zero. * @throws MathIllegalStateException when the guessing procedure cannot * produce sensible results. */ - public ParameterGuesser(Collection<WeightedObservedPoint> observations) { + public double[] guess(Collection<WeightedObservedPoint> observations) { if (observations.size() < 4) { throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, observations.size(), 4, true); @@ -265,62 +198,15 @@ public class HarmonicCurveFitter extends AbstractCurveFitter { = sortObservations(observations).toArray(new WeightedObservedPoint[0]); final double aOmega[] = guessAOmega(sorted); - a = aOmega[0]; - omega = aOmega[1]; + final double a = aOmega[0]; + final double omega = aOmega[1]; - phi = guessPhi(sorted); - } + final double phi = guessPhi(sorted, omega); - /** - * Gets an estimation of the parameters. - * - * @return the guessed parameters, in the following order: - * <ul> - * <li>Amplitude</li> - * <li>Angular frequency</li> - * <li>Phase</li> - * </ul> - */ - public double[] guess() { return new double[] { a, omega, phi }; } /** - * Sort the observations with respect to the abscissa. - * - * @param unsorted Input observations. - * @return the input observations, sorted. - */ - private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { - final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted); - - // Since the samples are almost always already sorted, this - // method is implemented as an insertion sort that reorders the - // elements in place. Insertion sort is very efficient in this case. - WeightedObservedPoint curr = observations.get(0); - final int len = observations.size(); - for (int j = 1; j < len; j++) { - WeightedObservedPoint prec = curr; - curr = observations.get(j); - if (curr.getX() < prec.getX()) { - // the current element should be inserted closer to the beginning - int i = j - 1; - WeightedObservedPoint mI = observations.get(i); - while ((i >= 0) && (curr.getX() < mI.getX())) { - observations.set(i + 1, mI); - if (i-- != 0) { - mI = observations.get(i); - } - } - observations.set(i + 1, curr); - curr = observations.get(j); - } - } - - return observations; - } - - /** * Estimate a first guess of the amplitude and angular frequency. * * @param observations Observations, sorted w.r.t. abscissa. @@ -415,9 +301,11 @@ public class HarmonicCurveFitter extends AbstractCurveFitter { * Estimate a first guess of the phase. * * @param observations Observations, sorted w.r.t. abscissa. + * @param omega Angular frequency. * @return the guessed phase. */ - private double guessPhi(WeightedObservedPoint[] observations) { + private double guessPhi(WeightedObservedPoint[] observations, + double omega) { // initialize the means double fcMean = 0; double fsMean = 0; diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java index 325097e..9360b80 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java @@ -19,10 +19,6 @@ package org.apache.commons.math4.legacy.fitting; import java.util.Collection; import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialFunction; -import org.apache.commons.math4.legacy.exception.MathInternalError; -import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder; -import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem; -import org.apache.commons.math4.legacy.linear.DiagonalMatrix; /** * Fits points to a {@link @@ -36,25 +32,19 @@ import org.apache.commons.math4.legacy.linear.DiagonalMatrix; * * @since 3.3 */ -public class PolynomialCurveFitter extends AbstractCurveFitter { +public class PolynomialCurveFitter extends SimpleCurveFitter { /** Parametric function to be fitted. */ private static final PolynomialFunction.Parametric FUNCTION = new PolynomialFunction.Parametric(); - /** Initial guess. */ - private final double[] initialGuess; - /** Maximum number of iterations of the optimization algorithm. */ - private final int maxIter; /** * Constructor used by the factory methods. * * @param initialGuess Initial guess. * @param maxIter Maximum number of iterations of the optimization algorithm. - * @throws MathInternalError if {@code initialGuess} is {@code null}. */ private PolynomialCurveFitter(double[] initialGuess, int maxIter) { - this.initialGuess = initialGuess; - this.maxIter = maxIter; + super(FUNCTION, initialGuess, null, maxIter); } /** @@ -72,60 +62,4 @@ public class PolynomialCurveFitter extends AbstractCurveFitter { public static PolynomialCurveFitter create(int degree) { return new PolynomialCurveFitter(new double[degree + 1], Integer.MAX_VALUE); } - - /** - * Configure the start point (initial guess). - * @param newStart new start point (initial guess) - * @return a new instance. - */ - public PolynomialCurveFitter withStartPoint(double[] newStart) { - return new PolynomialCurveFitter(newStart.clone(), - maxIter); - } - - /** - * Configure the maximum number of iterations. - * @param newMaxIter maximum number of iterations - * @return a new instance. - */ - public PolynomialCurveFitter withMaxIterations(int newMaxIter) { - return new PolynomialCurveFitter(initialGuess, - newMaxIter); - } - - /** {@inheritDoc} */ - @Override - protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { - // Prepare least-squares problem. - final int len = observations.size(); - final double[] target = new double[len]; - final double[] weights = new double[len]; - - int i = 0; - for (WeightedObservedPoint obs : observations) { - target[i] = obs.getY(); - weights[i] = obs.getWeight(); - ++i; - } - - final AbstractCurveFitter.TheoreticalValuesFunction model = - new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations); - - if (initialGuess == null) { - throw new MathInternalError(); - } - - // Return a new least squares problem set up to fit a polynomial curve to the - // observed points. - return new LeastSquaresBuilder(). - maxEvaluations(Integer.MAX_VALUE). - maxIterations(maxIter). - start(initialGuess). - target(target). - weight(new DiagonalMatrix(weights)). - model(model.getModelFunction(), model.getModelFunctionJacobian()). - build(); - - } - } diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java index 9ad65a4..832168f 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java @@ -16,8 +16,14 @@ */ package org.apache.commons.math4.legacy.fitting; +import java.util.Collections; import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.ArrayList; +import org.apache.commons.math4.legacy.exception.ZeroException; +import org.apache.commons.math4.legacy.exception.OutOfRangeException; import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction; import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder; import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem; @@ -33,6 +39,8 @@ public class SimpleCurveFitter extends AbstractCurveFitter { private final ParametricUnivariateFunction function; /** Initial guess for the parameters. */ private final double[] initialGuess; + /** Parameter guesser. */ + private final ParameterGuesser guesser; /** Maximum number of iterations of the optimization algorithm. */ private final int maxIter; @@ -42,13 +50,17 @@ public class SimpleCurveFitter extends AbstractCurveFitter { * @param function Function to fit. * @param initialGuess Initial guess. Cannot be {@code null}. Its length must * be consistent with the number of parameters of the {@code function} to fit. + * @param guesser Method for providing an initial guess (if {@code initialGuess} + * is {@code null}). * @param maxIter Maximum number of iterations of the optimization algorithm. */ - private SimpleCurveFitter(ParametricUnivariateFunction function, - double[] initialGuess, - int maxIter) { + protected SimpleCurveFitter(ParametricUnivariateFunction function, + double[] initialGuess, + ParameterGuesser guesser, + int maxIter) { this.function = function; this.initialGuess = initialGuess; + this.guesser = guesser; this.maxIter = maxIter; } @@ -68,7 +80,24 @@ public class SimpleCurveFitter extends AbstractCurveFitter { */ public static SimpleCurveFitter create(ParametricUnivariateFunction f, double[] start) { - return new SimpleCurveFitter(f, start, Integer.MAX_VALUE); + return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE); + } + + /** + * Creates a curve fitter. + * The maximum number of iterations of the optimization algorithm is set + * to {@link Integer#MAX_VALUE}. + * + * @param f Function to fit. + * @param guesser Method for providing an initial guess. + * @return a curve fitter. + * + * @see #withStartPoint(double[]) + * @see #withMaxIterations(int) + */ + public static SimpleCurveFitter create(ParametricUnivariateFunction f, + ParameterGuesser guesser) { + return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE); } /** @@ -79,6 +108,7 @@ public class SimpleCurveFitter extends AbstractCurveFitter { public SimpleCurveFitter withStartPoint(double[] newStart) { return new SimpleCurveFitter(function, newStart.clone(), + null, maxIter); } @@ -90,6 +120,7 @@ public class SimpleCurveFitter extends AbstractCurveFitter { public SimpleCurveFitter withMaxIterations(int newMaxIter) { return new SimpleCurveFitter(function, initialGuess, + guesser, newMaxIter); } @@ -112,14 +143,186 @@ public class SimpleCurveFitter extends AbstractCurveFitter { = new AbstractCurveFitter.TheoreticalValuesFunction(function, observations); + final double[] startPoint = initialGuess != null ? + initialGuess : + // Compute estimation. + guesser.guess(observations); + // Create an optimizer for fitting the curve to the observed points. return new LeastSquaresBuilder(). maxEvaluations(Integer.MAX_VALUE). maxIterations(maxIter). - start(initialGuess). + start(startPoint). target(target). weight(new DiagonalMatrix(weights)). model(model.getModelFunction(), model.getModelFunctionJacobian()). build(); } + + /** + * Guesses the parameters. + */ + public static abstract class ParameterGuesser { + private final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() { + /** {@inheritDoc} */ + @Override + public int compare(WeightedObservedPoint p1, + WeightedObservedPoint p2) { + if (p1 == null && p2 == null) { + return 0; + } + if (p1 == null) { + return -1; + } + if (p2 == null) { + return 1; + } + int comp = Double.compare(p1.getX(), p2.getX()); + if (comp != 0) { + return comp; + } + comp = Double.compare(p1.getY(), p2.getY()); + if (comp != 0) { + return comp; + } + comp = Double.compare(p1.getWeight(), p2.getWeight()); + if (comp != 0) { + return comp; + } + return 0; + } + }; + + /** + * Computes an estimation of the parameters. + * + * @param obs Observations. + * @return the guessed parameters. + */ + public abstract double[] guess(Collection<WeightedObservedPoint> obs); + + /** + * Sort the observations. + * + * @param unsorted Input observations. + * @return the input observations, sorted. + */ + protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) { + final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted); + Collections.sort(observations, CMP); + return observations; + } + + /** + * Finds index of point in specified points with the largest Y. + * + * @param points Points to search. + * @return the index in specified points array. + */ + protected int findMaxY(WeightedObservedPoint[] points) { + int maxYIdx = 0; + for (int i = 1; i < points.length; i++) { + if (points[i].getY() > points[maxYIdx].getY()) { + maxYIdx = i; + } + } + return maxYIdx; + } + + /** + * Interpolates using the specified points to determine X at the + * specified Y. + * + * @param points Points to use for interpolation. + * @param startIdx Index within points from which to start the search for + * interpolation bounds points. + * @param idxStep Index step for searching interpolation bounds points. + * @param y Y value for which X should be determined. + * @return the value of X for the specified Y. + * @throws ZeroException if {@code idxStep} is 0. + * @throws OutOfRangeException if specified {@code y} is not within the + * range of the specified {@code points}. + */ + protected double interpolateXAtY(WeightedObservedPoint[] points, + int startIdx, + int idxStep, + double y) { + if (idxStep == 0) { + throw new ZeroException(); + } + final WeightedObservedPoint[] twoPoints + = getInterpolationPointsForY(points, startIdx, idxStep, y); + final WeightedObservedPoint p1 = twoPoints[0]; + final WeightedObservedPoint p2 = twoPoints[1]; + if (p1.getY() == y) { + return p1.getX(); + } + if (p2.getY() == y) { + return p2.getX(); + } + return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / + (p2.getY() - p1.getY())); + } + + /** + * Gets the two bounding interpolation points from the specified points + * suitable for determining X at the specified Y. + * + * @param points Points to use for interpolation. + * @param startIdx Index within points from which to start search for + * interpolation bounds points. + * @param idxStep Index step for search for interpolation bounds points. + * @param y Y value for which X should be determined. + * @return the array containing two points suitable for determining X at + * the specified Y. + * @throws ZeroException if {@code idxStep} is 0. + * @throws OutOfRangeException if specified {@code y} is not within the + * range of the specified {@code points}. + */ + private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, + int startIdx, + int idxStep, + double y) { + if (idxStep == 0) { + throw new ZeroException(); + } + for (int i = startIdx; + idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; + i += idxStep) { + final WeightedObservedPoint p1 = points[i]; + final WeightedObservedPoint p2 = points[i + idxStep]; + if (isBetween(y, p1.getY(), p2.getY())) { + if (idxStep < 0) { + return new WeightedObservedPoint[] { p2, p1 }; + } else { + return new WeightedObservedPoint[] { p1, p2 }; + } + } + } + + // Boundaries are replaced by dummy values because the raised + // exception is caught and the message never displayed. + // TODO: Exceptions should not be used for flow control. + throw new OutOfRangeException(y, + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY); + } + + /** + * Determines whether a value is between two other values. + * + * @param value Value to test whether it is between {@code boundary1} + * and {@code boundary2}. + * @param boundary1 One end of the range. + * @param boundary2 Other end of the range. + * @return {@code true} if {@code value} is between {@code boundary1} and + * {@code boundary2} (inclusive), {@code false} otherwise. + */ + private boolean isBetween(double value, + double boundary1, + double boundary2) { + return (value >= boundary1 && value <= boundary2) || + (value >= boundary2 && value <= boundary1); + } + } } diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java index 7ce5a00..94cbe24 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java @@ -180,7 +180,7 @@ public class GaussianCurveFitterTest { */ @Test public void testFit01() { - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter.fit(createDataset(DATASET1).toList()); Assert.assertEquals(3496978.1837704973, parameters[0], 1e-7); @@ -190,7 +190,7 @@ public class GaussianCurveFitterTest { @Test public void testDataset1LargeXShift() { - final GaussianCurveFitter fitter = GaussianCurveFitter.create(); + final SimpleCurveFitter fitter = GaussianCurveFitter.create(); final double xShift = 1e8; final double[] parameters = fitter.fit(createDataset(DATASET1, xShift, 0).toList()); @@ -204,7 +204,7 @@ public class GaussianCurveFitterTest { final int maxIter = 20; final double[] init = { 3.5e6, 4.2, 0.1 }; - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter .withMaxIterations(maxIter) .withStartPoint(init) @@ -220,7 +220,7 @@ public class GaussianCurveFitterTest { final int maxIter = 1; // Too few iterations. final double[] init = { 3.5e6, 4.2, 0.1 }; - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); fitter.withMaxIterations(maxIter) .withStartPoint(init) .fit(createDataset(DATASET1).toList()); @@ -230,7 +230,7 @@ public class GaussianCurveFitterTest { public void testWithStartPoint() { final double[] init = { 3.5e6, 4.2, 0.1 }; - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter .withStartPoint(init) .fit(createDataset(DATASET1).toList()); @@ -253,7 +253,7 @@ public class GaussianCurveFitterTest { */ @Test(expected=MathIllegalArgumentException.class) public void testFit03() { - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); fitter.fit(createDataset(new double[][] { {4.0254623, 531026.0}, {4.02804905, 664002.0} @@ -265,7 +265,7 @@ public class GaussianCurveFitterTest { */ @Test public void testFit04() { - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter.fit(createDataset(DATASET2).toList()); Assert.assertEquals(233003.2967252038, parameters[0], 1e-4); @@ -278,7 +278,7 @@ public class GaussianCurveFitterTest { */ @Test public void testFit05() { - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter.fit(createDataset(DATASET3).toList()); Assert.assertEquals(283863.81929180305, parameters[0], 1e-4); @@ -291,7 +291,7 @@ public class GaussianCurveFitterTest { */ @Test public void testFit06() { - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter.fit(createDataset(DATASET4).toList()); Assert.assertEquals(285250.66754309234, parameters[0], 1e-4); @@ -304,7 +304,7 @@ public class GaussianCurveFitterTest { */ @Test public void testFit07() { - GaussianCurveFitter fitter = GaussianCurveFitter.create(); + SimpleCurveFitter fitter = GaussianCurveFitter.create(); double[] parameters = fitter.fit(createDataset(DATASET5).toList()); Assert.assertEquals(3514384.729342235, parameters[0], 1e-4); diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java index c044a6a..05c3f82 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java @@ -49,7 +49,7 @@ public class HarmonicCurveFitterTest { points.add(1, x, f.value(x)); } - final HarmonicCurveFitter fitter = HarmonicCurveFitter.create(); + final SimpleCurveFitter fitter = HarmonicCurveFitter.create(); final double[] fitted = fitter.fit(points.toList()); Assert.assertEquals(a, fitted[0], 1.0e-13); Assert.assertEquals(w, fitted[1], 1.0e-13); @@ -74,7 +74,7 @@ public class HarmonicCurveFitterTest { points.add(1, x, f.value(x) + 0.01 * randomizer.nextGaussian()); } - final HarmonicCurveFitter fitter = HarmonicCurveFitter.create(); + final SimpleCurveFitter fitter = HarmonicCurveFitter.create(); final double[] fitted = fitter.fit(points.toList()); Assert.assertEquals(a, fitted[0], 7.6e-4); Assert.assertEquals(w, fitted[1], 2.7e-3); @@ -90,7 +90,7 @@ public class HarmonicCurveFitterTest { points.add(1, x, 1e-7 * randomizer.nextGaussian()); } - final HarmonicCurveFitter fitter = HarmonicCurveFitter.create(); + final SimpleCurveFitter fitter = HarmonicCurveFitter.create(); fitter.fit(points.toList()); // This test serves to cover the part of the code of "guessAOmega" @@ -110,7 +110,7 @@ public class HarmonicCurveFitterTest { points.add(1, x, f.value(x) + 0.01 * randomizer.nextGaussian()); } - final HarmonicCurveFitter fitter = HarmonicCurveFitter.create() + final SimpleCurveFitter fitter = HarmonicCurveFitter.create() .withStartPoint(new double[] { 0.15, 3.6, 4.5 }); final double[] fitted = fitter.fit(points.toList()); Assert.assertEquals(a, fitted[0], 1.2e-3); @@ -153,7 +153,7 @@ public class HarmonicCurveFitterTest { points.add(1, xTab[i], yTab[i]); } - final HarmonicCurveFitter fitter = HarmonicCurveFitter.create(); + final SimpleCurveFitter fitter = HarmonicCurveFitter.create(); final double[] fitted = fitter.fit(points.toList()); Assert.assertEquals(a, fitted[0], 7.6e-4); Assert.assertEquals(w, fitted[1], 3.5e-3); @@ -177,6 +177,6 @@ public class HarmonicCurveFitterTest { // and period 12, and all sample points are taken at integer abscissae // so function values all belong to the integer subset {-3, -2, -1, 0, // 1, 2, 3}. - new HarmonicCurveFitter.ParameterGuesser(points); + new HarmonicCurveFitter.ParameterGuesser().guess(points); } } diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java index 5004eb2..c540e24 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java @@ -48,7 +48,7 @@ public class PolynomialCurveFitterTest { } // Start fit from initial guesses that are far from the optimal values. - final PolynomialCurveFitter fitter + final SimpleCurveFitter fitter = PolynomialCurveFitter.create(0).withStartPoint(new double[] { -1e-20, 3e15, -5e25 }); final double[] best = fitter.fit(obs.toList()); @@ -60,7 +60,7 @@ public class PolynomialCurveFitterTest { final Random randomizer = new Random(64925784252l); for (int degree = 1; degree < 10; ++degree) { final PolynomialFunction p = buildRandomPolynomial(degree, randomizer); - final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree); + final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree); final WeightedObservedPoints obs = new WeightedObservedPoints(); for (int i = 0; i <= degree; ++i) { @@ -83,7 +83,7 @@ public class PolynomialCurveFitterTest { double maxError = 0; for (int degree = 0; degree < 10; ++degree) { final PolynomialFunction p = buildRandomPolynomial(degree, randomizer); - final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree); + final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree); final WeightedObservedPoints obs = new WeightedObservedPoints(); for (double x = -1.0; x < 1.0; x += 0.01) { @@ -114,7 +114,7 @@ public class PolynomialCurveFitterTest { double maxError = 0; for (int degree = 0; degree < 10; ++degree) { final PolynomialFunction p = buildRandomPolynomial(degree, randomizer); - final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree); + final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree); final WeightedObservedPoints obs = new WeightedObservedPoints(); for (int i = 0; i < 40000; ++i) { @@ -138,7 +138,7 @@ public class PolynomialCurveFitterTest { for (int degree = 0; degree < 10; ++degree) { final PolynomialFunction p = buildRandomPolynomial(degree, randomizer); - final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree); + final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree); final WeightedObservedPoints obs = new WeightedObservedPoints(); // reusing the same point over and over again does not bring // information, the problem cannot be solved in this case for
