This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
commit 0805920f3118e034fc957715e333a0a82d7b0a4a Author: Alex Herbert <aherb...@apache.org> AuthorDate: Tue Dec 12 10:54:59 2023 +0000 Update FirstMoment to use a half-representation This maintains the overflow protection of downscaling but avoids re-upscaling the moment and stored deviations for each input value. Upscaling is only required when computing the final result. This has a performance gain of 30-40%. Performance is approximately the same as a rolling algorithm with no downscaling. Thus this modification allows the overflow protection with negligible cost. All sub-class moments must update their scaling factors when using the deviations by appropriate powers of 2. --- .../statistics/descriptive/FirstMoment.java | 129 ++++++++++++++++----- .../descriptive/SumOfCubedDeviations.java | 8 +- .../descriptive/SumOfFourthDeviations.java | 9 +- .../descriptive/SumOfSquaredDeviations.java | 6 +- 4 files changed, 112 insertions(+), 40 deletions(-) diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java index 932ebe3..18d4ed2 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java @@ -21,7 +21,7 @@ import java.util.function.DoubleConsumer; /** * Computes the first moment (arithmetic mean) using the definitional formula: * - * <p> mean = sum(x_i) / n + * <pre>mean = sum(x_i) / n</pre> * * <p> To limit numeric errors, the value of the statistic is computed using the * following recursive updating algorithm: @@ -59,12 +59,14 @@ import java.util.function.DoubleConsumer; * </ul> */ class FirstMoment implements DoubleConsumer { + /** The downscale constant. Used to avoid overflow for all finite input. */ + private static final double DOWNSCALE = 0.5; + /** The rescale constant. */ + private static final double RESCALE = 2; + /** Count of values that have been added. */ protected long n; - /** First moment of values that have been added. */ - protected double m1; - /** * Half the deviation of most recently added value from the previous first moment. * Retained to prevent repeated computation in higher order moments. @@ -74,16 +76,25 @@ class FirstMoment implements DoubleConsumer { * * <p>This value is not used in the {@link #combine(FirstMoment)} method. */ - protected double halfDev; + protected double dev; /** - * Deviation of most recently added value from the previous first moment, + * Half the deviation of most recently added value from the previous first moment, * normalized by current sample size. Retained to prevent repeated * computation in higher order moments. + * + * <p>Note: This is (x - m1) / 2n. It is computed as a half value to prevent overflow + * when computing for any finite value x and m. + * * Note: This value is not used in the {@link #combine(FirstMoment)} method. */ protected double nDev; + /** First moment of values that have been added. + * This is stored as a half value to prevent overflow for any finite input. + * Benchmarks show this has negligible performance impact. */ + private double m1; + /** * Running sum of values seen so far. * This is not used in the computation of mean. Used as a return value for first moment when @@ -122,7 +133,7 @@ class FirstMoment implements DoubleConsumer { // "Corrected two-pass algorithm" // First pass - final FirstMoment m1 = Statistics.add(new FirstMoment(), values); + final FirstMoment m1 = create(values); final double xbar = m1.getFirstMoment(); if (!Double.isFinite(xbar)) { // Note: Also occurs when the input is empty @@ -135,11 +146,56 @@ class FirstMoment implements DoubleConsumer { } // Note: Correction may be infinite if (Double.isFinite(correction)) { - m1.m1 += correction / values.length; + // Down scale the correction to the half representation + m1.m1 += DOWNSCALE * correction / values.length; } return m1; } + /** + * Creates the first moment using a rolling algorithm. + * + * <p>This duplicates the algorithm in the {@link #accept(double)} method + * with optimisations due to the processing of an entire array: + * <ul> + * <li>Avoid updating (unused) class level working variables. + * <li>Only computing the non-finite value if required. + * </ul> + * + * @param values Values. + * @return the first moment + */ + private static FirstMoment create(double[] values) { + double m1 = 0; + int n = 0; + for (final double x : values) { + // Downscale to avoid overflow for all finite input + m1 += (x * DOWNSCALE - m1) / ++n; + } + final FirstMoment m = new FirstMoment(); + m.n = n; + m.m1 = m1; + // The non-finite value is only relevant if the data contains inf/nan + if (!Double.isFinite(m1 * RESCALE)) { + m.nonFiniteValue = sum(values); + } + return m; + } + + /** + * Compute the sum of the values. + * + * @param values Values. + * @return the sum + */ + private static double sum(double[] values) { + double sum = 0; + for (final double x : values) { + sum += x; + } + return sum; + } + /** * Updates the state of the statistic to reflect the addition of {@code value}. * @@ -151,14 +207,13 @@ class FirstMoment implements DoubleConsumer { // See: Chan et al (1983) Equation 1.3a // m_{i+1} = m_i + (x - m_i) / (i + 1) // This is modified with scaling to avoid overflow for all finite input. + // Scaling the input down by a factor of two ensures that the scaling is lossless. + // Sub-classes must alter their scaling factors when using the computed deviations. - n++; nonFiniteValue += value; - // To prevent overflow, dev is computed by scaling down and then scaling up. - // We choose to scale down by a factor of two to ensure that the scaling is lossless. - halfDev = value * 0.5 - m1 * 0.5; - // nDev cannot overflow as halfDev is <= MAX_VALUE when n > 1; or <= MAX_VALUE / 2 when n = 1 - nDev = (halfDev / n) * 2; + // Scale down the input + dev = value * DOWNSCALE - m1; + nDev = dev / ++n; m1 += nDev; } @@ -172,8 +227,10 @@ class FirstMoment implements DoubleConsumer { * {@code NaN} otherwise. */ double getFirstMoment() { - if (Double.isFinite(m1)) { - return n == 0 ? Double.NaN : m1; + // Scale back to the original magnitude + final double m = m1 * RESCALE; + if (Double.isFinite(m)) { + return n == 0 ? Double.NaN : m; } // A non-finite value must have been encountered, return nonFiniteValue which represents m1. return nonFiniteValue; @@ -194,22 +251,13 @@ class FirstMoment implements DoubleConsumer { n = n1 + n2; // Adjust the mean with the weighted difference: // m1 = m1 + (m2 - m1) * n2 / (n1 + n2) - // The difference between means can be 2 * MAX_VALUE so the computation optionally - // scales by a factor of 2. Avoiding scaling if possible preserves sub-normals. + // The half-representation ensures the difference of means is at most MAX_VALUE + // so the combine can avoid scaling. if (n1 == n2) { // Optimisation for equal sizes: m1 = (m1 + m2) / 2 - // Use scaling for a large sum - final double sum = mu1 + mu2; - m1 = Double.isFinite(sum) ? - sum * 0.5 : - mu1 * 0.5 + mu2 * 0.5; + m1 = (mu1 + mu2) * 0.5; } else { - // Use scaling for a large difference - if (Double.isFinite(mu2 - mu1)) { - m1 = combine(mu1, mu2, n1, n2); - } else { - m1 = 2 * combine(mu1 * 0.5, mu2 * 0.5, n1, n2); - } + m1 = combine(mu1, mu2, n1, n2); } return this; } @@ -231,4 +279,27 @@ class FirstMoment implements DoubleConsumer { m1 + (m2 - m1) * ((double) n2 / (n1 + n2)) : m2 + (m1 - m2) * ((double) n1 / (n1 + n2)); } + + /** + * Gets the difference of the first moment between {@code this} moment and the + * {@code other} moment. This is provided for sub-classes. + * + * @param other Other moment. + * @return the difference + */ + double getFirstMomentDifference(FirstMoment other) { + // Scale back to the original magnitude + return (m1 - other.m1) * RESCALE; + } + + /** + * Gets the half the difference of the first moment between {@code this} moment and + * the {@code other} moment. This is provided for sub-classes. + * + * @param other Other moment. + * @return the difference + */ + double getFirstMomentHalfDifference(FirstMoment other) { + return m1 - other.m1; + } } diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java index 915f52d..e58610f 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java @@ -156,10 +156,10 @@ class SumOfCubedDeviations extends SumOfSquaredDeviations { // multiplication of later terms (nDev * 3 and nDev^2). // This handles initialisation when np in {0, 1) to zero // for any deviation (e.g. series MAX_VALUE, -MAX_VALUE). - // Note: account for the half-deviation representation. + // Note: account for the half-deviation representation by scaling by 6=3*2; 8=2^3 sumCubedDev = sumCubedDev - - ss * nDev * 3 + - (np - 1.0) * np * nDev * nDev * halfDev * 2; + ss * nDev * 6 + + (np - 1.0) * np * nDev * nDev * dev * 8; } /** @@ -197,7 +197,7 @@ class SumOfCubedDeviations extends SumOfSquaredDeviations { // Avoid overflow to compute the difference. // This allows any samples of size n=1 to be combined as their SS=0. // The result is a SC=0 for the combined n=2. - final double halfDiffOfMean = m1 * 0.5 - other.m1 * 0.5; + final double halfDiffOfMean = getFirstMomentHalfDifference(other); sumCubedDev += other.sumCubedDev; // Add additional terms that do not cancel to zero if (halfDiffOfMean != 0) { diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java index 79e8e33..b0bc155 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java @@ -141,11 +141,12 @@ class SumOfFourthDeviations extends SumOfCubedDeviations { // This handles initialisation when np in {0, 1) to zero // for any deviation (e.g. series MAX_VALUE, -MAX_VALUE). // Note: (np1 * np1 - 3 * np) = (np+1)^2 - 3np = np^2 - np + 1 + // Note: account for the half-deviation representation by scaling by 8=4*2; 24=6*2^2; 16=2^4 final double np1 = n; sumFourthDev = sumFourthDev - - sc * nDev * 4 + - ss * nDev * nDev * 6 + - np * (np1 * np1 - 3 * np) * nDev * nDev * nDev * nDev * n; + sc * nDev * 8 + + ss * nDev * nDev * 24 + + np * (np1 * np1 - 3 * np) * nDev * nDev * nDev * nDev * n * 16; } /** @@ -180,7 +181,7 @@ class SumOfFourthDeviations extends SumOfCubedDeviations { sumFourthDev = other.sumFourthDev; } else if (other.n != 0) { // Avoid overflow to compute the difference. - final double halfDiffOfMean = m1 * 0.5 - other.m1 * 0.5; + final double halfDiffOfMean = getFirstMomentHalfDifference(other); sumFourthDev += other.sumFourthDev; // Add additional terms that do not cancel to zero if (halfDiffOfMean != 0) { diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java index 2e574f6..93cc421 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java @@ -133,8 +133,8 @@ class SumOfSquaredDeviations extends FirstMoment { // "Updating one-pass algorithm" // See: Chan et al (1983) Equation 1.3b super.accept(value); - // Note: account for the half-deviation representation - sumSquaredDev += (n - 1) * halfDev * nDev * 2; + // Note: account for the half-deviation representation by scaling by 4=2^2 + sumSquaredDev += (n - 1) * dev * nDev * 4; } /** @@ -159,7 +159,7 @@ class SumOfSquaredDeviations extends FirstMoment { } else if (m != 0) { // "Updating one-pass algorithm" // See: Chan et al (1983) Equation 1.5b (modified for the mean) - final double diffOfMean = other.m1 - m1; + final double diffOfMean = getFirstMomentDifference(other); final double sqDiffOfMean = diffOfMean * diffOfMean; // Enforce symmetry sumSquaredDev = (sumSquaredDev + other.sumSquaredDev) +