This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 153b11633b1c4fa9bbfec6c5ac79635b91600400 Author: Alex Herbert <aherb...@apache.org> AuthorDate: Sat Jul 6 20:20:02 2019 +0100 RNG-102: Implement the SharedStateSampler interface. --- .../commons/rng/sampling/CollectionSampler.java | 18 +++- .../commons/rng/sampling/CombinationSampler.java | 27 +++++- .../DiscreteProbabilityCollectionSampler.java | 20 +++- .../commons/rng/sampling/PermutationSampler.java | 26 +++++- .../commons/rng/sampling/UnitSphereSampler.java | 19 +++- .../AhrensDieterExponentialSampler.java | 20 +++- .../AhrensDieterMarsagliaTsangGammaSampler.java | 70 +++++++++++++- .../distribution/AliasMethodDiscreteSampler.java | 20 +++- .../BoxMullerNormalizedGaussianSampler.java | 9 +- .../sampling/distribution/ChengBetaSampler.java | 21 ++++- .../distribution/ContinuousUniformSampler.java | 9 +- .../distribution/DiscreteUniformSampler.java | 33 ++++++- .../rng/sampling/distribution/GaussianSampler.java | 35 ++++++- .../sampling/distribution/GeometricSampler.java | 46 +++++++++- .../distribution/GuideTableDiscreteSampler.java | 20 +++- .../InverseTransformContinuousSampler.java | 14 ++- .../InverseTransformDiscreteSampler.java | 14 ++- .../InverseTransformParetoSampler.java | 21 ++++- .../distribution/KempSmallMeanPoissonSampler.java | 20 +++- .../distribution/LargeMeanPoissonSampler.java | 45 ++++++++- .../sampling/distribution/LogNormalSampler.java | 35 ++++++- .../MarsagliaNormalizedGaussianSampler.java | 9 +- .../MarsagliaTsangWangDiscreteSampler.java | 101 +++++++++++++++++++-- .../rng/sampling/distribution/PoissonSampler.java | 26 +++++- .../RejectionInversionZipfSampler.java | 24 ++++- .../distribution/SmallMeanPoissonSampler.java | 20 +++- .../ZigguratNormalizedGaussianSampler.java | 9 +- .../rng/sampling/CollectionSamplerTest.java | 30 +++++- .../rng/sampling/CombinationSamplerTest.java | 27 ++++++ .../DiscreteProbabilityCollectionSamplerTest.java | 28 ++++++ .../rng/sampling/PermutationSamplerTest.java | 27 ++++++ .../apache/commons/rng/sampling/RandomAssert.java | 3 +- .../rng/sampling/UnitSphereSamplerTest.java | 26 ++++++ .../AhrensDieterExponentialSamplerTest.java | 16 ++++ ...AhrensDieterMarsagliaTsangGammaSamplerTest.java | 33 +++++++ .../AliasMethodDiscreteSamplerTest.java | 32 +++++++ ...=> BoxMullerNormalisedGaussianSamplerTest.java} | 22 +++-- .../distribution/ChengBetaSamplerTest.java | 17 ++++ .../distribution/ContinuousUniformSamplerTest.java | 16 ++++ .../distribution/DiscreteUniformSamplerTest.java | 32 +++++++ .../sampling/distribution/GaussianSamplerTest.java | 74 +++++++++++++++ .../distribution/GeometricSamplerTest.java | 32 +++++++ .../GuideTableDiscreteSamplerTest.java | 15 +++ .../InverseTransformContinuousSamplerTest.java | 47 ++++++++++ .../InverseTransformDiscreteSamplerTest.java | 47 ++++++++++ .../InverseTransformParetoSamplerTest.java | 25 ++++- .../KempSmallMeanPoissonSamplerTest.java | 16 ++++ .../distribution/LargeMeanPoissonSamplerTest.java | 33 +++++++ .../distribution/LogNormalSamplerTest.java | 74 +++++++++++++++ ...=> MarsagliaNormalisedGaussianSamplerTest.java} | 22 +++-- .../MarsagliaTsangWangDiscreteSamplerTest.java | 89 ++++++++++++++++-- .../sampling/distribution/PoissonSamplerTest.java | 58 ++++++++++++ .../RejectionInversionZipfSamplerTest.java | 17 ++++ .../distribution/SmallMeanPoissonSamplerTest.java | 15 +++ .../ZigguratNormalizedGaussianSamplerTest.java | 15 +++ 55 files changed, 1539 insertions(+), 80 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java index 54f9ee9..12c8b93 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java @@ -32,7 +32,7 @@ import org.apache.commons.rng.UniformRandomProvider; * * @since 1.0 */ -public class CollectionSampler<T> { +public class CollectionSampler<T> implements SharedStateSampler<CollectionSampler<T>> { /** Collection to be sampled from. */ private final List<T> items; /** RNG. */ @@ -57,6 +57,16 @@ public class CollectionSampler<T> { } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private CollectionSampler(UniformRandomProvider rng, + CollectionSampler<T> source) { + this.rng = rng; + items = source.items; + } + + /** * Picks one of the items from the * {@link #CollectionSampler(UniformRandomProvider,Collection) * collection passed to the constructor}. @@ -66,4 +76,10 @@ public class CollectionSampler<T> { public T sample() { return items.get(rng.nextInt(items.size())); } + + /** {@inheritDoc} */ + @Override + public CollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { + return new CollectionSampler<T>(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java index eeae0d5..848f737 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java @@ -39,7 +39,7 @@ import org.apache.commons.rng.UniformRandomProvider; * * @see PermutationSampler */ -public class CombinationSampler { +public class CombinationSampler implements SharedStateSampler<CombinationSampler> { /** Domain of the combination. */ private final int[] domain; /** The number of steps of a full shuffle to perform. */ @@ -90,6 +90,25 @@ public class CombinationSampler { } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private CombinationSampler(UniformRandomProvider rng, + CombinationSampler source) { + // Do not clone the domain. This ensures: + // 1. Thread safety as the domain may be shuffled during the clone + // and a shuffle swap step can result in duplicates and missing elements + // in the array. + // 2. If the argument RNG is an exact match for the RNG in the source + // then the output sequence will differ unless the domain is currently + // in natural order. + domain = PermutationSampler.natural(source.domain.length); + steps = source.steps; + upper = source.upper; + this.rng = rng; + } + + /** * Return a combination of {@code k} whose entries are selected randomly, * without repetition, from the integers 0, 1, ..., {@code n}-1 (inclusive). * @@ -101,4 +120,10 @@ public class CombinationSampler { public int[] sample() { return SubsetSamplerUtils.partialSample(domain, steps, rng, upper); } + + /** {@inheritDoc} */ + @Override + public CombinationSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new CombinationSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java index 06f108a..4bc50b4 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java @@ -38,7 +38,8 @@ import org.apache.commons.rng.UniformRandomProvider; * * @since 1.1 */ -public class DiscreteProbabilityCollectionSampler<T> { +public class DiscreteProbabilityCollectionSampler<T> + implements SharedStateSampler<DiscreteProbabilityCollectionSampler<T>> { /** Collection to be sampled from. */ private final List<T> items; /** RNG. */ @@ -125,6 +126,17 @@ public class DiscreteProbabilityCollectionSampler<T> { } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, + DiscreteProbabilityCollectionSampler<T> source) { + this.rng = rng; + this.items = source.items; + this.cumulativeProbabilities = source.cumulativeProbabilities; + } + + /** * Picks one of the items from the collection passed to the constructor. * * @return a random sample. @@ -148,6 +160,12 @@ public class DiscreteProbabilityCollectionSampler<T> { return items.get(items.size() - 1); } + /** {@inheritDoc} */ + @Override + public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { + return new DiscreteProbabilityCollectionSampler<T>(rng, this); + } + /** * @param collection Collection to be sampled. * @param probabilities Probability associated to each item of the diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java index 7cb72cc..37aa7bc 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java @@ -27,7 +27,7 @@ import org.apache.commons.rng.UniformRandomProvider; * * <p>This class also contains utilities for shuffling an {@code int[]} array in-place.</p> */ -public class PermutationSampler { +public class PermutationSampler implements SharedStateSampler<PermutationSampler> { /** Domain of the permutation. */ private final int[] domain; /** Size of the permutation. */ @@ -60,6 +60,24 @@ public class PermutationSampler { } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private PermutationSampler(UniformRandomProvider rng, + PermutationSampler source) { + // Do not clone the domain. This ensures: + // 1. Thread safety as the domain may be shuffled during the clone + // and an incomplete shuffle swap step can result in duplicates and + // missing elements in the array. + // 2. If the argument RNG is an exact match for the RNG in the source + // then the output sequence will differ unless the domain is currently + // in natural order. + domain = PermutationSampler.natural(source.domain.length); + size = source.size; + this.rng = rng; + } + + /** * @return a random permutation. * * @see #PermutationSampler(UniformRandomProvider,int,int) @@ -68,6 +86,12 @@ public class PermutationSampler { return SubsetSamplerUtils.partialSample(domain, size, rng, true); } + /** {@inheritDoc} */ + @Override + public PermutationSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new PermutationSampler(rng, this); + } + /** * Shuffles the entries of the given array. * diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/UnitSphereSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/UnitSphereSampler.java index 661702b..67cecd1 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/UnitSphereSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/UnitSphereSampler.java @@ -34,7 +34,7 @@ import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSa * * @since 1.1 */ -public class UnitSphereSampler { +public class UnitSphereSampler implements SharedStateSampler<UnitSphereSampler> { /** Sampler used for generating the individual components of the vectors. */ private final NormalizedGaussianSampler sampler; /** Space dimension. */ @@ -57,6 +57,17 @@ public class UnitSphereSampler { } /** + * @param rng Generator for the individual components of the vectors. + * @param source Source to copy. + */ + private UnitSphereSampler(UniformRandomProvider rng, + UnitSphereSampler source) { + // The Gaussian sampler has no shared state so create a new instance + sampler = new ZigguratNormalizedGaussianSampler(rng); + dimension = source.dimension; + } + + /** * @return a random normalized Cartesian vector. */ public double[] nextVector() { @@ -87,4 +98,10 @@ public class UnitSphereSampler { return v; } + + /** {@inheritDoc} */ + @Override + public UnitSphereSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new UnitSphereSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java index 3f8a2e7..2d93c28 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampling from an <a href="http://mathworld.wolfram.com/ExponentialDistribution.html">exponential distribution</a>. @@ -27,7 +28,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class AhrensDieterExponentialSampler extends SamplerBase - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<AhrensDieterExponentialSampler> { /** * Table containing the constants * \( q_i = sum_{j=1}^i (\ln 2)^j / j! = \ln 2 + (\ln 2)^2 / 2 + ... + (\ln 2)^i / i! \) @@ -78,6 +79,17 @@ public class AhrensDieterExponentialSampler this.mean = mean; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private AhrensDieterExponentialSampler(UniformRandomProvider rng, + AhrensDieterExponentialSampler source) { + super(null); + this.rng = rng; + this.mean = source.mean; + } + /** {@inheritDoc} */ @Override public double sample() { @@ -124,4 +136,10 @@ public class AhrensDieterExponentialSampler public String toString() { return "Ahrens-Dieter Exponential deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public AhrensDieterExponentialSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new AhrensDieterExponentialSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java index 219c848..08901ff 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampling from the <a href="http://mathworld.wolfram.com/GammaDistribution.html">Gamma distribution</a>. @@ -52,7 +53,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class AhrensDieterMarsagliaTsangGammaSampler extends SamplerBase - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<AhrensDieterMarsagliaTsangGammaSampler> { /** The appropriate gamma sampler for the parameters. */ private final ContinuousSampler delegate; @@ -60,7 +61,7 @@ public class AhrensDieterMarsagliaTsangGammaSampler * Base class for a sampler from the Gamma distribution. */ private abstract static class BaseGammaSampler - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<ContinuousSampler> { /** Underlying source of randomness. */ protected final UniformRandomProvider rng; @@ -89,6 +90,17 @@ public class AhrensDieterMarsagliaTsangGammaSampler this.theta = theta; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + BaseGammaSampler(UniformRandomProvider rng, + BaseGammaSampler source) { + this.rng = rng; + this.alpha = source.alpha; + this.theta = source.theta; + } + /** {@inheritDoc} */ @Override public String toString() { @@ -127,6 +139,17 @@ public class AhrensDieterMarsagliaTsangGammaSampler bGSOptim = 1 + alpha / Math.E; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + AhrensDieterGammaSampler(UniformRandomProvider rng, + AhrensDieterGammaSampler source) { + super(rng, source); + oneOverAlpha = source.oneOverAlpha; + bGSOptim = source.bGSOptim; + } + @Override public double sample() { // [1]: p. 228, Algorithm GS. @@ -158,6 +181,11 @@ public class AhrensDieterMarsagliaTsangGammaSampler // Reject and continue. } } + + @Override + public ContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new AhrensDieterGammaSampler(rng, this); + } } /** @@ -189,14 +217,26 @@ public class AhrensDieterMarsagliaTsangGammaSampler * @throws IllegalArgumentException if {@code alpha <= 0} or {@code theta <= 0} */ MarsagliaTsangGammaSampler(UniformRandomProvider rng, - double alpha, - double theta) { + double alpha, + double theta) { super(rng, alpha, theta); gaussian = new ZigguratNormalizedGaussianSampler(rng); dOptim = alpha - ONE_THIRD; cOptim = ONE_THIRD / Math.sqrt(dOptim); } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + MarsagliaTsangGammaSampler(UniformRandomProvider rng, + MarsagliaTsangGammaSampler source) { + super(rng, source); + gaussian = new ZigguratNormalizedGaussianSampler(rng); + dOptim = source.dOptim; + cOptim = source.cOptim; + } + @Override public double sample() { while (true) { @@ -221,6 +261,11 @@ public class AhrensDieterMarsagliaTsangGammaSampler } } } + + @Override + public ContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new MarsagliaTsangGammaSampler(rng, this); + } } /** @@ -238,6 +283,17 @@ public class AhrensDieterMarsagliaTsangGammaSampler new MarsagliaTsangGammaSampler(rng, alpha, theta); } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + @SuppressWarnings("unchecked") + private AhrensDieterMarsagliaTsangGammaSampler(UniformRandomProvider rng, + AhrensDieterMarsagliaTsangGammaSampler source) { + super(null); + delegate = ((SharedStateSampler<ContinuousSampler>)(source.delegate)).withUniformRandomProvider(rng); + } + /** {@inheritDoc} */ @Override public double sample() { @@ -249,4 +305,10 @@ public class AhrensDieterMarsagliaTsangGammaSampler public String toString() { return delegate.toString(); } + + /** {@inheritDoc} */ + @Override + public AhrensDieterMarsagliaTsangGammaSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new AhrensDieterMarsagliaTsangGammaSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.java index ac93c9b..c5018e0 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; import java.util.Arrays; @@ -69,7 +70,8 @@ import java.util.Arrays; * @see <a href="https://ieeexplore.ieee.org/document/92917">Vose (1991) IEEE Transactions * on Software Engineering 17, 972-975.</a> */ -public class AliasMethodDiscreteSampler implements DiscreteSampler { +public class AliasMethodDiscreteSampler + implements DiscreteSampler, SharedStateSampler<AliasMethodDiscreteSampler> { /** * The default alpha factor for zero-padding an input probability table. The default * value will pad the probabilities by to the next power-of-2. @@ -182,8 +184,8 @@ public class AliasMethodDiscreteSampler implements DiscreteSampler { * @param alias Alias table. */ SmallTableAliasMethodDiscreteSampler(final UniformRandomProvider rng, - final long[] probability, - final int[] alias) { + final long[] probability, + final int[] alias) { super(rng, probability, alias); // Assume the table size is a power of 2 and create the mask mask = alias.length - 1; @@ -208,6 +210,12 @@ public class AliasMethodDiscreteSampler implements DiscreteSampler { // Choose between the two. Use a 53-bit long for the probability. return (longBits >>> 11) < probability[j] ? j : alias[j]; } + + /** {@inheritDoc} */ + @Override + public SmallTableAliasMethodDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new SmallTableAliasMethodDiscreteSampler(rng, probability, alias); + } } /** @@ -270,6 +278,12 @@ public class AliasMethodDiscreteSampler implements DiscreteSampler { return "Alias method [" + rng.toString() + "]"; } + /** {@inheritDoc} */ + @Override + public AliasMethodDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new AliasMethodDiscreteSampler(rng, probability, alias); + } + /** * Creates a sampler. * diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.java index 616e768..5910060 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * <a href="https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform"> @@ -28,7 +29,7 @@ import org.apache.commons.rng.UniformRandomProvider; * @since 1.1 */ public class BoxMullerNormalizedGaussianSampler - implements NormalizedGaussianSampler { + implements NormalizedGaussianSampler, SharedStateSampler<BoxMullerNormalizedGaussianSampler> { /** Next gaussian. */ private double nextGaussian = Double.NaN; /** Underlying source of randomness. */ @@ -75,4 +76,10 @@ public class BoxMullerNormalizedGaussianSampler public String toString() { return "Box-Muller normalized Gaussian deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public BoxMullerNormalizedGaussianSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new BoxMullerNormalizedGaussianSampler(rng); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java index 5a64bf4..682ae9c 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Utility class implementing Cheng's algorithms for beta distribution sampling. @@ -35,7 +36,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class ChengBetaSampler extends SamplerBase - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<ChengBetaSampler> { /** 1/2. */ private static final double ONE_HALF = 1d / 2; /** 1/4. */ @@ -71,6 +72,18 @@ public class ChengBetaSampler betaShape = beta; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private ChengBetaSampler(UniformRandomProvider rng, + ChengBetaSampler source) { + super(null); + this.rng = rng; + alphaShape = source.alphaShape; + betaShape = source.betaShape; + } + /** {@inheritDoc} */ @Override public double sample() { @@ -90,6 +103,12 @@ public class ChengBetaSampler return "Cheng Beta deviate [" + rng.toString() + "]"; } + /** {@inheritDoc} */ + @Override + public ChengBetaSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new ChengBetaSampler(rng, this); + } + /** * Computes one sample using Cheng's BB algorithm, when \( \alpha \) and * \( \beta \) are both larger than 1. diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSampler.java index e7d749e..2a3a9ea 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampling from a uniform distribution. @@ -27,7 +28,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class ContinuousUniformSampler extends SamplerBase - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<ContinuousUniformSampler> { /** Lower bound. */ private final double lo; /** Higher bound. */ @@ -61,4 +62,10 @@ public class ContinuousUniformSampler public String toString() { return "Uniform deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public ContinuousUniformSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new ContinuousUniformSampler(rng, lo, hi); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java index 2b5ea52..1809c4c 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java @@ -18,6 +18,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Discrete uniform distribution sampler. @@ -30,7 +31,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class DiscreteUniformSampler extends SamplerBase - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<DiscreteUniformSampler> { /** The appropriate uniform sampler for the parameters. */ private final DiscreteSampler delegate; @@ -39,7 +40,7 @@ public class DiscreteUniformSampler * Base class for a sampler from a discrete uniform distribution. */ private abstract static class AbstractDiscreteUniformSampler - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<DiscreteSampler> { /** Underlying source of randomness. */ protected final UniformRandomProvider rng; @@ -56,7 +57,6 @@ public class DiscreteUniformSampler this.lower = lower; } - /** {@inheritDoc} */ @Override public String toString() { return "Uniform deviate [" + rng.toString() + "]"; @@ -89,6 +89,11 @@ public class DiscreteUniformSampler public int sample() { return lower + rng.nextInt(range); } + + @Override + public DiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new SmallRangeDiscreteUniformSampler(rng, lower, range); + } } /** @@ -127,6 +132,11 @@ public class DiscreteUniformSampler } } } + + @Override + public DiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new LargeRangeDiscreteUniformSampler(rng, lower, upper); + } } /** @@ -152,6 +162,17 @@ public class DiscreteUniformSampler new SmallRangeDiscreteUniformSampler(rng, lower, range); } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + @SuppressWarnings("unchecked") + private DiscreteUniformSampler(UniformRandomProvider rng, + DiscreteUniformSampler source) { + super(null); + delegate = ((SharedStateSampler<DiscreteSampler>)(source.delegate)).withUniformRandomProvider(rng); + } + /** {@inheritDoc} */ @Override public int sample() { @@ -163,4 +184,10 @@ public class DiscreteUniformSampler public String toString() { return delegate.toString(); } + + /** {@inheritDoc} */ + @Override + public DiscreteUniformSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new DiscreteUniformSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java index 9c99742..e8cb095 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java @@ -16,13 +16,16 @@ */ package org.apache.commons.rng.sampling.distribution; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; + /** * Sampling from a Gaussian distribution with given mean and * standard deviation. * * @since 1.1 */ -public class GaussianSampler implements ContinuousSampler { +public class GaussianSampler implements ContinuousSampler, SharedStateSampler<GaussianSampler> { /** Mean. */ private final double mean; /** standardDeviation. */ @@ -48,6 +51,21 @@ public class GaussianSampler implements ContinuousSampler { this.standardDeviation = standardDeviation; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private GaussianSampler(UniformRandomProvider rng, + GaussianSampler source) { + if (!(source.normalized instanceof SharedStateSampler<?>)) { + throw new UnsupportedOperationException("The underlying sampler is not a SharedStateSampler"); + } + this.mean = source.mean; + this.standardDeviation = source.standardDeviation; + this.normalized = (NormalizedGaussianSampler) + ((SharedStateSampler<?>)source.normalized).withUniformRandomProvider(rng); + } + /** {@inheritDoc} */ @Override public double sample() { @@ -59,4 +77,19 @@ public class GaussianSampler implements ContinuousSampler { public String toString() { return "Gaussian deviate [" + normalized.toString() + "]"; } + + /** + * {@inheritDoc} + * + * <p>Note: This function is available if the underlying {@link NormalizedGaussianSampler} + * is a {@link SharedStateSampler}. Otherwise a run-time exception is thrown.</p> + * + * @throws UnsupportedOperationException if the underlying sampler is not a {@link SharedStateSampler}. + * @throws ClassCastException if the underlying {@link SharedStateSampler} does not return a + * {@link NormalizedGaussianSampler}. + */ + @Override + public GaussianSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new GaussianSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GeometricSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GeometricSampler.java index f4683ac..f1e3470 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GeometricSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GeometricSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampling from a <a href="https://en.wikipedia.org/wiki/Geometric_distribution">geometric @@ -45,14 +46,15 @@ import org.apache.commons.rng.UniformRandomProvider; * * @since 1.3 */ -public class GeometricSampler implements DiscreteSampler { +public class GeometricSampler implements DiscreteSampler, SharedStateSampler<GeometricSampler> { /** The appropriate geometric sampler for the parameters. */ private final DiscreteSampler delegate; /** * Sample from the geometric distribution when the probability of success is 1. */ - private static class GeometricP1Sampler implements DiscreteSampler { + private static class GeometricP1Sampler + implements DiscreteSampler, SharedStateSampler<DiscreteSampler> { /** The single instance. */ static final GeometricP1Sampler INSTANCE = new GeometricP1Sampler(); @@ -62,17 +64,23 @@ public class GeometricSampler implements DiscreteSampler { return 0; } - /** {@inheritDoc} */ @Override public String toString() { return "Geometric(p=1) deviate"; } + + @Override + public DiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + // No requirement for a new instance + return this; + } } /** * Sample from the geometric distribution by using a related exponential distribution. */ - private static class GeometricExponentialSampler implements DiscreteSampler { + private static class GeometricExponentialSampler + implements DiscreteSampler, SharedStateSampler<DiscreteSampler> { /** Underlying source of randomness. Used only for the {@link #toString()} method. */ private final UniformRandomProvider rng; /** The related exponential sampler for the geometric distribution. */ @@ -99,17 +107,30 @@ public class GeometricSampler implements DiscreteSampler { exponentialSampler = new AhrensDieterExponentialSampler(rng, exponentialMean); } + /** + * @param rng Generator of uniformly distributed random numbers + * @param source Source to copy. + */ + GeometricExponentialSampler(UniformRandomProvider rng, GeometricExponentialSampler source) { + this.rng = rng; + exponentialSampler = source.exponentialSampler.withUniformRandomProvider(rng); + } + @Override public int sample() { // Return the floor of the exponential sample return (int) Math.floor(exponentialSampler.sample()); } - /** {@inheritDoc} */ @Override public String toString() { return "Geometric deviate [" + rng.toString() + "]"; } + + @Override + public DiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new GeometricExponentialSampler(rng, this); + } } /** @@ -134,6 +155,15 @@ public class GeometricSampler implements DiscreteSampler { } /** + * @param rng Generator of uniformly distributed random numbers + * @param source Source to copy. + */ + @SuppressWarnings("unchecked") + private GeometricSampler(UniformRandomProvider rng, GeometricSampler source) { + delegate = ((SharedStateSampler<DiscreteSampler>)(source.delegate)).withUniformRandomProvider(rng); + } + + /** * Create a sample from a geometric distribution. * * <p>The sample will take the values in the set {@code [0, 1, 2, ...]}, equivalent to the @@ -149,4 +179,10 @@ public class GeometricSampler implements DiscreteSampler { public String toString() { return delegate.toString(); } + + /** {@inheritDoc} */ + @Override + public GeometricSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new GeometricSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java index beb7a5f..cb02cbb 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Compute a sample from a discrete probability distribution. The cumulative probability @@ -40,7 +41,7 @@ import org.apache.commons.rng.UniformRandomProvider; * @since 1.3 */ public class GuideTableDiscreteSampler - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<GuideTableDiscreteSampler> { /** The default value for {@code alpha}. */ private static final double DEFAULT_ALPHA = 1.0; /** Underlying source of randomness. */ @@ -142,6 +143,17 @@ public class GuideTableDiscreteSampler } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private GuideTableDiscreteSampler(UniformRandomProvider rng, + GuideTableDiscreteSampler source) { + this.rng = rng; + cumulativeProbabilities = source.cumulativeProbabilities; + guideTable = source.guideTable; + } + + /** * Validate the parameters. * * @param probabilities The probabilities. @@ -198,4 +210,10 @@ public class GuideTableDiscreteSampler public String toString() { return "Guide table deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public GuideTableDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new GuideTableDiscreteSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSampler.java index 5ba32fd..f60e1d0 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Distribution sampler that uses the @@ -57,7 +58,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class InverseTransformContinuousSampler extends SamplerBase - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<InverseTransformContinuousSampler> { /** Inverse cumulative probability function. */ private final ContinuousInverseCumulativeProbabilityFunction function; /** Underlying source of randomness. */ @@ -85,4 +86,15 @@ public class InverseTransformContinuousSampler public String toString() { return function.toString() + " (inverse method) [" + rng.toString() + "]"; } + + /** + * {@inheritDoc} + * + * <p>Note: The new sampler will share the inverse cumulative probability function. This + * must be suitable for concurrent use to ensure thread safety.</p> + */ + @Override + public InverseTransformContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new InverseTransformContinuousSampler(rng, function); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSampler.java index 276fb8e..e10e2c4 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Distribution sampler that uses the @@ -57,7 +58,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class InverseTransformDiscreteSampler extends SamplerBase - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<InverseTransformDiscreteSampler> { /** Inverse cumulative probability function. */ private final DiscreteInverseCumulativeProbabilityFunction function; /** Underlying source of randomness. */ @@ -85,4 +86,15 @@ public class InverseTransformDiscreteSampler public String toString() { return function.toString() + " (inverse method) [" + rng.toString() + "]"; } + + /** + * {@inheritDoc} + * + * <p>Note: The new sampler will share the inverse cumulative probability function. This + * must be suitable for concurrent use to ensure thread safety.</p> + */ + @Override + public InverseTransformDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new InverseTransformDiscreteSampler(rng, function); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java index 69f7a11..4068b49 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampling from a <a href="https://en.wikipedia.org/wiki/Pareto_distribution">Pareto distribution</a>. @@ -27,7 +28,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class InverseTransformParetoSampler extends SamplerBase - implements ContinuousSampler { + implements ContinuousSampler, SharedStateSampler<InverseTransformParetoSampler> { /** Scale. */ private final double scale; /** 1 / Shape. */ @@ -56,6 +57,18 @@ public class InverseTransformParetoSampler this.oneOverShape = 1 / shape; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private InverseTransformParetoSampler(UniformRandomProvider rng, + InverseTransformParetoSampler source) { + super(null); + this.rng = rng; + scale = source.scale; + oneOverShape = source.oneOverShape; + } + /** {@inheritDoc} */ @Override public double sample() { @@ -67,4 +80,10 @@ public class InverseTransformParetoSampler public String toString() { return "[Inverse method for Pareto distribution " + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public InverseTransformParetoSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new InverseTransformParetoSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java index b0692b0..b3203d3 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson @@ -46,7 +47,7 @@ import org.apache.commons.rng.UniformRandomProvider; * 249-253</a> */ public class KempSmallMeanPoissonSampler - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<KempSmallMeanPoissonSampler> { /** Underlying source of randomness. */ private final UniformRandomProvider rng; /** @@ -81,6 +82,17 @@ public class KempSmallMeanPoissonSampler } } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private KempSmallMeanPoissonSampler(UniformRandomProvider rng, + KempSmallMeanPoissonSampler source) { + this.rng = rng; + p0 = source.p0; + mean = source.mean; + } + /** {@inheritDoc} */ @Override public int sample() { @@ -114,4 +126,10 @@ public class KempSmallMeanPoissonSampler public String toString() { return "Kemp Small Mean Poisson deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public KempSmallMeanPoissonSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new KempSmallMeanPoissonSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java index 3421db9..ce8d9ae 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog; /** @@ -44,13 +45,13 @@ import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog; * @since 1.1 */ public class LargeMeanPoissonSampler - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<LargeMeanPoissonSampler> { /** Upper bound to avoid truncation. */ private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE; /** Class to compute {@code log(n!)}. This has no cached values. */ private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG; /** Used when there is no requirement for a small mean Poisson sampler. */ - private static final DiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = null; + private static final KempSmallMeanPoissonSampler NO_SMALL_MEAN_POISSON_SAMPLER = null; static { // Create without a cache. @@ -60,7 +61,7 @@ public class LargeMeanPoissonSampler /** Underlying source of randomness. */ private final UniformRandomProvider rng; /** Exponential. */ - private final ContinuousSampler exponential; + private final AhrensDieterExponentialSampler exponential; /** Gaussian. */ private final ContinuousSampler gaussian; /** Local class to compute {@code log(n!)}. This may have cached values. */ @@ -100,7 +101,7 @@ public class LargeMeanPoissonSampler private final double c1; /** The internal Poisson sampler for the lambda fraction. */ - private final DiscreteSampler smallMeanPoissonSampler; + private final KempSmallMeanPoissonSampler smallMeanPoissonSampler; /** * @param rng Generator of uniformly distributed random numbers. @@ -186,6 +187,36 @@ public class LargeMeanPoissonSampler new KempSmallMeanPoissonSampler(rng, lambdaFractional); } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private LargeMeanPoissonSampler(UniformRandomProvider rng, + LargeMeanPoissonSampler source) { + this.rng = rng; + + // The Gaussian sampler has no shared state + gaussian = new ZigguratNormalizedGaussianSampler(rng); + exponential = source.exponential.withUniformRandomProvider(rng); + // Reuse the cache + factorialLog = source.factorialLog; + + lambda = source.lambda; + logLambda = source.logLambda; + logLambdaFactorial = source.logLambdaFactorial; + delta = source.delta; + halfDelta = source.halfDelta; + twolpd = source.twolpd; + p1 = source.p1; + p2 = source.p2; + c1 = source.c1; + + // Share the state of the small sampler + smallMeanPoissonSampler = source.smallMeanPoissonSampler == null ? + NO_SMALL_MEAN_POISSON_SAMPLER : // Not used. + source.smallMeanPoissonSampler.withUniformRandomProvider(rng); + } + /** {@inheritDoc} */ @Override public int sample() { @@ -262,6 +293,12 @@ public class LargeMeanPoissonSampler return "Large Mean Poisson deviate [" + rng.toString() + "]"; } + /** {@inheritDoc} */ + @Override + public LargeMeanPoissonSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new LargeMeanPoissonSampler(rng, this); + } + /** * Gets the initialisation state of the sampler. * diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java index 258f2f5..3879cbc 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java @@ -16,12 +16,15 @@ */ package org.apache.commons.rng.sampling.distribution; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; + /** * Sampling from a log-normal distribution. * * @since 1.1 */ -public class LogNormalSampler implements ContinuousSampler { +public class LogNormalSampler implements ContinuousSampler, SharedStateSampler<LogNormalSampler> { /** Scale. */ private final double scale; /** Shape. */ @@ -49,6 +52,21 @@ public class LogNormalSampler implements ContinuousSampler { this.gaussian = gaussian; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private LogNormalSampler(UniformRandomProvider rng, + LogNormalSampler source) { + if (!(source.gaussian instanceof SharedStateSampler<?>)) { + throw new UnsupportedOperationException("The underlying sampler is not a SharedStateSampler"); + } + this.scale = source.scale; + this.shape = source.shape; + this.gaussian = (NormalizedGaussianSampler) + ((SharedStateSampler<?>)source.gaussian).withUniformRandomProvider(rng); + } + /** {@inheritDoc} */ @Override public double sample() { @@ -60,4 +78,19 @@ public class LogNormalSampler implements ContinuousSampler { public String toString() { return "Log-normal deviate [" + gaussian.toString() + "]"; } + + /** + * {@inheritDoc} + * + * <p>Note: This function is available if the underlying {@link NormalizedGaussianSampler} + * is a {@link SharedStateSampler}. Otherwise a run-time exception is thrown.</p> + * + * @throws UnsupportedOperationException if the underlying sampler is not a {@link SharedStateSampler}. + * @throws ClassCastException if the underlying {@link SharedStateSampler} does not return a + * {@link NormalizedGaussianSampler}. + */ + @Override + public LogNormalSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new LogNormalSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.java index af50b55..421288a 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * <a href="https://en.wikipedia.org/wiki/Marsaglia_polar_method"> @@ -30,7 +31,7 @@ import org.apache.commons.rng.UniformRandomProvider; * @since 1.1 */ public class MarsagliaNormalizedGaussianSampler - implements NormalizedGaussianSampler { + implements NormalizedGaussianSampler, SharedStateSampler<MarsagliaNormalizedGaussianSampler> { /** Next gaussian. */ private double nextGaussian = Double.NaN; /** Underlying source of randomness. */ @@ -84,4 +85,10 @@ public class MarsagliaNormalizedGaussianSampler public String toString() { return "Box-Muller (with rejection) normalized Gaussian deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public MarsagliaNormalizedGaussianSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new MarsagliaNormalizedGaussianSampler(rng); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java index 55b82d5..c7528f4 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampler for a discrete distribution using an optimised look-up table. @@ -49,7 +50,8 @@ import org.apache.commons.rng.UniformRandomProvider; * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol. * 11, Issue 3</a> */ -public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampler { +public abstract class MarsagliaTsangWangDiscreteSampler + implements DiscreteSampler, SharedStateSampler<MarsagliaTsangWangDiscreteSampler> { /** The value 2<sup>8</sup> as an {@code int}. */ private static final int INT_8 = 1 << 8; /** The value 2<sup>16</sup> as an {@code int}. */ @@ -193,6 +195,24 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private MarsagliaTsangWangBase64Int8DiscreteSampler(UniformRandomProvider rng, + MarsagliaTsangWangBase64Int8DiscreteSampler source) { + super(rng, source); + t1 = source.t1; + t2 = source.t2; + t3 = source.t3; + t4 = source.t4; + table1 = source.table1; + table2 = source.table2; + table3 = source.table3; + table4 = source.table4; + table5 = source.table5; + } + + /** * Fill the table with the value. * * @param table Table. @@ -206,7 +226,6 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl } } - /** {@inheritDoc} */ @Override public int sample() { final int j = rng.nextInt() >>> 2; @@ -227,6 +246,11 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl // difference. So the tables *must* be constructed correctly. return table5[j - t4] & MASK; } + + @Override + public MarsagliaTsangWangDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new MarsagliaTsangWangBase64Int8DiscreteSampler(rng, this); + } } /** @@ -311,6 +335,24 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private MarsagliaTsangWangBase64Int16DiscreteSampler(UniformRandomProvider rng, + MarsagliaTsangWangBase64Int16DiscreteSampler source) { + super(rng, source); + t1 = source.t1; + t2 = source.t2; + t3 = source.t3; + t4 = source.t4; + table1 = source.table1; + table2 = source.table2; + table3 = source.table3; + table4 = source.table4; + table5 = source.table5; + } + + /** * Fill the table with the value. * * @param table Table. @@ -324,7 +366,6 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl } } - /** {@inheritDoc} */ @Override public int sample() { final int j = rng.nextInt() >>> 2; @@ -345,6 +386,11 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl // difference. So the tables *must* be constructed correctly. return table5[j - t4] & MASK; } + + @Override + public MarsagliaTsangWangDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new MarsagliaTsangWangBase64Int16DiscreteSampler(rng, this); + } } /** @@ -353,7 +399,6 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl */ private static class MarsagliaTsangWangBase64Int32DiscreteSampler extends MarsagliaTsangWangDiscreteSampler { - /** Limit for look-up table 1. */ private final int t1; /** Limit for look-up table 2. */ @@ -426,6 +471,24 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private MarsagliaTsangWangBase64Int32DiscreteSampler(UniformRandomProvider rng, + MarsagliaTsangWangBase64Int32DiscreteSampler source) { + super(rng, source); + t1 = source.t1; + t2 = source.t2; + t3 = source.t3; + t4 = source.t4; + table1 = source.table1; + table2 = source.table2; + table3 = source.table3; + table4 = source.table4; + table5 = source.table5; + } + + /** * Fill the table with the value. * * @param table Table. @@ -439,7 +502,6 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl } } - /** {@inheritDoc} */ @Override public int sample() { final int j = rng.nextInt() >>> 2; @@ -460,6 +522,11 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl // difference. So the tables *must* be constructed correctly. return table5[j - t4]; } + + @Override + public MarsagliaTsangWangDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new MarsagliaTsangWangBase64Int32DiscreteSampler(rng, this); + } } /** @@ -484,11 +551,16 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl return result; } - /** {@inheritDoc} */ @Override public String toString() { return BINOMIAL_NAME + " deviate"; } + + @Override + public MarsagliaTsangWangDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + // No shared state + return this; + } } /** @@ -522,11 +594,16 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl return trials - sampler.sample(); } - /** {@inheritDoc} */ @Override public String toString() { return sampler.toString(); } + + @Override + public MarsagliaTsangWangDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new MarsagliaTsangWangInversionBinomialSampler(this.trials, + this.sampler.withUniformRandomProvider(rng)); + } } /** @@ -539,6 +616,16 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl this.distributionName = distributionName; } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + MarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, + MarsagliaTsangWangDiscreteSampler source) { + this.rng = rng; + this.distributionName = source.distributionName; + } + /** {@inheritDoc} */ @Override public String toString() { diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java index 56472ad..804a9e3 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>. @@ -51,7 +52,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class PoissonSampler extends SamplerBase - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<PoissonSampler> { /** * Value for switching sampling algorithm. @@ -79,6 +80,23 @@ public class PoissonSampler new LargeMeanPoissonSampler(rng, mean); } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private PoissonSampler(UniformRandomProvider rng, + PoissonSampler source) { + super(null); + + if (source.poissonSamplerDelegate instanceof SmallMeanPoissonSampler) { + poissonSamplerDelegate = + ((SmallMeanPoissonSampler)source.poissonSamplerDelegate).withUniformRandomProvider(rng); + } else { + poissonSamplerDelegate = + ((LargeMeanPoissonSampler)source.poissonSamplerDelegate).withUniformRandomProvider(rng); + } + } + /** {@inheritDoc} */ @Override public int sample() { @@ -90,4 +108,10 @@ public class PoissonSampler public String toString() { return poissonSamplerDelegate.toString(); } + + /** {@inheritDoc} */ + @Override + public PoissonSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new PoissonSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java index 3a477e3..112f3fe 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java @@ -18,6 +18,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Implementation of the <a href="https://en.wikipedia.org/wiki/Zipf's_law">Zipf distribution</a>. @@ -28,7 +29,7 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class RejectionInversionZipfSampler extends SamplerBase - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<RejectionInversionZipfSampler> { /** Threshold below which Taylor series will be used. */ private static final double TAYLOR_THRESHOLD = 1e-8; /** 1/2. */ @@ -77,6 +78,21 @@ public class RejectionInversionZipfSampler } /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private RejectionInversionZipfSampler(UniformRandomProvider rng, + RejectionInversionZipfSampler source) { + super(null); + this.rng = rng; + this.numberOfElements = source.numberOfElements; + this.exponent = source.exponent; + this.hIntegralX1 = source.hIntegralX1; + this.hIntegralNumberOfElements = source.hIntegralNumberOfElements; + this.s = source.s; + } + + /** * Rejection inversion sampling method for a discrete, bounded Zipf * distribution that is based on the method described in * <blockquote> @@ -177,6 +193,12 @@ public class RejectionInversionZipfSampler return "Rejection inversion Zipf deviate [" + rng.toString() + "]"; } + /** {@inheritDoc} */ + @Override + public RejectionInversionZipfSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new RejectionInversionZipfSampler(rng, this); + } + /** * {@code H(x)} is defined as * <ul> diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java index 22b7864..e6a9722 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>. @@ -42,7 +43,7 @@ import org.apache.commons.rng.UniformRandomProvider; * @since 1.1 */ public class SmallMeanPoissonSampler - implements DiscreteSampler { + implements DiscreteSampler, SharedStateSampler<SmallMeanPoissonSampler> { /** * Pre-compute {@code Math.exp(-mean)}. * Note: This is the probability of the Poisson sample {@code P(n=0)}. @@ -74,6 +75,17 @@ public class SmallMeanPoissonSampler } } + /** + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private SmallMeanPoissonSampler(UniformRandomProvider rng, + SmallMeanPoissonSampler source) { + this.rng = rng; + p0 = source.p0; + limit = source.limit; + } + /** {@inheritDoc} */ @Override public int sample() { @@ -96,4 +108,10 @@ public class SmallMeanPoissonSampler public String toString() { return "Small Mean Poisson deviate [" + rng.toString() + "]"; } + + /** {@inheritDoc} */ + @Override + public SmallMeanPoissonSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new SmallMeanPoissonSampler(rng, this); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.java index bb9b52a..a1ba383 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.java @@ -18,6 +18,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.SharedStateSampler; /** * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm"> @@ -38,7 +39,7 @@ import org.apache.commons.rng.UniformRandomProvider; * @since 1.1 */ public class ZigguratNormalizedGaussianSampler - implements NormalizedGaussianSampler { + implements NormalizedGaussianSampler, SharedStateSampler<ZigguratNormalizedGaussianSampler> { /** Start of tail. */ private static final double R = 3.442619855899; /** Inverse of R. */ @@ -160,4 +161,10 @@ public class ZigguratNormalizedGaussianSampler private static double gauss(double x) { return Math.exp(-0.5 * x * x); } + + /** {@inheritDoc} */ + @Override + public ZigguratNormalizedGaussianSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new ZigguratNormalizedGaussianSampler(rng); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java index 3ea70c1..9d83631 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CollectionSamplerTest.java @@ -17,10 +17,12 @@ package org.apache.commons.rng.sampling; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.junit.Assert; import org.junit.Test; - +import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.simple.RandomSource; /** @@ -53,4 +55,30 @@ public class CollectionSamplerTest { new CollectionSampler<String>(RandomSource.create(RandomSource.MT), new ArrayList<String>()); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final List<String> list = Arrays.asList("Apache", "Commons", "RNG"); + final CollectionSampler<String> sampler1 = + new CollectionSampler<String>(rng1, list); + final CollectionSampler<String> sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence( + new RandomAssert.Sampler<String>() { + @Override + public String sample() { + return sampler1.sample(); + } + }, + new RandomAssert.Sampler<String>() { + @Override + public String sample() { + return sampler2.sample(); + } + }); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CombinationSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CombinationSamplerTest.java index 3472a89..928ff1b 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CombinationSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/CombinationSamplerTest.java @@ -113,6 +113,33 @@ public class CombinationSamplerTest { new CombinationSampler(rng, n, k); } + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final int n = 17; + final int k = 3; + final CombinationSampler sampler1 = + new CombinationSampler(rng1, n, k); + final CombinationSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence( + new RandomAssert.Sampler<int[]>() { + @Override + public int[] sample() { + return sampler1.sample(); + } + }, + new RandomAssert.Sampler<int[]>() { + @Override + public int[] sample() { + return sampler2.sample(); + } + }); + } + //// Support methods. /** diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java index cae03cf..78a4391 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java @@ -152,4 +152,32 @@ public class DiscreteProbabilityCollectionSamplerTest { // Test the two samples are different items Assert.assertNotSame("Item1 and 2 should be different", item1, item2); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final List<Double> items = Arrays.asList(new Double[] {1d, 2d, 3d, 4d}); + final DiscreteProbabilityCollectionSampler<Double> sampler1 = + new DiscreteProbabilityCollectionSampler<Double>(rng1, + items, + new double[] {0.1, 0.2, 0.3, 04}); + final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence( + new RandomAssert.Sampler<Double>() { + @Override + public Double sample() { + return sampler1.sample(); + } + }, + new RandomAssert.Sampler<Double>() { + @Override + public Double sample() { + return sampler2.sample(); + } + }); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java index 21c9c42..9b88fce 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java @@ -186,6 +186,33 @@ public class PermutationSamplerTest { Assert.assertTrue(ok); } + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final int n = 17; + final int k = 13; + final PermutationSampler sampler1 = + new PermutationSampler(rng1, n, k); + final PermutationSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence( + new RandomAssert.Sampler<int[]>() { + @Override + public int[] sample() { + return sampler1.sample(); + } + }, + new RandomAssert.Sampler<int[]>() { + @Override + public int[] sample() { + return sampler2.sample(); + } + }); + } + //// Support methods. private void runSampleChiSquareTest(int n, diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/RandomAssert.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/RandomAssert.java index 113471e..5860501 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/RandomAssert.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/RandomAssert.java @@ -97,7 +97,8 @@ public final class RandomAssert { final T value1 = sampler1.sample(); final T value2 = sampler2.sample(); if (isArray(value1) && isArray(value2)) { - Assert.assertArrayEquals((Object[]) value1, (Object[]) value2); + // JUnit assertArrayEquals will handle nested primitive arrays + Assert.assertArrayEquals(new Object[] {value1}, new Object[] {value2}); } else { Assert.assertEquals(value1, value2); } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/UnitSphereSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/UnitSphereSamplerTest.java index 311f87e..3c9df0b 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/UnitSphereSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/UnitSphereSamplerTest.java @@ -115,6 +115,32 @@ public class UnitSphereSamplerTest { } /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final int n = 3; + final UnitSphereSampler sampler1 = + new UnitSphereSampler(n, rng1); + final UnitSphereSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence( + new RandomAssert.Sampler<double[]>() { + @Override + public double[] sample() { + return sampler1.nextVector(); + } + }, + new RandomAssert.Sampler<double[]>() { + @Override + public double[] sample() { + return sampler2.nextVector(); + } + }); + } + + /** * @return the length (L2-norm) of given vector. */ private static double length(double[] vector) { diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSamplerTest.java index 1dceaeb..739a178 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSamplerTest.java @@ -17,6 +17,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -36,4 +38,18 @@ public class AhrensDieterExponentialSamplerTest { final AhrensDieterExponentialSampler sampler = new AhrensDieterExponentialSampler(rng, mean); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double mean = 1.23; + final AhrensDieterExponentialSampler sampler1 = + new AhrensDieterExponentialSampler(rng1, mean); + final AhrensDieterExponentialSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSamplerTest.java index eee689b..11ba878 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSamplerTest.java @@ -17,6 +17,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -51,4 +53,35 @@ public class AhrensDieterMarsagliaTsangGammaSamplerTest { final AhrensDieterMarsagliaTsangGammaSampler sampler = new AhrensDieterMarsagliaTsangGammaSampler(rng, alpha, theta); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSamplerWithAlphaBelowOne() { + testSharedStateSampler(0.5, 3.456); + } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSamplerWithAlphaAboveOne() { + testSharedStateSampler(3.5, 3.456); + } + + /** + * Test the SharedStateSampler implementation. + * + * @param alpha Alpha. + * @param theta Theta. + */ + private static void testSharedStateSampler(double alpha, double theta) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final AhrensDieterMarsagliaTsangGammaSampler sampler1 = + new AhrensDieterMarsagliaTsangGammaSampler(rng1, alpha, theta); + final AhrensDieterMarsagliaTsangGammaSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSamplerTest.java index 19f5451..e162f19 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/AliasMethodDiscreteSamplerTest.java @@ -20,6 +20,7 @@ import org.apache.commons.math3.distribution.BinomialDistribution; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.commons.math3.stat.inference.ChiSquareTest; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -248,4 +249,35 @@ public class AliasMethodDiscreteSamplerTest { // Pass if we cannot reject null hypothesis that the distributions are the same. Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); } + + /** + * Test the SharedStateSampler implementation for the specialised power-of-2 table size. + */ + @Test + public void testSharedStateSamplerWithPowerOf2TableSize() { + testSharedStateSampler(new double[] {0.1, 0.2, 0.3, 0.4}); + } + + /** + * Test the SharedStateSampler implementation for the generic non power-of-2 table size. + */ + @Test + public void testSharedStateSamplerWithNonPowerOf2TableSize() { + testSharedStateSampler(new double[] {0.1, 0.2, 0.3}); + } + + /** + * Test the SharedStateSampler implementation. + * + * @param probabilities The probabilities + */ + private static void testSharedStateSampler(double[] probabilities) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + // Use negative alpha to disable padding + final AliasMethodDiscreteSampler sampler1 = + AliasMethodDiscreteSampler.create(rng1, probabilities, -1); + final AliasMethodDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalisedGaussianSamplerTest.java similarity index 57% copy from commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java copy to commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalisedGaussianSamplerTest.java index a0ab018..191cbc5 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/BoxMullerNormalisedGaussianSamplerTest.java @@ -17,22 +17,24 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; /** - * Test for the {@link DiscreteUniformSampler}. The tests hit edge cases for the sampler. + * Test for the {@link BoxMullerNormalizedGaussianSampler}. */ -public class DiscreteUniformSamplerTest { +public class BoxMullerNormalisedGaussianSamplerTest { /** - * Test the constructor with a bad range. + * Test the SharedStateSampler implementation. */ - @Test(expected = IllegalArgumentException.class) - public void testConstructorThrowsWithLowerAboveUpper() { - final int upper = 55; - final int lower = upper + 1; - final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - @SuppressWarnings("unused") - DiscreteUniformSampler sampler = new DiscreteUniformSampler(rng, lower, upper); + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final BoxMullerNormalizedGaussianSampler sampler1 = + new BoxMullerNormalizedGaussianSampler(rng1); + final BoxMullerNormalizedGaussianSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ChengBetaSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ChengBetaSamplerTest.java index ea1f6ef..24325b0 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ChengBetaSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ChengBetaSamplerTest.java @@ -17,6 +17,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -51,4 +53,19 @@ public class ChengBetaSamplerTest { final ChengBetaSampler sampler = new ChengBetaSampler(rng, alpha, beta); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double alpha = 1.23; + final double beta = 4.56; + final ChengBetaSampler sampler1 = + new ChengBetaSampler(rng1, alpha, beta); + final ChengBetaSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSamplerTest.java index ba69c3b..06bf53f 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousUniformSamplerTest.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -47,4 +48,19 @@ public class ContinuousUniformSamplerTest { Assert.assertTrue("Value not in range", value >= min && value <= max); } } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double low = 1.23; + final double high = 4.56; + final ContinuousUniformSampler sampler1 = + new ContinuousUniformSampler(rng1, low, high); + final ContinuousUniformSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java index a0ab018..3a0a34c 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -35,4 +36,35 @@ public class DiscreteUniformSamplerTest { @SuppressWarnings("unused") DiscreteUniformSampler sampler = new DiscreteUniformSampler(rng, lower, upper); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSamplerWithSmallRange() { + testSharedStateSampler(5, 67); + } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSamplerWithLargeRange() { + testSharedStateSampler(-99999999, Integer.MAX_VALUE); + } + + /** + * Test the SharedStateSampler implementation. + * + * @param lower Lower. + * @param upper Upper. + */ + private static void testSharedStateSampler(int lower, int upper) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final DiscreteUniformSampler sampler1 = + new DiscreteUniformSampler(rng1, lower, upper); + final DiscreteUniformSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java index 0cc0a61..63bbf88 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java @@ -17,6 +17,9 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.sampling.SharedStateSampler; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -38,4 +41,75 @@ public class GaussianSamplerTest { final GaussianSampler sampler = new GaussianSampler(gauss, mean, standardDeviation); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng1); + final double mean = 1.23; + final double standardDeviation = 4.56; + final GaussianSampler sampler1 = + new GaussianSampler(gauss, mean, standardDeviation); + final GaussianSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Test the SharedStateSampler implementation throws if the underlying sampler is + * not a SharedStateSampler. + */ + @Test(expected = UnsupportedOperationException.class) + public void testSharedStateSamplerThrowsIfUnderlyingSamplerDoesNotShareState() { + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final NormalizedGaussianSampler gauss = new NormalizedGaussianSampler() { + @Override + public double sample() { + return 0; + } + }; + final double mean = 1.23; + final double standardDeviation = 4.56; + final GaussianSampler sampler1 = + new GaussianSampler(gauss, mean, standardDeviation); + sampler1.withUniformRandomProvider(rng2); + } + + /** + * Test the SharedStateSampler implementation throws if the underlying sampler is + * a SharedStateSampler that returns an incorrect type. + */ + @Test(expected = ClassCastException.class) + public void testSharedStateSamplerThrowsIfUnderlyingSamplerReturnsWrongSharedState() { + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final NormalizedGaussianSampler gauss = new BadSharedStateNormalizedGaussianSampler(); + final double mean = 1.23; + final double standardDeviation = 4.56; + final GaussianSampler sampler1 = + new GaussianSampler(gauss, mean, standardDeviation); + sampler1.withUniformRandomProvider(rng2); + } + + /** + * Test class to return an incorrect sampler from the SharedStateSampler method. + * + * <p>Note that due to type erasure the type returned by the SharedStateSampler is not + * available at run-time and the GaussianSampler has to assume it is the correct type.</p> + */ + private static class BadSharedStateNormalizedGaussianSampler + implements NormalizedGaussianSampler, SharedStateSampler<Integer> { + @Override + public double sample() { + return 0; + } + + @Override + public Integer withUniformRandomProvider(UniformRandomProvider rng) { + // Something that is not a NormalizedGaussianSampler + return Integer.valueOf(44); + } + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GeometricSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GeometricSamplerTest.java index b3bab01..66eaee4 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GeometricSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GeometricSamplerTest.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -111,4 +112,35 @@ public class GeometricSamplerTest { @SuppressWarnings("unused") final GeometricSampler sampler = new GeometricSampler(unusedRng, probabilityOfSuccess); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + testSharedStateSampler(0.5); + } + + /** + * Test the SharedStateSampler implementation with the edge case when the probability of + * success is {@code 1.0}. + */ + @Test + public void testSharedStateSamplerWithProbabilityOfSuccessOne() { + testSharedStateSampler(1.0); + } + + /** + * Test the SharedStateSampler implementation. + * + * @param probabilityOfSuccess Probability of success. + */ + private static void testSharedStateSampler(double probabilityOfSuccess) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final GeometricSampler sampler1 = + new GeometricSampler(rng1, probabilityOfSuccess); + final GeometricSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java index f312c93..ec89352 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java @@ -20,6 +20,7 @@ import org.apache.commons.math3.distribution.BinomialDistribution; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.commons.math3.stat.inference.ChiSquareTest; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -234,4 +235,18 @@ public class GuideTableDiscreteSamplerTest { // Pass if we cannot reject null hypothesis that the distributions are the same. Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double[] probabilities = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0}; + final GuideTableDiscreteSampler sampler1 = + new GuideTableDiscreteSampler(rng1, probabilities); + final GuideTableDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSamplerTest.java new file mode 100644 index 0000000..6ddbb07 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformContinuousSamplerTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Test; + +/** + * Test for the {@link InverseTransformContinuousSampler}. + */ +public class InverseTransformContinuousSamplerTest { + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + ContinuousInverseCumulativeProbabilityFunction function = + new ContinuousInverseCumulativeProbabilityFunction() { + @Override + public double inverseCumulativeProbability(double p) { + return 456.99 * p; + } + }; + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final InverseTransformContinuousSampler sampler1 = + new InverseTransformContinuousSampler(rng1, function); + final InverseTransformContinuousSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSamplerTest.java new file mode 100644 index 0000000..b9d3f89 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformDiscreteSamplerTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Test; + +/** + * Test for the {@link InverseTransformDiscreteSampler}. + */ +public class InverseTransformDiscreteSamplerTest { + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + DiscreteInverseCumulativeProbabilityFunction function = + new DiscreteInverseCumulativeProbabilityFunction() { + @Override + public int inverseCumulativeProbability(double p) { + return (int) Math.round(789 * p); + } + }; + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final InverseTransformDiscreteSampler sampler1 = + new InverseTransformDiscreteSampler(rng1, function); + final InverseTransformDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSamplerTest.java index 0119ed2..b4314c7 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSamplerTest.java @@ -17,6 +17,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -31,11 +33,11 @@ public class InverseTransformParetoSamplerTest { public void testConstructorThrowsWithZeroScale() { final RestorableUniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - final double shape = 1; final double scale = 0; + final double shape = 1; @SuppressWarnings("unused") final InverseTransformParetoSampler sampler = - new InverseTransformParetoSampler(rng, shape, scale); + new InverseTransformParetoSampler(rng, scale, shape); } /** @@ -45,10 +47,25 @@ public class InverseTransformParetoSamplerTest { public void testConstructorThrowsWithZeroShape() { final RestorableUniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - final double shape = 0; final double scale = 1; + final double shape = 0; @SuppressWarnings("unused") final InverseTransformParetoSampler sampler = - new InverseTransformParetoSampler(rng, shape, scale); + new InverseTransformParetoSampler(rng, scale, shape); + } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double scale = 1.23; + final double shape = 4.56; + final InverseTransformParetoSampler sampler1 = + new InverseTransformParetoSampler(rng1, scale, shape); + final InverseTransformParetoSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSamplerTest.java index 013c1c0..04b8e7e 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSamplerTest.java @@ -18,6 +18,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -145,6 +147,20 @@ public class KempSmallMeanPoissonSamplerTest { } /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double mean = 1.23; + final KempSmallMeanPoissonSampler sampler1 = + new KempSmallMeanPoissonSampler(rng1, mean); + final KempSmallMeanPoissonSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** * Test a sample from the Poisson distribution at the given cumulative probability. * * @param rng the fixed random generator backing the sampler diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java index 63c1f55..89022d1 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java @@ -18,6 +18,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RandomProviderState; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.sampling.distribution.LargeMeanPoissonSampler.LargeMeanPoissonSamplerState; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; @@ -141,4 +143,35 @@ public class LargeMeanPoissonSamplerTest { Assert.assertEquals("Not the same sample", s1.sample(), s2.sample()); } } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSamplerWithFractionalMean() { + testSharedStateSampler(34.5); + } + + /** + * Test the SharedStateSampler implementation with the edge case when there is no + * small mean sampler (i.e. no fraction part to the mean). + */ + @Test + public void testSharedStateSamplerWithIntegerMean() { + testSharedStateSampler(34.0); + } + + /** + * Test the SharedStateSampler implementation. + * + * @param mean Mean. + */ + private static void testSharedStateSampler(double mean) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final LargeMeanPoissonSampler sampler1 = + new LargeMeanPoissonSampler(rng1, mean); + final LargeMeanPoissonSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LogNormalSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LogNormalSamplerTest.java index 104e552..e708fc4 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LogNormalSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LogNormalSamplerTest.java @@ -17,6 +17,9 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.sampling.SharedStateSampler; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -53,4 +56,75 @@ public class LogNormalSamplerTest { final LogNormalSampler sampler = new LogNormalSampler(gauss, scale, shape); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng1); + final double scale = 1.23; + final double shape = 4.56; + final LogNormalSampler sampler1 = + new LogNormalSampler(gauss, scale, shape); + final LogNormalSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Test the SharedStateSampler implementation throws if the underlying sampler is + * not a SharedStateSampler. + */ + @Test(expected = UnsupportedOperationException.class) + public void testSharedStateSamplerThrowsIfUnderlyingSamplerDoesNotShareState() { + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final NormalizedGaussianSampler gauss = new NormalizedGaussianSampler() { + @Override + public double sample() { + return 0; + } + }; + final double scale = 1.23; + final double shape = 4.56; + final LogNormalSampler sampler1 = + new LogNormalSampler(gauss, scale, shape); + sampler1.withUniformRandomProvider(rng2); + } + + /** + * Test the SharedStateSampler implementation throws if the underlying sampler is + * a SharedStateSampler that returns an incorrect type. + */ + @Test(expected = ClassCastException.class) + public void testSharedStateSamplerThrowsIfUnderlyingSamplerReturnsWrongSharedState() { + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final NormalizedGaussianSampler gauss = new BadSharedStateNormalizedGaussianSampler(); + final double scale = 1.23; + final double shape = 4.56; + final LogNormalSampler sampler1 = + new LogNormalSampler(gauss, scale, shape); + sampler1.withUniformRandomProvider(rng2); + } + + /** + * Test class to return an incorrect sampler from the SharedStateSampler method. + * + * <p>Note that due to type erasure the type returned by the SharedStateSampler is not + * available at run-time and the LogNormalSampler has to assume it is the correct type.</p> + */ + private static class BadSharedStateNormalizedGaussianSampler + implements NormalizedGaussianSampler, SharedStateSampler<Integer> { + @Override + public double sample() { + return 0; + } + + @Override + public Integer withUniformRandomProvider(UniformRandomProvider rng) { + // Something that is not a NormalizedGaussianSampler + return Integer.valueOf(44); + } + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalisedGaussianSamplerTest.java similarity index 57% copy from commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java copy to commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalisedGaussianSamplerTest.java index a0ab018..89b7339 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaNormalisedGaussianSamplerTest.java @@ -17,22 +17,24 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; /** - * Test for the {@link DiscreteUniformSampler}. The tests hit edge cases for the sampler. + * Test for the {@link MarsagliaNormalizedGaussianSampler}. */ -public class DiscreteUniformSamplerTest { +public class MarsagliaNormalisedGaussianSamplerTest { /** - * Test the constructor with a bad range. + * Test the SharedStateSampler implementation. */ - @Test(expected = IllegalArgumentException.class) - public void testConstructorThrowsWithLowerAboveUpper() { - final int upper = 55; - final int lower = upper + 1; - final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - @SuppressWarnings("unused") - DiscreteUniformSampler sampler = new DiscreteUniformSampler(rng, lower, upper); + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final MarsagliaNormalizedGaussianSampler sampler1 = + new MarsagliaNormalizedGaussianSampler(rng1); + final MarsagliaNormalizedGaussianSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java index 5ec43df..37164f7 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java @@ -20,6 +20,7 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest; import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.core.source32.IntProvider; import org.apache.commons.rng.core.source64.SplitMix64; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -162,8 +163,8 @@ public class MarsagliaTsangWangDiscreteSamplerTest { /** * Creates the probabilities using zero padding below the values. * - * @param offset the offset - * @param prob the probability values + * @param offset Offset for first given probability (i.e. the zero padding size). + * @param prob Probability values. * @return the zero-padded probabilities */ private static double[] createProbabilities(int offset, int[] prob) { @@ -241,7 +242,7 @@ public class MarsagliaTsangWangDiscreteSamplerTest { * tests the limits described in the class Javadoc is correct. * * @param k Base is 2^k. - * @param expectedLimitMB the expected limit in MB + * @param expectedLimitMB Expected limit in MB. */ private static void checkStorageRequirements(int k, double expectedLimitMB) { // Worst case scenario is a uniform distribution of 2^k samples each with the highest @@ -283,8 +284,8 @@ public class MarsagliaTsangWangDiscreteSamplerTest { /** * Gets the k<sup>th</sup> base 64 digit of {@code m}. * - * @param m the value m. - * @param k the digit. + * @param m Value m. + * @param k Digit. * @return the base 64 digit */ private static int getBase64Digit(int m, int k) { @@ -489,8 +490,8 @@ public class MarsagliaTsangWangDiscreteSamplerTest { /** * Gets the p(0) value for the Binomial distribution. * - * @param trials the trials - * @param probabilityOfSuccess the probability of success + * @param trials Number of trials. + * @param probabilityOfSuccess Probability of success. * @return the p(0) value */ private static double getBinomialP0(int trials, double probabilityOfSuccess) { @@ -575,6 +576,80 @@ public class MarsagliaTsangWangDiscreteSamplerTest { } /** + * Test the SharedStateSampler implementation with the 8-bit storage implementation. + */ + @Test + public void testSharedStateSamplerWith8bitStorage() { + testSharedStateSampler(0, new int[] {1, 2, 3, 4, 5}); + } + + /** + * Test the SharedStateSampler implementation with the 16-bit storage implementation. + */ + @Test + public void testSharedStateSamplerWith16bitStorage() { + testSharedStateSampler(1 << 8, new int[] {1, 2, 3, 4, 5}); + } + + /** + * Test the SharedStateSampler implementation with the 32-bit storage implementation. + */ + @Test + public void testSharedStateSamplerWith32bitStorage() { + testSharedStateSampler(1 << 16, new int[] {1, 2, 3, 4, 5}); + } + + /** + * Test the SharedStateSampler implementation using zero padded probabilities to force + * different storage implementations. + * + * @param offset Offset for first given probability (i.e. the zero padding size). + * @param prob Probability values. + */ + private static void testSharedStateSampler(int offset, int[] prob) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + double[] probabilities = createProbabilities(offset, prob); + final MarsagliaTsangWangDiscreteSampler sampler1 = + MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(rng1, probabilities); + final MarsagliaTsangWangDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Test the SharedStateSampler implementation with a Binomial distribution with a fixed result. + */ + @Test + public void testSharedStateSamplerWithFixedBinomialDistribution() { + testSharedStateSampler(10, 1.0); + } + + /** + * Test the SharedStateSampler implementation with a Binomial distribution that requires + * inversion (probability of success > 0.5). + */ + @Test + public void testSharedStateSamplerWithInvertedBinomialDistribution() { + testSharedStateSampler(10, 0.999); + } + + /** + * Test the SharedStateSampler implementation using a binomial distribution to exercise + * special implementations. + * + * @param trials Number of trials. + * @param probabilityOfSuccess Probability of success. + */ + private static void testSharedStateSampler(int trials, double probabilityOfSuccess) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final MarsagliaTsangWangDiscreteSampler sampler1 = + MarsagliaTsangWangDiscreteSampler.createBinomialDistribution(rng1, trials, probabilityOfSuccess); + final MarsagliaTsangWangDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** * Return a fixed sequence of {@code int} output. */ private static class FixedSequenceIntProvider extends IntProvider { diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerTest.java new file mode 100644 index 0000000..9899b61 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Test; + +/** + * This test checks the {@link PoissonSampler} can be created + * from a saved state. + */ +public class PoissonSamplerTest { + /** + * Test the SharedStateSampler implementation with a mean below 40. + */ + @Test + public void testSharedStateSamplerWithSmallMean() { + testSharedStateSampler(34.5); + } + + /** + * Test the SharedStateSampler implementation with a mean above 40. + */ + @Test + public void testSharedStateSamplerWithLargeMean() { + testSharedStateSampler(67.8); + } + + /** + * Test the SharedStateSampler implementation. + * + * @param mean Mean. + */ + private static void testSharedStateSampler(double mean) { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final PoissonSampler sampler1 = + new PoissonSampler(rng1, mean); + final PoissonSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSamplerTest.java index d022e3e..d7e3654 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSamplerTest.java @@ -17,6 +17,8 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Test; @@ -51,4 +53,19 @@ public class RejectionInversionZipfSamplerTest { final RejectionInversionZipfSampler sampler = new RejectionInversionZipfSampler(rng, numberOfElements, exponent); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final int numberOfElements = 7; + final double exponent = 1.23; + final RejectionInversionZipfSampler sampler1 = + new RejectionInversionZipfSampler(rng1, numberOfElements, exponent); + final RejectionInversionZipfSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSamplerTest.java index 5b22785..8d34fac 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSamplerTest.java @@ -17,6 +17,7 @@ package org.apache.commons.rng.sampling.distribution; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; import org.apache.commons.rng.simple.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -77,4 +78,18 @@ public class SmallMeanPoissonSamplerTest { Assert.assertEquals(expected, sampler.sample()); } } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final double mean = 1.23; + final SmallMeanPoissonSampler sampler1 = + new SmallMeanPoissonSampler(rng1, mean); + final SmallMeanPoissonSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSamplerTest.java index e9a3657..d4a2e16 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSamplerTest.java @@ -18,6 +18,8 @@ package org.apache.commons.rng.sampling.distribution; import org.junit.Test; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.simple.RandomSource; /** * Test for {@link ZigguratNormalizedGaussianSampler}. @@ -45,4 +47,17 @@ public class ZigguratNormalizedGaussianSamplerTest { // Infinite loop (in v1.1). new ZigguratNormalizedGaussianSampler(bad).sample(); } + + /** + * Test the SharedStateSampler implementation. + */ + @Test + public void testSharedStateSampler() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final ZigguratNormalizedGaussianSampler sampler1 = + new ZigguratNormalizedGaussianSampler(rng1); + final ZigguratNormalizedGaussianSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } }