This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 6ff1f5411949c224f610b5beedc558ee0a77ae70 Author: Alex Herbert <[email protected]> AuthorDate: Wed Nov 22 15:22:05 2023 +0000 Cache the log factorial to avoid Math.log call --- .../AhrensDieterExponentialSampler.java | 6 ++- .../rng/sampling/distribution/InternalUtils.java | 54 +++++++++++++++------- .../sampling/distribution/InternalUtilsTest.java | 13 +++--- 3 files changed, 49 insertions(+), 24 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java index 1a171121..86fac42a 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java @@ -62,8 +62,12 @@ public class AhrensDieterExponentialSampler final double ln2 = Math.log(2); double qi = 0; + // Start with 0! + // This will not overflow a long as the length < 21 + long factorial = 1; for (int i = 0; i < EXPONENTIAL_SA_QI.length; i++) { - qi += Math.pow(ln2, i + 1.0) / InternalUtils.factorial(i + 1); + factorial *= i + 1; + qi += Math.pow(ln2, i + 1.0) / factorial; EXPONENTIAL_SA_QI[i] = qi; } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java index 8f69272f..14e9d3df 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java @@ -25,16 +25,36 @@ import org.apache.commons.rng.sampling.SharedStateSampler; * This class is not part of the public API, as it would be * better to group these utilities in a dedicated component. */ -final class InternalUtils { // Class is package-private on purpose; do not make it public. - /** All long-representable factorials. */ - private static final long[] FACTORIALS = { - 1L, 1L, 2L, - 6L, 24L, 120L, - 720L, 5040L, 40320L, - 362880L, 3628800L, 39916800L, - 479001600L, 6227020800L, 87178291200L, - 1307674368000L, 20922789888000L, 355687428096000L, - 6402373705728000L, 121645100408832000L, 2432902008176640000L }; +final class InternalUtils { + /** All long-representable factorials, precomputed as the natural + * logarithm using Matlab R2023a VPA: log(vpa(x)). + * + * <p>Note: This table could be any length. Previously this stored + * the long value of n!, not log(n!). Using the previous length + * maintains behaviour. */ + private static final double[] LOG_FACTORIALS = { + 0, + 0, + 0.69314718055994530941723212145818, + 1.7917594692280550008124773583807, + 3.1780538303479456196469416012971, + 4.7874917427820459942477009345232, + 6.5792512120101009950601782929039, + 8.5251613610654143001655310363471, + 10.604602902745250228417227400722, + 12.801827480081469611207717874567, + 15.104412573075515295225709329251, + 17.502307845873885839287652907216, + 19.987214495661886149517362387055, + 22.55216385312342288557084982862, + 25.191221182738681500093434693522, + 27.89927138384089156608943926367, + 30.671860106080672803758367749503, + 33.505073450136888884007902367376, + 36.39544520803305357621562496268, + 39.339884187199494036224652394567, + 42.33561646075348502965987597071 + }; /** The first array index with a non-zero log factorial. */ private static final int BEGIN_LOG_FACTORIALS = 2; @@ -54,8 +74,8 @@ final class InternalUtils { // Class is package-private on purpose; do not make * @throws IndexOutOfBoundsException if the result is too large to be represented * by a {@code long} (i.e. if {@code n > 20}), or {@code n} is negative. */ - static long factorial(int n) { - return FACTORIALS[n]; + static double logFactorial(int n) { + return LOG_FACTORIALS[n]; } /** @@ -283,8 +303,8 @@ final class InternalUtils { // Class is package-private on purpose; do not make // Compute remaining values. for (int i = endCopy; i < numValues; i++) { - if (i < FACTORIALS.length) { - logFactorials[i] = Math.log(FACTORIALS[i]); + if (i < LOG_FACTORIALS.length) { + logFactorials[i] = LOG_FACTORIALS[i]; } else { logFactorials[i] = logFactorials[i - 1] + Math.log(i); } @@ -325,9 +345,9 @@ final class InternalUtils { // Class is package-private on purpose; do not make return logFactorials[n]; } - // Use cache of precomputed factorial values. - if (n < FACTORIALS.length) { - return Math.log(FACTORIALS[n]); + // Use cache of precomputed log factorial values. + if (n < LOG_FACTORIALS.length) { + return LOG_FACTORIALS[n]; } // Delegate. diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java index f78c5362..512468dd 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java @@ -32,24 +32,25 @@ class InternalUtilsTest { @Test void testFactorial() { - Assertions.assertEquals(1L, InternalUtils.factorial(0)); + Assertions.assertEquals(0L, InternalUtils.logFactorial(0)); long result = 1; for (int n = 1; n <= MAX_REPRESENTABLE; n++) { result *= n; - Assertions.assertEquals(result, InternalUtils.factorial(n)); + final double expected = Math.log(result); + Assertions.assertEquals(expected, InternalUtils.logFactorial(n), Math.ulp(expected)); } } @Test void testFactorialThrowsWhenNegative() { Assertions.assertThrows(IndexOutOfBoundsException.class, - () -> InternalUtils.factorial(-1)); + () -> InternalUtils.logFactorial(-1)); } @Test void testFactorialThrowsWhenNotRepresentableAsLong() { Assertions.assertThrows(IndexOutOfBoundsException.class, - () -> InternalUtils.factorial(MAX_REPRESENTABLE + 1)); + () -> InternalUtils.logFactorial(MAX_REPRESENTABLE + 1)); } @Test @@ -60,7 +61,7 @@ class InternalUtilsTest { Assertions.assertEquals(0, factorialLog.value(0), 1e-10); for (int n = 1; n <= MAX_REPRESENTABLE + 5; n++) { // Use Commons math to compute logGamma(1 + n); - double expected = Gamma.logGamma(1 + n); + final double expected = Gamma.logGamma(1 + n); Assertions.assertEquals(expected, factorialLog.value(n), 1e-10); } } @@ -71,7 +72,7 @@ class InternalUtilsTest { FactorialLog factorialLog = FactorialLog.create().withCache(limit); for (int n = MAX_REPRESENTABLE; n <= limit; n++) { // Use Commons math to compute logGamma(1 + n); - double expected = Gamma.logGamma(1 + n); + final double expected = Gamma.logGamma(1 + n); Assertions.assertEquals(expected, factorialLog.value(n), 1e-10); } }
