This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 9d7bef003fb0278a92a91b55910af1698850e9f8 Author: Alex Herbert <[email protected]> AuthorDate: Wed Nov 22 14:37:17 2023 +0000 SpotBugs: Prevent Finalizer attack using validated constructor arguments Solves CT_CONSTRUCTOR_THROW. All arguments to constructors are validated before passing to a private constructor that cannot raise an exception. This moves many validation methods to InternalUtils to provide consistent behaviour. --- .../apache/commons/rng/core/source64/TwoCmres.java | 24 +++- .../commons/rng/sampling/CollectionSampler.java | 34 ++++-- .../DiscreteProbabilityCollectionSampler.java | 130 +++++++++++++++------ .../AhrensDieterExponentialSampler.java | 19 ++- .../AhrensDieterMarsagliaTsangGammaSampler.java | 20 +++- .../distribution/BoxMullerGaussianSampler.java | 15 ++- .../sampling/distribution/ChengBetaSampler.java | 17 +-- .../sampling/distribution/DirichletSampler.java | 5 +- .../rng/sampling/distribution/GaussianSampler.java | 33 +++--- .../distribution/GuideTableDiscreteSampler.java | 12 +- .../rng/sampling/distribution/InternalUtils.java | 126 +++++++++++++++++--- .../InverseTransformParetoSampler.java | 20 +++- .../distribution/KempSmallMeanPoissonSampler.java | 4 +- .../distribution/LargeMeanPoissonSampler.java | 49 +++++--- .../rng/sampling/distribution/LevySampler.java | 4 +- .../sampling/distribution/LogNormalSampler.java | 15 ++- .../MarsagliaTsangWangDiscreteSampler.java | 8 +- .../rng/sampling/distribution/PoissonSampler.java | 12 +- .../sampling/distribution/PoissonSamplerCache.java | 20 +++- .../RejectionInversionZipfSampler.java | 4 +- .../distribution/SmallMeanPoissonSampler.java | 48 ++++++-- .../rng/sampling/distribution/StableSampler.java | 9 +- .../rng/sampling/distribution/ZigguratSampler.java | 5 +- src/main/resources/pmd/pmd-ruleset.xml | 2 +- .../resources/spotbugs/spotbugs-exclude-filter.xml | 13 +++ 25 files changed, 452 insertions(+), 196 deletions(-) diff --git a/commons-rng-core/src/main/java/org/apache/commons/rng/core/source64/TwoCmres.java b/commons-rng-core/src/main/java/org/apache/commons/rng/core/source64/TwoCmres.java index 94ecd18b..52705bad 100644 --- a/commons-rng-core/src/main/java/org/apache/commons/rng/core/source64/TwoCmres.java +++ b/commons-rng-core/src/main/java/org/apache/commons/rng/core/source64/TwoCmres.java @@ -54,14 +54,10 @@ public class TwoCmres extends LongProvider { * @param seed Initial seed. * @param x First subcycle generator. * @param y Second subcycle generator. - * @throws IllegalArgumentException if {@code x == y}. */ private TwoCmres(int seed, Cmres x, Cmres y) { - if (x.equals(y)) { - throw new IllegalArgumentException("Subcycle generators must be different"); - } this.x = x; this.y = y; setSeedInternal(seed); @@ -91,7 +87,7 @@ public class TwoCmres extends LongProvider { public TwoCmres(Integer seed, int i, int j) { - this(seed, FACTORY.get(i), FACTORY.get(j)); + this(seed, FACTORY.getIfDifferent(i, j), FACTORY.get(j)); } /** {@inheritDoc} */ @@ -274,6 +270,24 @@ public class TwoCmres extends LongProvider { return TABLE.get(index); } + /** + * Get the generator at {@code index} if the {@code other} index is different. + * + * <p>This method exists to raise an exception before invocation of the + * private constructor; this mitigates Finalizer attacks + * (see SpotBugs CT_CONSTRUCTOR_THROW). + * + * @param index Index into the list of available generators. + * @param other Other index. + * @return the subcycle generator entry at index {@code index}. + */ + Cmres getIfDifferent(int index, int other) { + if (index == other) { + throw new IllegalArgumentException("Subcycle generators must be different"); + } + return get(index); + } + /** * Adds an entry to the {@link Factory#TABLE}. * diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java index 44bfee77..183cfae4 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CollectionSampler.java @@ -48,22 +48,17 @@ public class CollectionSampler<T> implements SharedStateObjectSampler<T> { */ public CollectionSampler(UniformRandomProvider rng, Collection<T> collection) { - if (collection.isEmpty()) { - throw new IllegalArgumentException("Empty collection"); - } - - this.rng = rng; - items = new ArrayList<>(collection); + this(rng, toList(collection)); } /** * @param rng Generator of uniformly distributed random numbers. - * @param source Source to copy. + * @param collection Collection to be sampled. */ private CollectionSampler(UniformRandomProvider rng, - CollectionSampler<T> source) { + List<T> collection) { this.rng = rng; - items = source.items; + items = collection; } /** @@ -85,6 +80,25 @@ public class CollectionSampler<T> implements SharedStateObjectSampler<T> { */ @Override public CollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { - return new CollectionSampler<>(rng, this); + return new CollectionSampler<>(rng, this.items); + } + + /** + * Convert the collection to a list (shallow) copy. + * + * <p>This method exists to raise an exception before invocation of the + * private constructor; this mitigates Finalizer attacks + * (see SpotBugs CT_CONSTRUCTOR_THROW). + * + * @param <T> Type of items in the collection. + * @param collection Collection. + * @return the list copy + * @throws IllegalArgumentException if {@code collection} is empty. + */ + private static <T> List<T> toList(Collection<T> collection) { + if (collection.isEmpty()) { + throw new IllegalArgumentException("Empty collection"); + } + return new ArrayList<>(collection); } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java index ae2247df..325677c9 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java @@ -20,7 +20,6 @@ package org.apache.commons.rng.sampling; import java.util.List; import java.util.Map; import java.util.ArrayList; - import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler; @@ -62,23 +61,8 @@ public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjec */ public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, Map<T, Double> collection) { - if (collection.isEmpty()) { - throw new IllegalArgumentException(EMPTY_COLLECTION); - } - - // Extract the items and probabilities - final int size = collection.size(); - items = new ArrayList<>(size); - final double[] probabilities = new double[size]; - - int count = 0; - for (final Map.Entry<T, Double> e : collection.entrySet()) { - items.add(e.getKey()); - probabilities[count++] = e.getValue(); - } - - // Delegate sampling - sampler = createSampler(rng, probabilities); + this(toList(collection), + createSampler(rng, toProbabilities(collection))); } /** @@ -100,29 +84,18 @@ public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjec public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, List<T> collection, double[] probabilities) { - if (collection.isEmpty()) { - throw new IllegalArgumentException(EMPTY_COLLECTION); - } - final int len = probabilities.length; - if (len != collection.size()) { - throw new IllegalArgumentException("Size mismatch: " + - len + " != " + - collection.size()); - } - // Shallow copy the list - items = new ArrayList<>(collection); - // Delegate sampling - sampler = createSampler(rng, probabilities); + this(copyList(collection), + createSampler(rng, collection, probabilities)); } /** - * @param rng Generator of uniformly distributed random numbers. - * @param source Source to copy. + * @param items Collection to be sampled. + * @param sampler Sampler for the probabilities. */ - private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, - DiscreteProbabilityCollectionSampler<T> source) { - this.items = source.items; - this.sampler = source.sampler.withUniformRandomProvider(rng); + private DiscreteProbabilityCollectionSampler(List<T> items, + SharedStateDiscreteSampler sampler) { + this.items = items; + this.sampler = sampler; } /** @@ -142,7 +115,7 @@ public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjec */ @Override public DiscreteProbabilityCollectionSampler<T> withUniformRandomProvider(UniformRandomProvider rng) { - return new DiscreteProbabilityCollectionSampler<>(rng, this); + return new DiscreteProbabilityCollectionSampler<>(items, sampler.withUniformRandomProvider(rng)); } /** @@ -156,4 +129,85 @@ public class DiscreteProbabilityCollectionSampler<T> implements SharedStateObjec double[] probabilities) { return GuideTableDiscreteSampler.of(rng, probabilities); } + + /** + * Creates the sampler of the enumerated probability distribution. + * + * @param <T> Type of items in the collection. + * @param rng Generator of uniformly distributed random numbers. + * @param collection Collection to be sampled. + * @param probabilities Probability associated to each item. + * @return the sampler + * @throws IllegalArgumentException if the number + * of items in the {@code collection} is not equal to the number of + * provided {@code probabilities}. + */ + private static <T> SharedStateDiscreteSampler createSampler(UniformRandomProvider rng, + List<T> collection, + double[] probabilities) { + if (probabilities.length != collection.size()) { + throw new IllegalArgumentException("Size mismatch: " + + probabilities.length + " != " + + collection.size()); + } + return GuideTableDiscreteSampler.of(rng, probabilities); + } + + // Validation methods exist to raise an exception before invocation of the + // private constructor; this mitigates Finalizer attacks + // (see SpotBugs CT_CONSTRUCTOR_THROW). + + /** + * Extract the items. + * + * @param <T> Type of items in the collection. + * @param collection Collection. + * @return the items + * @throws IllegalArgumentException if {@code collection} is empty. + */ + private static <T> List<T> toList(Map<T, Double> collection) { + if (collection.isEmpty()) { + throw new IllegalArgumentException(EMPTY_COLLECTION); + } + return new ArrayList<>(collection.keySet()); + } + + /** + * Extract the probabilities. + * + * @param <T> Type of items in the collection. + * @param collection Collection. + * @return the probabilities + */ + private static <T> double[] toProbabilities(Map<T, Double> collection) { + final int size = collection.size(); + final double[] probabilities = new double[size]; + int count = 0; + for (final Double e : collection.values()) { + final double probability = e; + if (probability < 0 || + Double.isInfinite(probability) || + Double.isNaN(probability)) { + throw new IllegalArgumentException("Invalid probability: " + + probability); + } + probabilities[count++] = probability; + } + return probabilities; + } + + /** + * Create a (shallow) copy of the collection. + * + * @param <T> Type of items in the collection. + * @param collection Collection. + * @return the copy + * @throws IllegalArgumentException if {@code collection} is empty. + */ + private static <T> List<T> copyList(List<T> collection) { + if (collection.isEmpty()) { + throw new IllegalArgumentException(EMPTY_COLLECTION); + } + return new ArrayList<>(collection); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java index 09097873..1a171121 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.java @@ -75,23 +75,19 @@ public class AhrensDieterExponentialSampler */ public AhrensDieterExponentialSampler(UniformRandomProvider rng, double mean) { - super(null); - if (mean <= 0) { - throw new IllegalArgumentException("mean is not strictly positive: " + mean); - } - this.rng = rng; - this.mean = mean; + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(InternalUtils.requireStrictlyPositive(mean, "mean"), rng); } /** + * @param mean Mean. * @param rng Generator of uniformly distributed random numbers. - * @param source Source to copy. */ - private AhrensDieterExponentialSampler(UniformRandomProvider rng, - AhrensDieterExponentialSampler source) { + private AhrensDieterExponentialSampler(double mean, + UniformRandomProvider rng) { super(null); this.rng = rng; - this.mean = source.mean; + this.mean = mean; } /** {@inheritDoc} */ @@ -149,7 +145,8 @@ public class AhrensDieterExponentialSampler */ @Override public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { - return new AhrensDieterExponentialSampler(rng, this); + // Use private constructor without validation + return new AhrensDieterExponentialSampler(mean, rng); } /** diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java index ebf6e298..8530626c 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java @@ -78,12 +78,20 @@ public class AhrensDieterMarsagliaTsangGammaSampler BaseGammaSampler(UniformRandomProvider rng, double alpha, double theta) { - if (alpha <= 0) { - throw new IllegalArgumentException("alpha is not strictly positive: " + alpha); - } - if (theta <= 0) { - throw new IllegalArgumentException("theta is not strictly positive: " + theta); - } + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(InternalUtils.requireStrictlyPositive(alpha, "alpha"), + InternalUtils.requireStrictlyPositive(theta, "theta"), + rng); + } + + /** + * @param alpha Alpha parameter of the distribution. + * @param theta Theta parameter of the distribution. + * @param rng Generator of uniformly distributed random numbers. + */ + private BaseGammaSampler(double alpha, + double theta, + UniformRandomProvider rng) { this.rng = rng; this.alpha = alpha; this.theta = theta; diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerGaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerGaussianSampler.java index 497362ad..1d3d6b9b 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerGaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/BoxMullerGaussianSampler.java @@ -56,11 +56,18 @@ public class BoxMullerGaussianSampler public BoxMullerGaussianSampler(UniformRandomProvider rng, double mean, double standardDeviation) { + this(mean, InternalUtils.requireStrictlyPositiveFinite(standardDeviation, "standardDeviation"), rng); + } + + /** + * @param rng Generator of uniformly distributed random numbers. + * @param mean Mean of the Gaussian distribution. + * @param standardDeviation Standard deviation of the Gaussian distribution. + */ + private BoxMullerGaussianSampler(double mean, + double standardDeviation, + UniformRandomProvider rng) { super(null); - if (standardDeviation <= 0) { - throw new IllegalArgumentException("standard deviation is not strictly positive: " + - standardDeviation); - } this.rng = rng; this.mean = mean; this.standardDeviation = standardDeviation; diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java index 15e9a5e9..5df1db0d 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ChengBetaSampler.java @@ -314,8 +314,15 @@ public class ChengBetaSampler public ChengBetaSampler(UniformRandomProvider rng, double alpha, double beta) { + this(of(rng, alpha, beta)); + } + + /** + * @param delegate Beta sampler. + */ + private ChengBetaSampler(SharedStateContinuousSampler delegate) { super(null); - delegate = of(rng, alpha, beta); + this.delegate = delegate; } /** {@inheritDoc} */ @@ -353,12 +360,8 @@ public class ChengBetaSampler public static SharedStateContinuousSampler of(UniformRandomProvider rng, double alpha, double beta) { - if (alpha <= 0) { - throw new IllegalArgumentException("alpha is not strictly positive: " + alpha); - } - if (beta <= 0) { - throw new IllegalArgumentException("beta is not strictly positive: " + beta); - } + InternalUtils.requireStrictlyPositive(alpha, "alpha"); + InternalUtils.requireStrictlyPositive(beta, "beta"); // Choose the algorithm. final double a = Math.min(alpha, beta); diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DirichletSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DirichletSampler.java index 86e6358f..db665f57 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DirichletSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DirichletSampler.java @@ -237,10 +237,7 @@ public abstract class DirichletSampler implements SharedStateObjectSampler<doubl */ private static SharedStateContinuousSampler createSampler(UniformRandomProvider rng, double alpha) { - // Negation of logic will detect NaN - if (!isNonZeroPositiveFinite(alpha)) { - throw new IllegalArgumentException("Invalid concentration: " + alpha); - } + InternalUtils.requireStrictlyPositiveFinite(alpha, "alpha concentration"); // Create a Gamma(shape=alpha, scale=1) sampler. if (alpha == 1) { // Special case diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java index 76019939..7a080e1f 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java @@ -50,27 +50,23 @@ public class GaussianSampler implements SharedStateContinuousSampler { public GaussianSampler(NormalizedGaussianSampler normalized, double mean, double standardDeviation) { - if (!(standardDeviation > 0 && standardDeviation < Double.POSITIVE_INFINITY)) { - throw new IllegalArgumentException( - "standard deviation is not strictly positive and finite: " + standardDeviation); - } - if (!Double.isFinite(mean)) { - throw new IllegalArgumentException("mean is not finite: " + mean); - } - this.normalized = normalized; - this.mean = mean; - this.standardDeviation = standardDeviation; + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(InternalUtils.requireFinite(mean, "mean"), + InternalUtils.requireStrictlyPositiveFinite(standardDeviation, "standardDeviation"), + normalized); } /** - * @param rng Generator of uniformly distributed random numbers. - * @param source Source to copy. + * @param mean Mean of the Gaussian distribution. + * @param standardDeviation Standard deviation of the Gaussian distribution. + * @param normalized Generator of N(0,1) Gaussian distributed random numbers. */ - private GaussianSampler(UniformRandomProvider rng, - GaussianSampler source) { - this.mean = source.mean; - this.standardDeviation = source.standardDeviation; - this.normalized = InternalUtils.newNormalizedGaussianSampler(source.normalized, rng); + private GaussianSampler(double mean, + double standardDeviation, + NormalizedGaussianSampler normalized) { + this.normalized = normalized; + this.mean = mean; + this.standardDeviation = standardDeviation; } /** {@inheritDoc} */ @@ -100,7 +96,8 @@ public class GaussianSampler implements SharedStateContinuousSampler { */ @Override public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) { - return new GaussianSampler(rng, this); + return new GaussianSampler(mean, standardDeviation, + InternalUtils.newNormalizedGaussianSampler(normalized, rng)); } /** diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java index 3ad4218a..673674d0 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java @@ -154,16 +154,12 @@ public final class GuideTableDiscreteSampler double sumProb = 0; int count = 0; for (final double prob : probabilities) { - InternalUtils.validateProbability(prob); - // Compute and store cumulative probability. - sumProb += prob; + sumProb += InternalUtils.requirePositiveFinite(prob, "probability"); cumulativeProbabilities[count++] = sumProb; } - if (Double.isInfinite(sumProb) || sumProb <= 0) { - throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); - } + InternalUtils.requireStrictlyPositiveFinite(sumProb, "sum of probabilities"); // Note: The guide table is at least length 1. Compute the size avoiding overflow // in case (alpha * size) is too large. @@ -209,9 +205,7 @@ public final class GuideTableDiscreteSampler if (probabilities == null || probabilities.length == 0) { throw new IllegalArgumentException("Probabilities must not be empty."); } - if (alpha <= 0) { - throw new IllegalArgumentException("Alpha must be strictly positive."); - } + InternalUtils.requireStrictlyPositive(alpha, "alpha"); } /** diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java index fa89e47f..8f69272f 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java @@ -74,29 +74,127 @@ final class InternalUtils { // Class is package-private on purpose; do not make double sumProb = 0; for (final double prob : probabilities) { - validateProbability(prob); - sumProb += prob; + sumProb += requirePositiveFinite(prob, "probability"); } - if (Double.isInfinite(sumProb) || sumProb <= 0) { - throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); + return requireStrictlyPositiveFinite(sumProb, "sum of probabilities"); + } + + /** + * Checks the value {@code x} is finite. + * + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x} is non-finite + */ + static double requireFinite(double x, String name) { + if (!Double.isFinite(x)) { + throw new IllegalArgumentException(name + " is not finite: " + x); + } + return x; + } + + /** + * Checks the value {@code x >= 0} and is finite. + * Note: This method allows {@code x == -0.0}. + * + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x < 0} or is non-finite + */ + static double requirePositiveFinite(double x, String name) { + if (!(x >= 0 && x < Double.POSITIVE_INFINITY)) { + throw new IllegalArgumentException( + name + " is not positive and finite: " + x); + } + return x; + } + + /** + * Checks the value {@code x > 0} and is finite. + * + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x <= 0} or is non-finite + */ + static double requireStrictlyPositiveFinite(double x, String name) { + if (!(x > 0 && x < Double.POSITIVE_INFINITY)) { + throw new IllegalArgumentException( + name + " is not strictly positive and finite: " + x); + } + return x; + } + + /** + * Checks the value {@code x >= 0}. + * Note: This method allows {@code x == -0.0}. + * + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x < 0} + */ + static double requirePositive(double x, String name) { + // Logic inversion detects NaN + if (!(x >= 0)) { + throw new IllegalArgumentException(name + " is not positive: " + x); + } + return x; + } + + /** + * Checks the value {@code x > 0}. + * + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x <= 0} + */ + static double requireStrictlyPositive(double x, String name) { + // Logic inversion detects NaN + if (!(x > 0)) { + throw new IllegalArgumentException(name + " is not strictly positive: " + x); + } + return x; + } + + /** + * Checks the value is within the range: {@code min <= x < max}. + * + * @param min Minimum (inclusive). + * @param max Maximum (exclusive). + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x < min || x >= max}. + */ + static double requireRange(double min, double max, double x, String name) { + if (!(min <= x && x < max)) { + throw new IllegalArgumentException( + String.format("%s not within range: %s <= %s < %s", name, min, x, max)); } - return sumProb; + return x; } /** - * Validate the probability is a finite positive number. + * Checks the value is within the closed range: {@code min <= x <= max}. * - * @param probability Probability. - * @throws IllegalArgumentException if {@code probability} is negative, infinite or {@code NaN}. + * @param min Minimum (inclusive). + * @param max Maximum (inclusive). + * @param x Value. + * @param name Name of the value. + * @return x + * @throws IllegalArgumentException if {@code x < min || x > max}. */ - static void validateProbability(double probability) { - if (probability < 0 || - Double.isInfinite(probability) || - Double.isNaN(probability)) { - throw new IllegalArgumentException("Invalid probability: " + - probability); + static double requireRangeClosed(double min, double max, double x, String name) { + if (!(min <= x && x <= max)) { + throw new IllegalArgumentException( + String.format("%s not within closed range: %s <= %s <= %s", name, min, x, max)); } + return x; } /** diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java index 04acc293..2470033d 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InverseTransformParetoSampler.java @@ -47,13 +47,21 @@ public class InverseTransformParetoSampler public InverseTransformParetoSampler(UniformRandomProvider rng, double scale, double shape) { + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(InternalUtils.requireStrictlyPositive(scale, "scale"), + InternalUtils.requireStrictlyPositive(shape, "shape"), + rng); + } + + /** + * @param scale Scale of the distribution. + * @param shape Shape of the distribution. + * @param rng Generator of uniformly distributed random numbers. + */ + private InverseTransformParetoSampler(double scale, + double shape, + UniformRandomProvider rng) { super(null); - if (scale <= 0) { - throw new IllegalArgumentException("scale is not strictly positive: " + scale); - } - if (shape <= 0) { - throw new IllegalArgumentException("shape is not strictly positive: " + shape); - } this.rng = rng; this.scale = scale; this.oneOverShape = 1 / shape; diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java index 353651d3..07f2ec22 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.java @@ -123,9 +123,7 @@ public final class KempSmallMeanPoissonSampler */ public static SharedStateDiscreteSampler of(UniformRandomProvider rng, double mean) { - if (mean <= 0) { - throw new IllegalArgumentException("Mean is not strictly positive: " + mean); - } + InternalUtils.requireStrictlyPositive(mean, "mean"); final double p0 = Math.exp(-mean); diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java index a0e746ea..245ad889 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java @@ -117,6 +117,7 @@ public class LargeMeanPoissonSampler /** The internal Poisson sampler for the lambda fraction. */ private final SharedStateDiscreteSampler smallMeanPoissonSampler; + /** * @param rng Generator of uniformly distributed random numbers. * @param mean Mean. @@ -125,13 +126,33 @@ public class LargeMeanPoissonSampler */ public LargeMeanPoissonSampler(UniformRandomProvider rng, double mean) { - if (mean < 1) { - throw new IllegalArgumentException("mean is not >= 1: " + mean); - } - // The algorithm is not valid if Math.floor(mean) is not an integer. - if (mean > MAX_MEAN) { - throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); - } + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(InternalUtils.requireRangeClosed(1, MAX_MEAN, mean, "mean"), rng); + } + + /** + * Instantiates a sampler using a precomputed state. + * + * @param rng Generator of uniformly distributed random numbers. + * @param state The state for {@code lambda = (int)Math.floor(mean)}. + * @param lambdaFractional The lambda fractional value + * ({@code mean - (int)Math.floor(mean))}. + * @throws IllegalArgumentException + * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. + */ + LargeMeanPoissonSampler(UniformRandomProvider rng, + LargeMeanPoissonSamplerState state, + double lambdaFractional) { + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(state, InternalUtils.requireRange(0, 1, lambdaFractional, "lambdaFractional"), rng); + } + + /** + * @param mean Mean. + * @param rng Generator of uniformly distributed random numbers. + */ + private LargeMeanPoissonSampler(double mean, + UniformRandomProvider rng) { this.rng = rng; gaussian = ZigguratSampler.NormalizedGaussian.of(rng); @@ -164,20 +185,14 @@ public class LargeMeanPoissonSampler /** * Instantiates a sampler using a precomputed state. * - * @param rng Generator of uniformly distributed random numbers. * @param state The state for {@code lambda = (int)Math.floor(mean)}. * @param lambdaFractional The lambda fractional value * ({@code mean - (int)Math.floor(mean))}. - * @throws IllegalArgumentException - * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. + * @param rng Generator of uniformly distributed random numbers. */ - LargeMeanPoissonSampler(UniformRandomProvider rng, - LargeMeanPoissonSamplerState state, - double lambdaFractional) { - if (lambdaFractional < 0 || lambdaFractional >= 1) { - throw new IllegalArgumentException( - "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional); - } + private LargeMeanPoissonSampler(LargeMeanPoissonSamplerState state, + double lambdaFractional, + UniformRandomProvider rng) { this.rng = rng; gaussian = ZigguratSampler.NormalizedGaussian.of(rng); diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LevySampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LevySampler.java index 0c9f05c1..28468f8b 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LevySampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LevySampler.java @@ -91,9 +91,7 @@ public final class LevySampler implements SharedStateContinuousSampler { public static LevySampler of(UniformRandomProvider rng, double location, double scale) { - if (scale <= 0) { - throw new IllegalArgumentException("scale is not strictly positive: " + scale); - } + InternalUtils.requireStrictlyPositive(scale, "scale"); return new LevySampler(rng, location, scale); } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java index f60d9fdb..b56b6ab6 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LogNormalSampler.java @@ -40,9 +40,18 @@ public class LogNormalSampler implements SharedStateContinuousSampler { public LogNormalSampler(NormalizedGaussianSampler gaussian, double mu, double sigma) { - if (sigma <= 0) { - throw new IllegalArgumentException("sigma is not strictly positive: " + sigma); - } + // Validation before java.lang.Object constructor exits prevents partially initialized object + this(mu, InternalUtils.requireStrictlyPositive(sigma, "sigma"), gaussian); + } + + /** + * @param mu Mean of the natural logarithm of the distribution values. + * @param sigma Standard deviation of the natural logarithm of the distribution values. + * @param gaussian N(0,1) generator. + */ + private LogNormalSampler(double mu, + double sigma, + NormalizedGaussianSampler gaussian) { this.mu = mu; this.sigma = sigma; this.gaussian = gaussian; diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java index 8e1a9a07..15d282b3 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java @@ -759,9 +759,7 @@ public final class MarsagliaTsangWangDiscreteSampler { * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}. */ private static void validatePoissonDistributionParameters(double mean) { - if (mean <= 0) { - throw new IllegalArgumentException("mean is not strictly positive: " + mean); - } + InternalUtils.requireStrictlyPositive(mean, "mean"); if (mean > MAX_MEAN) { throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); } @@ -1032,9 +1030,7 @@ public final class MarsagliaTsangWangDiscreteSampler { if (trials < 0) { throw new IllegalArgumentException("Trials is not positive: " + trials); } - if (probabilityOfSuccess < 0 || probabilityOfSuccess > 1) { - throw new IllegalArgumentException("Probability is not in range [0,1]: " + probabilityOfSuccess); - } + InternalUtils.requireRangeClosed(0, 1, probabilityOfSuccess, "probability of success"); } /** diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java index f97a58e6..12290b87 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java @@ -73,10 +73,16 @@ public class PoissonSampler */ public PoissonSampler(UniformRandomProvider rng, double mean) { - super(null); - // Delegate all work to specialised samplers. - poissonSamplerDelegate = of(rng, mean); + this(of(rng, mean)); + } + + /** + * @param delegate Poisson sampler. + */ + private PoissonSampler(SharedStateDiscreteSampler delegate) { + super(null); + poissonSamplerDelegate = delegate; } /** {@inheritDoc} */ diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java index d918e7d1..e5191106 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java @@ -85,9 +85,17 @@ public class PoissonSamplerCache { */ public PoissonSamplerCache(double minMean, double maxMean) { + this(checkMeanRange(minMean, maxMean), maxMean, false); + } - checkMeanRange(minMean, maxMean); - + /** + * @param minMean The minimum mean covered by the cache. + * @param maxMean The maximum mean covered by the cache. + * @param ignored Ignored value. + */ + private PoissonSamplerCache(double minMean, + double maxMean, + boolean ignored) { // The cache can only be used for the LargeMeanPoissonSampler. if (maxMean < PoissonSampler.PIVOT) { // The upper limit is too small so no cache will be used. @@ -121,11 +129,16 @@ public class PoissonSamplerCache { /** * Check the mean range. * + * <p>This method exists to raise an exception before invocation of the + * private constructor; this mitigates Finalizer attacks + * (see SpotBugs CT_CONSTRUCTOR_THROW). + * * @param minMean The minimum mean covered by the cache. * @param maxMean The maximum mean covered by the cache. + * @return the minimum mean * @throws IllegalArgumentException if {@code maxMean < minMean} */ - private static void checkMeanRange(double minMean, double maxMean) { + private static double checkMeanRange(double minMean, double maxMean) { // Note: // Although a mean of 0 is invalid for a Poisson sampler this case // is handled to make the cache user friendly. Any low means will @@ -138,6 +151,7 @@ public class PoissonSamplerCache { throw new IllegalArgumentException( "Max mean: " + maxMean + " < " + minMean); } + return minMean; } /** diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java index e8c5d470..3d74d59b 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/RejectionInversionZipfSampler.java @@ -334,9 +334,7 @@ public class RejectionInversionZipfSampler if (numberOfElements <= 0) { throw new IllegalArgumentException("number of elements is not strictly positive: " + numberOfElements); } - if (exponent < 0) { - throw new IllegalArgumentException("exponent is not positive: " + exponent); - } + InternalUtils.requirePositive(exponent, "exponent"); // When the exponent is at the limit of 0 the distribution PMF reduces to 1 / n // and sampling can use a discrete uniform sampler. diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java index e52beb43..7fb74e88 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.java @@ -60,18 +60,23 @@ public class SmallMeanPoissonSampler */ public SmallMeanPoissonSampler(UniformRandomProvider rng, double mean) { + this(rng, mean, computeP0(mean)); + } + + /** + * Instantiates a new small mean poisson sampler. + * + * @param rng Generator of uniformly distributed random numbers. + * @param mean Mean. + * @param p0 {@code Math.exp(-mean)}. + */ + private SmallMeanPoissonSampler(UniformRandomProvider rng, + double mean, + double p0) { this.rng = rng; - if (mean <= 0) { - throw new IllegalArgumentException("mean is not strictly positive: " + mean); - } - p0 = Math.exp(-mean); - if (p0 > 0) { - // The returned sample is bounded by 1000 * mean - limit = (int) Math.ceil(1000 * mean); - } else { - // This excludes NaN values for the mean - throw new IllegalArgumentException("No p(x=0) probability for mean: " + mean); - } + this.p0 = p0; + // The returned sample is bounded by 1000 * mean + limit = (int) Math.ceil(1000 * mean); } /** @@ -131,4 +136,25 @@ public class SmallMeanPoissonSampler double mean) { return new SmallMeanPoissonSampler(rng, mean); } + + /** + * Compute {@code Math.exp(-mean)}. + * + * <p>This method exists to raise an exception before invocation of the + * private constructor; this mitigates Finalizer attacks + * (see SpotBugs CT_CONSTRUCTOR_THROW). + * + * @param mean Mean. + * @return the mean + * @throws IllegalArgumentException if {@code mean <= 0} or {@code Math.exp(-mean) == 0} + */ + private static double computeP0(double mean) { + InternalUtils.requireStrictlyPositive(mean, "mean"); + final double p0 = Math.exp(-mean); + if (p0 > 0) { + return p0; + } + // This excludes NaN values for the mean + throw new IllegalArgumentException("No p(x=0) probability for mean: " + mean); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/StableSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/StableSampler.java index 167c75d9..276a72c7 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/StableSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/StableSampler.java @@ -1479,12 +1479,7 @@ public abstract class StableSampler implements SharedStateContinuousSampler { double gamma, double delta) { validateParameters(alpha, beta); - // Logic inversion will identify NaN - if (!(0 < gamma && gamma <= Double.MAX_VALUE)) { - throw new IllegalArgumentException("gamma is not strictly positive and finite: " + gamma); - } - if (!Double.isFinite(delta)) { - throw new IllegalArgumentException("delta is not finite: " + delta); - } + InternalUtils.requireStrictlyPositiveFinite(gamma, "gamma"); + InternalUtils.requireFinite(delta, "delta"); } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratSampler.java index 9723626c..0c18a91f 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratSampler.java @@ -727,10 +727,7 @@ public abstract class ZigguratSampler implements SharedStateContinuousSampler { * @throws IllegalArgumentException if the mean is not strictly positive ({@code mean <= 0}) */ public static Exponential of(UniformRandomProvider rng, double mean) { - if (mean > 0) { - return new ExponentialMean(rng, mean); - } - throw new IllegalArgumentException("Mean is not strictly positive: " + mean); + return new ExponentialMean(rng, InternalUtils.requireStrictlyPositive(mean, "mean")); } } diff --git a/src/main/resources/pmd/pmd-ruleset.xml b/src/main/resources/pmd/pmd-ruleset.xml index 112c365d..544c7470 100644 --- a/src/main/resources/pmd/pmd-ruleset.xml +++ b/src/main/resources/pmd/pmd-ruleset.xml @@ -185,7 +185,7 @@ <properties> <!-- Logic inversion allows detection of NaN for parameters that are expected in a range --> <property name="violationSuppressXPath" - value="//ClassOrInterfaceDeclaration[@SimpleName='GaussianSampler' or @SimpleName='StableSampler']"/> + value="//ClassOrInterfaceDeclaration[@SimpleName='InternalUtils' or @SimpleName='StableSampler']"/> </properties> </rule> <rule ref="category/java/design.xml/ImmutableField"> diff --git a/src/main/resources/spotbugs/spotbugs-exclude-filter.xml b/src/main/resources/spotbugs/spotbugs-exclude-filter.xml index 660d892f..b2ff2378 100644 --- a/src/main/resources/spotbugs/spotbugs-exclude-filter.xml +++ b/src/main/resources/spotbugs/spotbugs-exclude-filter.xml @@ -149,4 +149,17 @@ <BugPattern name="FL_FLOATS_AS_LOOP_COUNTERS"/> </Match> + <!-- Code prevents Finalizer attacks using a private constructor that accepts + validated arguments. This solution is provided by: + https://wiki.sei.cmu.edu/confluence/display/java/OBJ11-J.+Be+wary+of+letting+constructors+throw+exceptions + It is not (always) detected by SpotBugs, e.g. where a validation method in the + same class returns a primitive value. --> + <Match> + <Or> + <Class name="org.apache.commons.rng.sampling.distribution.SmallMeanPoissonSampler"/> + <Class name="org.apache.commons.rng.sampling.distribution.PoissonSamplerCache"/> + </Or> + <BugPattern name="CT_CONSTRUCTOR_THROW"/> + </Match> + </FindBugsFilter>
