RNG-50: PoissonSampler speed improvements

The algorithms for small mean and large mean have been separated into
dedicated classes. Caching of constants used in the algorithm has been
used to increase speed.

Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/7b4a4142
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/7b4a4142
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/7b4a4142

Branch: refs/heads/1.1
Commit: 7b4a41428c2bed270261b4fedd3a562adaee04b5
Parents: f33ea24
Author: aherbert <a.herb...@sussex.ac.uk>
Authored: Wed Aug 1 12:09:28 2018 +0100
Committer: aherbert <a.herb...@sussex.ac.uk>
Committed: Wed Aug 1 12:09:28 2018 +0100

----------------------------------------------------------------------
 .../distribution/LargeMeanPoissonSampler.java   | 184 +++++++++++++++++++
 .../sampling/distribution/PoissonSampler.java   | 136 ++------------
 .../distribution/SmallMeanPoissonSampler.java   |  85 +++++++++
 .../distribution/DiscreteSamplersList.java      |  13 +-
 4 files changed, 293 insertions(+), 125 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-rng/blob/7b4a4142/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
new file mode 100644
index 0000000..c861802
--- /dev/null
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
+
+/**
+ * Sampler for the <a 
href="http://mathworld.wolfram.com/PoissonDistribution.html";>Poisson 
distribution</a>.
+ *
+ * <ul>
+ *  <li>
+ *   For large means, we use the rejection algorithm described in
+ *   <blockquote>
+ *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random 
Variables</i><br>
+ *    <strong>Computing</strong> vol. 26 pp. 197-207.
+ *   </blockquote>
+ *  </li>
+ * </ul>
+ * 
+ * This sampler is suitable for {@code mean >= 40}.
+ */
+public class LargeMeanPoissonSampler
+    extends SamplerBase
+    implements DiscreteSampler {
+
+    /** Class to compute {@code log(n!)}. This has no cached values. */
+    static private final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
+
+    static {
+        // Create without a cache.
+        NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
+    }
+
+    /** Exponential. */
+    private final ContinuousSampler exponential;
+    /** Gaussian. */
+    private final ContinuousSampler gaussian;
+    /** Local class to compute {@code log(n!)}. This may have cached values. */
+    private final InternalUtils.FactorialLog factorialLog;
+ 
+    // Working values
+    private final double lambda;
+    private final double lambdaFractional;
+    private final double logLambda;
+    private final double logLambdaFactorial;
+    private final double delta;
+    private final double halfDelta;
+    private final double twolpd;
+    private final double p1;
+    private final double p2;
+    private final double c1;
+
+    /** The internal Poisson sampler for the lambda fraction. */
+    private final DiscreteSampler smallMeanPoissonSampler;
+
+    /**
+     * @param rng  Generator of uniformly distributed random numbers.
+     * @param mean Mean.
+     * @throws IllegalArgumentException if {@code mean <= 0}.
+     */
+    public LargeMeanPoissonSampler(UniformRandomProvider rng, 
+                                      double mean) {
+        super(rng);
+        if (mean <= 0) {
+            throw new IllegalArgumentException(mean + " <= " + 0);
+        }
+        
+        gaussian = new ZigguratNormalizedGaussianSampler(rng);
+        exponential = new AhrensDieterExponentialSampler(rng, 1);
+        // Plain constructor uses the uncached function.
+        factorialLog = NO_CACHE_FACTORIAL_LOG;
+
+        // Cache values used in the algorithm
+        lambda = Math.floor(mean);
+        lambdaFractional = mean - lambda;
+        logLambda = Math.log(lambda);
+        logLambdaFactorial = factorialLog((int) lambda);
+        delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
+        halfDelta = delta / 2;
+        twolpd = 2 * lambda + delta;
+        c1 = 1 / (8 * lambda);
+        final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
+        final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / 
twolpd);
+        final double aSum = a1 + a2 + 1;
+        p1 = a1 / aSum;
+        p2 = a2 / aSum;
+
+        // The algorithm requires a Poisson sample from the remaining lambda 
fraction.
+        smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
+            null : // Not used.
+            new SmallMeanPoissonSampler(rng, lambdaFractional);
+    }
+    
+    /** {@inheritDoc} */
+    @Override
+    public int sample() {
+
+        final int y2 = (smallMeanPoissonSampler == null) ? 
+            0 : // No lambda fraction
+            smallMeanPoissonSampler.sample();
+        
+        double x = 0;
+        double y = 0;
+        double v = 0;
+        int a = 0;
+        double t = 0;
+        double qr = 0;
+        double qa = 0;
+        while (true) {
+            final double u = nextDouble();
+            if (u <= p1) {
+                final double n = gaussian.sample();
+                x = n * Math.sqrt(lambda + halfDelta) - 0.5d;
+                if (x > delta || x < -lambda) {
+                    continue;
+                }
+                y = x < 0 ? Math.floor(x) : Math.ceil(x);
+                final double e = exponential.sample();
+                v = -e - 0.5 * n * n + c1;
+            } else {
+                if (u > p1 + p2) {
+                    y = lambda;
+                    break;
+                }
+                x = delta + (twolpd / delta) * exponential.sample();
+                y = Math.ceil(x);
+                v = -exponential.sample() - delta * (x + 1) / twolpd;
+            }
+            a = x < 0 ? 1 : 0;
+            t = y * (y + 1) / (2 * lambda);
+            if (v < -t && a == 0) {
+                y = lambda + y;
+                break;
+            }
+            qr = t * ((2 * y + 1) / (6 * lambda) - 1);
+            qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
+            if (v < qa) {
+                y = lambda + y;
+                break;
+            }
+            if (v > qr) {
+                continue;
+            }
+            if (v < y * logLambda - factorialLog((int) (y + lambda)) + 
logLambdaFactorial) {
+                y = lambda + y;
+                break;
+            }
+        }
+        
+        return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
+    }
+
+    /**
+     * Compute the natural logarithm of the factorial of {@code n}.
+     *
+     * @param n Argument.
+     * @return {@code log(n!)}
+     * @throws IllegalArgumentException if {@code n < 0}.
+     */
+    private final double factorialLog(int n) {
+        return factorialLog.value(n);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        return "Large Mean Poisson deviate [" + super.toString() + "]";
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/7b4a4142/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java
index 195e725..cf7d9ac 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java
@@ -30,51 +30,38 @@ import org.apache.commons.rng.UniformRandomProvider;
  *  <li>
  *   For large means, we use the rejection algorithm described in
  *   <blockquote>
- *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random 
Variables</i><br>
+ *    Devroye, Luc. (1981). <i>The Computer Generation of Poisson Random 
Variables</i><br>
  *    <strong>Computing</strong> vol. 26 pp. 197-207.
  *   </blockquote>
  *  </li>
  * </ul>
  */
 public class PoissonSampler
-    extends SamplerBase
     implements DiscreteSampler {
+
     /** Value for switching sampling algorithm. */
-    private static final double PIVOT = 40;
-    /** Mean of the distribution. */
-    private final double mean;
-    /** Exponential. */
-    private final ContinuousSampler exponential;
-    /** Gaussian. */
-    private final NormalizedGaussianSampler gaussian;
-    /** {@code log(n!)}. */
-    private final InternalUtils.FactorialLog factorialLog;
+    static final double PIVOT = 40;
+    /** The internal Poisson sampler. */
+    private final DiscreteSampler poissonSampler;
 
     /**
-     * @param rng Generator of uniformly distributed random numbers.
+     * @param rng  Generator of uniformly distributed random numbers.
      * @param mean Mean.
      * @throws IllegalArgumentException if {@code mean <= 0}.
      */
     public PoissonSampler(UniformRandomProvider rng,
                           double mean) {
-        super(rng);
-        if (mean <= 0) {
-            throw new IllegalArgumentException(mean + " <= " + 0);
-        }
-
-        this.mean = mean;
-
-        gaussian = new ZigguratNormalizedGaussianSampler(rng);
-        exponential = new AhrensDieterExponentialSampler(rng, 1);
-        factorialLog = mean < PIVOT ?
-            null : // Not used.
-            InternalUtils.FactorialLog.create().withCache((int) Math.min(mean, 
2 * PIVOT));
+        // Delegate all work to specialised samplers.
+        // These should check the input arguments.
+        poissonSampler = mean < PIVOT ?
+            new SmallMeanPoissonSampler(rng, mean) :
+            new LargeMeanPoissonSampler(rng, mean);
     }
 
     /** {@inheritDoc} */
     @Override
     public int sample() {
-        return (int) Math.min(nextPoisson(mean), Integer.MAX_VALUE);
+        return poissonSampler.sample();
     }
 
     /** {@inheritDoc} */
@@ -82,103 +69,4 @@ public class PoissonSampler
     public String toString() {
         return "Poisson deviate [" + super.toString() + "]";
     }
-
-    /**
-     * @param meanPoisson Mean.
-     * @return the next sample.
-     */
-    private long nextPoisson(double meanPoisson) {
-        if (meanPoisson < PIVOT) {
-            double p = Math.exp(-meanPoisson);
-            long n = 0;
-            double r = 1;
-
-            while (n < 1000 * meanPoisson) {
-                r *= nextDouble();
-                if (r >= p) {
-                    n++;
-                } else {
-                    break;
-                }
-            }
-            return n;
-        } else {
-            final double lambda = Math.floor(meanPoisson);
-            final double lambdaFractional = meanPoisson - lambda;
-            final double logLambda = Math.log(lambda);
-            final double logLambdaFactorial = factorialLog((int) lambda);
-            final long y2 = lambdaFractional < Double.MIN_VALUE ? 0 : 
nextPoisson(lambdaFractional);
-            final double delta = Math.sqrt(lambda * Math.log(32 * lambda / 
Math.PI + 1));
-            final double halfDelta = delta / 2;
-            final double twolpd = 2 * lambda + delta;
-            final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(1 / (8 * 
lambda));
-            final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) 
/ twolpd);
-            final double aSum = a1 + a2 + 1;
-            final double p1 = a1 / aSum;
-            final double p2 = a2 / aSum;
-            final double c1 = 1 / (8 * lambda);
-
-            double x;
-            double y;
-            double v;
-            int a;
-            double t;
-            double qr;
-            double qa;
-            while (true) {
-                final double u = nextDouble();
-                if (u <= p1) {
-                    final double n = gaussian.sample();
-                    x = n * Math.sqrt(lambda + halfDelta) - 0.5;
-                    if (x > delta ||
-                        x < -lambda) {
-                        continue;
-                    }
-                    y = x < 0 ? Math.floor(x) : Math.ceil(x);
-                    final double e = exponential.sample();
-                    v = -e - 0.5 * n * n + c1;
-                } else {
-                    if (u > p1 + p2) {
-                        y = lambda;
-                        break;
-                    } else {
-                        x = delta + (twolpd / delta) * exponential.sample();
-                        y = Math.ceil(x);
-                        v = -exponential.sample() - delta * (x + 1) / twolpd;
-                    }
-                }
-                a = x < 0 ? 1 : 0;
-                t = y * (y + 1) / (2 * lambda);
-                if (v < -t && a == 0) {
-                    y = lambda + y;
-                    break;
-                }
-                qr = t * ((2 * y + 1) / (6 * lambda) - 1);
-                qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
-                if (v < qa) {
-                    y = lambda + y;
-                    break;
-                }
-                if (v > qr) {
-                    continue;
-                }
-                if (v < y * logLambda - factorialLog((int) (y + lambda)) + 
logLambdaFactorial) {
-                    y = lambda + y;
-                    break;
-                }
-            }
-            return y2 + (long) y;
-        }
-    }
-
-    /**
-     * Compute the natural logarithm of the factorial of {@code n}.
-     *
-     * @param n Argument.
-     * @return {@code log(n!)}
-     * @throws IllegalArgumentException if {@code n < 0}.
-     */
-    private double factorialLog(int n) {
-        return factorialLog.value(n);
-    }
 }

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/7b4a4142/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java
new file mode 100644
index 0000000..9f47c76
--- /dev/null
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Sampler for the <a 
href="http://mathworld.wolfram.com/PoissonDistribution.html";>Poisson 
distribution</a>.
+ *
+ * <ul>
+ *  <li>
+ *   For small means, a Poisson process is simulated using uniform deviates, as
+ *   described <a 
href="http://mathaa.epfl.ch/cours/PMMI2001/interactive/rng7.htm";>here</a>.
+ *   The Poisson process (and hence, the returned value) is bounded by 1000 * 
mean.
+ *  </li>
+ * </ul>
+ * 
+ * This sampler is suitable for {@code mean < 40}.
+ */
+public class SmallMeanPoissonSampler
+    extends SamplerBase
+    implements DiscreteSampler {
+
+    /** 
+     * Pre-compute {@code Math.exp(-mean)}. 
+     * Note: This is the probability of the Poisson sample {@code P(n=0)}.
+     */
+    private final double p0;
+    /** Pre-compute {@code 1000 * mean} as the upper limit of the sample. */
+    private final int limit;
+
+    /**
+     * @param rng  Generator of uniformly distributed random numbers.
+     * @param mean Mean.
+     * @throws IllegalArgumentException if {@code mean <= 0}.
+     */
+    public SmallMeanPoissonSampler(UniformRandomProvider rng,
+                                   double mean) {
+        super(rng);
+        if (mean <= 0) {
+            throw new IllegalArgumentException(mean + " <= " + 0);
+        }
+        
+        p0 = Math.exp(-mean);
+        // The returned sample is bounded by 1000 * mean or Integer.MAX_VALUE
+        limit = (int) Math.ceil(Math.min(1000 * mean, Integer.MAX_VALUE));
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int sample() {
+        int n = 0;
+        double r = 1;
+
+        while (n < limit) {
+            r *= nextDouble();
+            if (r >= p0) {
+                n++;
+            } else {
+                break;
+            }
+        }
+        return n;
+    }
+    
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        return "Small Mean Poisson deviate [" + super.toString() + "]";
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/7b4a4142/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
index 2c02bdb..23122c7 100644
--- 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -16,7 +16,6 @@
  */
 package org.apache.commons.rng.sampling.distribution;
 
-import java.util.Arrays;
 import java.util.List;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -109,16 +108,28 @@ public class DiscreteSamplersList {
             add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
                 MathArrays.sequence(10, 0, 1),
                 new PoissonSampler(RandomSource.create(RandomSource.KISS), 
meanPoisson));
+            // Dedicated small mean poisson sampler
+            add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
+                MathArrays.sequence(10, 0, 1),
+                new 
SmallMeanPoissonSampler(RandomSource.create(RandomSource.KISS), meanPoisson));
             // Poisson (40 < mean < 80).
             final double largeMeanPoisson = 67.89;
             add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
                 new 
PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), 
largeMeanPoisson));
+            // Dedicated large mean poisson sampler
+            add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
+                MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
+                new 
LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), 
largeMeanPoisson));
             // Poisson (mean >> 40).
             final double veryLargeMeanPoisson = 543.21;
             add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(veryLargeMeanPoisson),
                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
                 new 
PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), 
veryLargeMeanPoisson));
+            // Dedicated large mean poisson sampler
+            add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(veryLargeMeanPoisson),
+                MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
+                new 
LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), 
veryLargeMeanPoisson));
         } catch (Exception e) {
             System.err.println("Unexpected exception while creating the list 
of samplers: " + e);
             e.printStackTrace(System.err);

Reply via email to