This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit dc779b4a919ae88b6094ec2c8f7811e38070c3bf Author: aherbert <[email protected]> AuthorDate: Wed May 4 13:25:35 2022 +0100 RNG-177: Add stream methods to the sampler API --- .../apache/commons/rng/sampling/ObjectSampler.java | 34 +++++++++- .../sampling/distribution/ContinuousSampler.java | 32 +++++++++- .../rng/sampling/distribution/DiscreteSampler.java | 32 +++++++++- .../rng/sampling/distribution/LongSampler.java | 32 +++++++++- .../commons/rng/sampling/ObjectSamplerTest.java | 72 ++++++++++++++++++++++ .../distribution/ContinuousSamplerTest.java | 71 +++++++++++++++++++++ .../sampling/distribution/DiscreteSamplerTest.java | 71 +++++++++++++++++++++ .../rng/sampling/distribution/LongSamplerTest.java | 71 +++++++++++++++++++++ 8 files changed, 410 insertions(+), 5 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ObjectSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ObjectSampler.java index f69dd6ba..29cb0d85 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ObjectSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ObjectSampler.java @@ -17,6 +17,8 @@ package org.apache.commons.rng.sampling; +import java.util.stream.Stream; + /** * Sampler that generates values of a specified type. * @@ -25,9 +27,37 @@ package org.apache.commons.rng.sampling; */ public interface ObjectSampler<T> { /** - * Create a sample. + * Create an object sample. * - * @return a sample + * @return a sample. */ T sample(); + + /** + * Returns an effectively unlimited stream of object sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(). + * + * @return a stream of object values. + * @since 1.5 + */ + default Stream<T> samples() { + return Stream.generate(this::sample).sequential(); + } + + /** + * Returns a stream producing the given {@code streamSize} number of object + * sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}. + * + * @param streamSize Number of values to generate. + * @return a stream of object values. + * @since 1.5 + */ + default Stream<T> samples(long streamSize) { + return samples().limit(streamSize); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousSampler.java index 8f3bc383..10d81ac2 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ContinuousSampler.java @@ -16,6 +16,8 @@ */ package org.apache.commons.rng.sampling.distribution; +import java.util.stream.DoubleStream; + /** * Sampler that generates values of type {@code double}. * @@ -23,9 +25,37 @@ package org.apache.commons.rng.sampling.distribution; */ public interface ContinuousSampler { /** - * Creates a sample. + * Creates a {@code double} sample. * * @return a sample. */ double sample(); + + /** + * Returns an effectively unlimited stream of {@code double} sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(). + * + * @return a stream of {@code double} values. + * @since 1.5 + */ + default DoubleStream samples() { + return DoubleStream.generate(this::sample).sequential(); + } + + /** + * Returns a stream producing the given {@code streamSize} number of {@code double} + * sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}. + * + * @param streamSize Number of values to generate. + * @return a stream of {@code double} values. + * @since 1.5 + */ + default DoubleStream samples(long streamSize) { + return samples().limit(streamSize); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteSampler.java index 4c01a0af..0db90b4f 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteSampler.java @@ -16,6 +16,8 @@ */ package org.apache.commons.rng.sampling.distribution; +import java.util.stream.IntStream; + /** * Sampler that generates values of type {@code int}. * @@ -23,9 +25,37 @@ package org.apache.commons.rng.sampling.distribution; */ public interface DiscreteSampler { /** - * Creates a sample. + * Creates an {@code int} sample. * * @return a sample. */ int sample(); + + /** + * Returns an effectively unlimited stream of {@code int} sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(). + * + * @return a stream of {@code int} values. + * @since 1.5 + */ + default IntStream samples() { + return IntStream.generate(this::sample).sequential(); + } + + /** + * Returns a stream producing the given {@code streamSize} number of {@code int} + * sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}. + * + * @param streamSize Number of values to generate. + * @return a stream of {@code int} values. + * @since 1.5 + */ + default IntStream samples(long streamSize) { + return samples().limit(streamSize); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LongSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LongSampler.java index a3e2fac7..d0b99b8b 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LongSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LongSampler.java @@ -16,6 +16,8 @@ */ package org.apache.commons.rng.sampling.distribution; +import java.util.stream.LongStream; + /** * Sampler that generates values of type {@code long}. * @@ -23,9 +25,37 @@ package org.apache.commons.rng.sampling.distribution; */ public interface LongSampler { /** - * Creates a sample. + * Creates a {@code long} sample. * * @return a sample. */ long sample(); + + /** + * Returns an effectively unlimited stream of {@code long} sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(). + * + * @return a stream of {@code long} values. + * @since 1.5 + */ + default LongStream samples() { + return LongStream.generate(this::sample).sequential(); + } + + /** + * Returns a stream producing the given {@code streamSize} number of {@code long} + * sample values. + * + * <p>The default implementation produces a sequential stream that repeatedly + * calls {@link #sample sample}(); the stream is limited to the given {@code streamSize}. + * + * @param streamSize Number of values to generate. + * @return a stream of {@code long} values. + * @since 1.5 + */ + default LongStream samples(long streamSize) { + return samples().limit(streamSize); + } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ObjectSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ObjectSamplerTest.java new file mode 100644 index 00000000..7fc14662 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ObjectSamplerTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling; + +import java.util.concurrent.ThreadLocalRandom; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Tests the default methods in the {@link ObjectSampler} interface. + */ +class ObjectSamplerTest { + @Test + void testSamplesUnlimitedSize() { + final ObjectSampler<Double> s = RandomSource.SPLIT_MIX_64.create()::nextDouble; + Assertions.assertEquals(Long.MAX_VALUE, s.samples().spliterator().estimateSize()); + } + + @RepeatedTest(value = 3) + void testSamples() { + final long seed = RandomSource.createLong(); + final ObjectSampler<Double> s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + final ObjectSampler<Double> s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + final int count = ThreadLocalRandom.current().nextInt(3, 13); + Assertions.assertArrayEquals(createSamples(s1, count), + s2.samples().limit(count).toArray()); + } + + @ParameterizedTest + @ValueSource(ints = {0, 1, 2, 5, 13}) + void testSamples(int streamSize) { + final long seed = RandomSource.createLong(); + final ObjectSampler<Double> s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + final ObjectSampler<Double> s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + Assertions.assertArrayEquals(createSamples(s1, streamSize), + s2.samples(streamSize).toArray()); + } + + /** + * Creates an array of samples. + * + * @param sampler Source of samples. + * @param count Number of samples. + * @return the samples + */ + private static Double[] createSamples(ObjectSampler<Double> sampler, int count) { + final Double[] data = new Double[count]; + for (int i = 0; i < count; i++) { + // Explicit boxing + data[i] = Double.valueOf(sampler.sample()); + } + return data; + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplerTest.java new file mode 100644 index 00000000..0222abfb --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplerTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.concurrent.ThreadLocalRandom; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Tests the default methods in the {@link ContinuousSampler} interface. + */ +class ContinuousSamplerTest { + @Test + void testSamplesUnlimitedSize() { + final ContinuousSampler s = RandomSource.SPLIT_MIX_64.create()::nextDouble; + Assertions.assertEquals(Long.MAX_VALUE, s.samples().spliterator().estimateSize()); + } + + @RepeatedTest(value = 3) + void testSamples() { + final long seed = RandomSource.createLong(); + final ContinuousSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + final ContinuousSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + final int count = ThreadLocalRandom.current().nextInt(3, 13); + Assertions.assertArrayEquals(createSamples(s1, count), + s2.samples().limit(count).toArray()); + } + + @ParameterizedTest + @ValueSource(ints = {0, 1, 2, 5, 13}) + void testSamples(int streamSize) { + final long seed = RandomSource.createLong(); + final ContinuousSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + final ContinuousSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextDouble; + Assertions.assertArrayEquals(createSamples(s1, streamSize), + s2.samples(streamSize).toArray()); + } + + /** + * Creates an array of samples. + * + * @param sampler Source of samples. + * @param count Number of samples. + * @return the samples + */ + private static double[] createSamples(ContinuousSampler sampler, int count) { + final double[] data = new double[count]; + for (int i = 0; i < count; i++) { + data[i] = sampler.sample(); + } + return data; + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTest.java new file mode 100644 index 00000000..029d1a4e --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.concurrent.ThreadLocalRandom; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Tests the default methods in the {@link DiscreteSampler} interface. + */ +class DiscreteSamplerTest { + @Test + void testSamplesUnlimitedSize() { + final DiscreteSampler s = RandomSource.SPLIT_MIX_64.create()::nextInt; + Assertions.assertEquals(Long.MAX_VALUE, s.samples().spliterator().estimateSize()); + } + + @RepeatedTest(value = 3) + void testSamples() { + final long seed = RandomSource.createLong(); + final DiscreteSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextInt; + final DiscreteSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextInt; + final int count = ThreadLocalRandom.current().nextInt(3, 13); + Assertions.assertArrayEquals(createSamples(s1, count), + s2.samples().limit(count).toArray()); + } + + @ParameterizedTest + @ValueSource(ints = {0, 1, 2, 5, 13}) + void testSamples(int streamSize) { + final long seed = RandomSource.createLong(); + final DiscreteSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextInt; + final DiscreteSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextInt; + Assertions.assertArrayEquals(createSamples(s1, streamSize), + s2.samples(streamSize).toArray()); + } + + /** + * Creates an array of samples. + * + * @param sampler Source of samples. + * @param count Number of samples. + * @return the samples + */ + private static int[] createSamples(DiscreteSampler sampler, int count) { + final int[] data = new int[count]; + for (int i = 0; i < count; i++) { + data[i] = sampler.sample(); + } + return data; + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LongSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LongSamplerTest.java new file mode 100644 index 00000000..26ca85e5 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LongSamplerTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.concurrent.ThreadLocalRandom; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Tests the default methods in the {@link LongSampler} interface. + */ +class LongSamplerTest { + @Test + void testSamplesUnlimitedSize() { + final LongSampler s = RandomSource.SPLIT_MIX_64.create()::nextLong; + Assertions.assertEquals(Long.MAX_VALUE, s.samples().spliterator().estimateSize()); + } + + @RepeatedTest(value = 3) + void testSamples() { + final long seed = RandomSource.createLong(); + final LongSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextLong; + final LongSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextLong; + final int count = ThreadLocalRandom.current().nextInt(3, 13); + Assertions.assertArrayEquals(createSamples(s1, count), + s2.samples().limit(count).toArray()); + } + + @ParameterizedTest + @ValueSource(ints = {0, 1, 2, 5, 13}) + void testSamples(int streamSize) { + final long seed = RandomSource.createLong(); + final LongSampler s1 = RandomSource.SPLIT_MIX_64.create(seed)::nextLong; + final LongSampler s2 = RandomSource.SPLIT_MIX_64.create(seed)::nextLong; + Assertions.assertArrayEquals(createSamples(s1, streamSize), + s2.samples(streamSize).toArray()); + } + + /** + * Creates an array of samples. + * + * @param sampler Source of samples. + * @param count Number of samples. + * @return the samples + */ + private static long[] createSamples(LongSampler sampler, int count) { + final long[] data = new long[count]; + for (int i = 0; i < count; i++) { + data[i] = sampler.sample(); + } + return data; + } +}
