This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 6ff1f5411949c224f610b5beedc558ee0a77ae70
Author: Alex Herbert <[email protected]>
AuthorDate: Wed Nov 22 15:22:05 2023 +0000

    Cache the log factorial to avoid Math.log call
---
 .../AhrensDieterExponentialSampler.java            |  6 ++-
 .../rng/sampling/distribution/InternalUtils.java   | 54 +++++++++++++++-------
 .../sampling/distribution/InternalUtilsTest.java   | 13 +++---
 3 files changed, 49 insertions(+), 24 deletions(-)

diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java
index 1a171121..86fac42a 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java
@@ -62,8 +62,12 @@ public class AhrensDieterExponentialSampler
         final double ln2 = Math.log(2);
         double qi = 0;
 
+        // Start with 0!
+        // This will not overflow a long as the length < 21
+        long factorial = 1;
         for (int i = 0; i < EXPONENTIAL_SA_QI.length; i++) {
-            qi += Math.pow(ln2, i + 1.0) / InternalUtils.factorial(i + 1);
+            factorial *= i + 1;
+            qi += Math.pow(ln2, i + 1.0) / factorial;
             EXPONENTIAL_SA_QI[i] = qi;
         }
     }
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
index 8f69272f..14e9d3df 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
@@ -25,16 +25,36 @@ import org.apache.commons.rng.sampling.SharedStateSampler;
  * This class is not part of the public API, as it would be
  * better to group these utilities in a dedicated component.
  */
-final class InternalUtils { // Class is package-private on purpose; do not 
make it public.
-    /** All long-representable factorials. */
-    private static final long[] FACTORIALS = {
-        1L,                1L,                  2L,
-        6L,                24L,                 120L,
-        720L,              5040L,               40320L,
-        362880L,           3628800L,            39916800L,
-        479001600L,        6227020800L,         87178291200L,
-        1307674368000L,    20922789888000L,     355687428096000L,
-        6402373705728000L, 121645100408832000L, 2432902008176640000L };
+final class InternalUtils {
+    /** All long-representable factorials, precomputed as the natural
+     * logarithm using Matlab R2023a VPA: log(vpa(x)).
+     *
+     * <p>Note: This table could be any length. Previously this stored
+     * the long value of n!, not log(n!). Using the previous length
+     * maintains behaviour. */
+    private static final double[] LOG_FACTORIALS = {
+        0,
+        0,
+        0.69314718055994530941723212145818,
+        1.7917594692280550008124773583807,
+        3.1780538303479456196469416012971,
+        4.7874917427820459942477009345232,
+        6.5792512120101009950601782929039,
+        8.5251613610654143001655310363471,
+        10.604602902745250228417227400722,
+        12.801827480081469611207717874567,
+        15.104412573075515295225709329251,
+        17.502307845873885839287652907216,
+        19.987214495661886149517362387055,
+        22.55216385312342288557084982862,
+        25.191221182738681500093434693522,
+        27.89927138384089156608943926367,
+        30.671860106080672803758367749503,
+        33.505073450136888884007902367376,
+        36.39544520803305357621562496268,
+        39.339884187199494036224652394567,
+        42.33561646075348502965987597071
+    };
 
     /** The first array index with a non-zero log factorial. */
     private static final int BEGIN_LOG_FACTORIALS = 2;
@@ -54,8 +74,8 @@ final class InternalUtils { // Class is package-private on 
purpose; do not make
      * @throws IndexOutOfBoundsException if the result is too large to be 
represented
      * by a {@code long} (i.e. if {@code n > 20}), or {@code n} is negative.
      */
-    static long factorial(int n)  {
-        return FACTORIALS[n];
+    static double logFactorial(int n)  {
+        return LOG_FACTORIALS[n];
     }
 
     /**
@@ -283,8 +303,8 @@ final class InternalUtils { // Class is package-private on 
purpose; do not make
 
             // Compute remaining values.
             for (int i = endCopy; i < numValues; i++) {
-                if (i < FACTORIALS.length) {
-                    logFactorials[i] = Math.log(FACTORIALS[i]);
+                if (i < LOG_FACTORIALS.length) {
+                    logFactorials[i] = LOG_FACTORIALS[i];
                 } else {
                     logFactorials[i] = logFactorials[i - 1] + Math.log(i);
                 }
@@ -325,9 +345,9 @@ final class InternalUtils { // Class is package-private on 
purpose; do not make
                 return logFactorials[n];
             }
 
-            // Use cache of precomputed factorial values.
-            if (n < FACTORIALS.length) {
-                return Math.log(FACTORIALS[n]);
+            // Use cache of precomputed log factorial values.
+            if (n < LOG_FACTORIALS.length) {
+                return LOG_FACTORIALS[n];
             }
 
             // Delegate.
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java
index f78c5362..512468dd 100644
--- 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InternalUtilsTest.java
@@ -32,24 +32,25 @@ class InternalUtilsTest {
 
     @Test
     void testFactorial() {
-        Assertions.assertEquals(1L, InternalUtils.factorial(0));
+        Assertions.assertEquals(0L, InternalUtils.logFactorial(0));
         long result = 1;
         for (int n = 1; n <= MAX_REPRESENTABLE; n++) {
             result *= n;
-            Assertions.assertEquals(result, InternalUtils.factorial(n));
+            final double expected = Math.log(result);
+            Assertions.assertEquals(expected, InternalUtils.logFactorial(n), 
Math.ulp(expected));
         }
     }
 
     @Test
     void testFactorialThrowsWhenNegative() {
         Assertions.assertThrows(IndexOutOfBoundsException.class,
-            () -> InternalUtils.factorial(-1));
+            () -> InternalUtils.logFactorial(-1));
     }
 
     @Test
     void testFactorialThrowsWhenNotRepresentableAsLong() {
         Assertions.assertThrows(IndexOutOfBoundsException.class,
-            () -> InternalUtils.factorial(MAX_REPRESENTABLE + 1));
+            () -> InternalUtils.logFactorial(MAX_REPRESENTABLE + 1));
     }
 
     @Test
@@ -60,7 +61,7 @@ class InternalUtilsTest {
         Assertions.assertEquals(0, factorialLog.value(0), 1e-10);
         for (int n = 1; n <= MAX_REPRESENTABLE + 5; n++) {
             // Use Commons math to compute logGamma(1 + n);
-            double expected = Gamma.logGamma(1 + n);
+            final double expected = Gamma.logGamma(1 + n);
             Assertions.assertEquals(expected, factorialLog.value(n), 1e-10);
         }
     }
@@ -71,7 +72,7 @@ class InternalUtilsTest {
         FactorialLog factorialLog = FactorialLog.create().withCache(limit);
         for (int n = MAX_REPRESENTABLE; n <= limit; n++) {
             // Use Commons math to compute logGamma(1 + n);
-            double expected = Gamma.logGamma(1 + n);
+            final double expected = Gamma.logGamma(1 + n);
             Assertions.assertEquals(expected, factorialLog.value(n), 1e-10);
         }
     }

Reply via email to