Repository: commons-rng Updated Branches: refs/heads/master 01a2c09ce -> 6ec1d323e
RNG-47: Sampling from discrete probability distribution. Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/6ec1d323 Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/6ec1d323 Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/6ec1d323 Branch: refs/heads/master Commit: 6ec1d323e1747152411fa0d52128614bc8ea0f30 Parents: 01a2c09 Author: Gilles <[email protected]> Authored: Wed Jan 17 13:47:24 2018 +0100 Committer: Gilles <[email protected]> Committed: Wed Jan 17 13:47:24 2018 +0100 ---------------------------------------------------------------------- .../DiscreteProbabilityCollectionSampler.java | 185 +++++++++++++++++++ ...iscreteProbabilityCollectionSamplerTest.java | 87 +++++++++ 2 files changed, 272 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-rng/blob/6ec1d323/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java new file mode 100644 index 0000000..8f87c15 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.sampling; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Sampling from a {@link Collection} of items with user-defined + * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> + * probabilities</a>. + * Note that if all unique items are assigned the same probability, + * it is much more efficient to use {@link CollectionSampler}. + * + * @param <T> Type of items in the collection. + * + * @since 1.1 + */ +public class DiscreteProbabilityCollectionSampler<T> { + /** Collection to be sampled from. */ + private final List<T> items; + /** RNG. */ + private final UniformRandomProvider rng; + /** Cumulative probabilities. */ + private final double[] cumulativeProbabilities; + + /** + * Creates a sampler. + * + * @param rng Generator of uniformly distributed random numbers. + * @param collection Collection to be sampled, with the probabilities + * associated to each of its items. + * A (shallow) copy of the items will be stored in the created instance. + * The probabilities must be non-negative, but zero values are allowed + * and their sum does not have to equal one (input will be normalized + * to make the probabilities sum to one). + * @throws IllegalArgumentException if {@code collection} is empty, a + * probability is negative, infinite or {@code NaN}, or the sum of all + * probabilities is not strictly positive. + */ + public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, + Map<T, Double> collection) { + if (collection.isEmpty()) { + throw new IllegalArgumentException("Empty collection"); + } + + this.rng = rng; + final int size = collection.size(); + items = new ArrayList<T>(size); + cumulativeProbabilities = new double[size]; + + double sumProb = 0; + int count = 0; + for (Map.Entry<T, Double> e : collection.entrySet()) { + items.add(e.getKey()); + + final double prob = e.getValue(); + if (prob < 0 || + Double.isInfinite(prob) || + Double.isNaN(prob)) { + throw new IllegalArgumentException("Invalid probability: " + + prob); + } + + // Temporarily store probability. + cumulativeProbabilities[count++] = prob; + sumProb += prob; + } + + if (!(sumProb > 0)) { + throw new IllegalArgumentException("Invalid sum of probabilities"); + } + + // Compute and store cumulative probability. + for (int i = 0; i < size; i++) { + cumulativeProbabilities[i] /= sumProb; + if (i > 0) { + cumulativeProbabilities[i] += cumulativeProbabilities[i - 1]; + } + } + } + + /** + * Creates a sampler. + * + * @param rng Generator of uniformly distributed random numbers. + * @param collection Collection to be sampled. + * A (shallow) copy of the items will be stored in the created instance. + * @param probabilities Probability associated to each item of the + * {@code collection}. + * The probabilities must be non-negative, but zero values are allowed + * and their sum does not have to equal one (input will be normalized + * to make the probabilities sum to one). + * @throws IllegalArgumentException if {@code collection} is empty or + * a probability is negative, infinite or {@code NaN}, or if the number + * of items in the {@code collection} is not equal to the number of + * provided {@code probabilities}. + */ + public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, + List<T> collection, + double[] probabilities) { + this(rng, consolidate(collection, probabilities)); + } + + /** + * Picks one of the items from the collection passed to the constructor. + * + * @return a random sample. + */ + public T sample() { + final double rand = rng.nextDouble(); + + int index = Arrays.binarySearch(cumulativeProbabilities, rand); + if (index < 0) { + index = -index - 1; + } + + if (index >= 0 && + index < cumulativeProbabilities.length && + rand < cumulativeProbabilities[index]) { + return items.get(index); + } + + // This should never happen, but it ensures we will return a correct + // object in case there is some floating point inequality problem + // wrt the cumulative probabilities. + return items.get(items.size() - 1); + } + + /** + * @param collection Collection to be sampled. + * @param probabilities Probability associated to each item of the + * {@code collection}. + * @return a consolidated map (where probabilities of equal items + * have been summed). + * @throws IllegalArgumentException if the number of items in the + * {@code collection} is not equal to the number of provided + * {@code probabilities}. + */ + private static <T> Map<T, Double> consolidate(List<T> collection, + double[] probabilities) { + final int len = probabilities.length; + if (len != collection.size()) { + throw new IllegalArgumentException("Size mismatch: " + + len + " != " + + collection.size()); + } + + final Map<T, Double> map = new HashMap<T, Double>(); + for (int i = 0; i < len; i++) { + final T item = collection.get(i); + final Double prob = probabilities[i]; + + Double currentProb = map.get(item); + if (currentProb == null) { + currentProb = 0d; + } + + map.put(item, currentProb + prob); + } + + return map; + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/6ec1d323/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java new file mode 100644 index 0000000..757d44e --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.sampling; + +import java.util.Arrays; +import org.junit.Assert; +import org.junit.Test; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; + +/** + * Test class for {@link DiscreteProbabilityCollectionSampler}. + */ +public class DiscreteProbabilityCollectionSamplerTest { + /** RNG. */ + private static final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A); + + @Test(expected=IllegalArgumentException.class) + public void testPrecondition1() { + new DiscreteProbabilityCollectionSampler<Double>(rng, + Arrays.asList(new Double[] {1d, 2d}), + new double[] {0d}); + } + @Test(expected=IllegalArgumentException.class) + public void testPrecondition2() { + new DiscreteProbabilityCollectionSampler<Double>(rng, + Arrays.asList(new Double[] {1d, 2d}), + new double[] {0d, -1d}); + } + @Test(expected=IllegalArgumentException.class) + public void testPrecondition3() { + new DiscreteProbabilityCollectionSampler<Double>(rng, + Arrays.asList(new Double[] {1d, 2d}), + new double[] {0d, 0d}); + } + @Test(expected=IllegalArgumentException.class) + public void testPrecondition4() { + new DiscreteProbabilityCollectionSampler<Double>(rng, + Arrays.asList(new Double[] {1d, 2d}), + new double[] {0d, Double.NaN}); + } + @Test(expected=IllegalArgumentException.class) + public void testPrecondition5() { + new DiscreteProbabilityCollectionSampler<Double>(rng, + Arrays.asList(new Double[] {1d, 2d}), + new double[] {0d, Double.POSITIVE_INFINITY}); + } + + @Test + public void testSample() { + final DiscreteProbabilityCollectionSampler<Double> sampler = + new DiscreteProbabilityCollectionSampler<Double>(rng, + Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}), + new double[] {0.2, 0.2, 0.3, 0.3, 0, 0}); + final double expectedMean = 3.4; + final double expectedVariance = 7.84; + + final int n = 100000000; + double sum = 0; + double sumOfSquares = 0; + for (int i = 0; i < n; i++) { + final double rand = sampler.sample(); + sum += rand; + sumOfSquares += rand * rand; + } + + final double mean = sum / n; + Assert.assertEquals(expectedMean, mean, 1e-3); + final double variance = sumOfSquares / n - mean * mean; + Assert.assertEquals(expectedVariance, variance, 1e-3); + } +}
