This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new fe338b19 [FLINK-31623] Fix DataStreamUtils#sample with approximate
uniform sampling
fe338b19 is described below
commit fe338b194b73fd51218f4d842fa7b0065fb76c56
Author: Fan Hong <[email protected]>
AuthorDate: Mon Apr 3 15:15:15 2023 +0800
[FLINK-31623] Fix DataStreamUtils#sample with approximate uniform sampling
This closes #227.
---
.../flink/ml/common/datastream/DataStreamUtils.java | 19 ++++++++++++++-----
.../ml/common/datastream/DataStreamUtilsTest.java | 15 +++++++++++++++
2 files changed, 29 insertions(+), 5 deletions(-)
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 691e7704..eb4ec6ca 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -275,7 +275,10 @@ public class DataStreamUtils {
}
/**
- * Performs a uniform sampling over the elements in a bounded data stream.
+ * Performs an approximate uniform sampling over the elements in a bounded
data stream. The
+ * difference of probabilities of two data points been sampled is bounded
by O(numSamples * p *
+ * p / (M * M)), where p is the parallelism of the input stream, M is the
total number of data
+ * points that the input stream contains.
*
* <p>This method takes samples without replacement. If the number of
elements in the stream is
* smaller than expected number of samples, all elements will be included
in the sample.
@@ -288,13 +291,19 @@ public class DataStreamUtils {
public static <T> DataStream<T> sample(DataStream<T> input, int
numSamples, long randomSeed) {
int inputParallelism = input.getParallelism();
- return input.transform(
- "samplingOperator",
+ // The maximum difference of number of data points in each partition
after calling
+ // `rebalance` is `inputParallelism`. As a result, extra
`inputParallelism` data points are
+ // sampled for each partition in the first round.
+ int firstRoundNumSamples =
+ Math.min((numSamples / inputParallelism) + inputParallelism,
numSamples);
+ return input.rebalance()
+ .transform(
+ "firstRoundSampling",
input.getType(),
- new SamplingOperator<>(numSamples, randomSeed))
+ new SamplingOperator<>(firstRoundNumSamples,
randomSeed))
.setParallelism(inputParallelism)
.transform(
- "samplingOperator",
+ "secondRoundSampling",
input.getType(),
new SamplingOperator<>(numSamples, randomSeed))
.setParallelism(1)
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index d3f8a95e..7b3e8b3a 100644
---
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -100,6 +100,21 @@ public class DataStreamUtilsTest {
assertEquals(Integer.toString(190 + env.getParallelism()),
stringSum.get(0));
}
+ @Test
+ public void testSample() throws Exception {
+ int numSamples = 10;
+ int[] totalMinusOneChoices = new int[] {0, 5, 9, 10, 11, 20, 30, 40,
200};
+ for (int totalMinusOne : totalMinusOneChoices) {
+ DataStream<Long> dataStream =
+ env.fromParallelCollection(
+ new NumberSequenceIterator(0L, totalMinusOne),
Types.LONG);
+ DataStream<Long> result = DataStreamUtils.sample(dataStream,
numSamples, 0);
+ //noinspection unchecked
+ List<String> sampled =
IteratorUtils.toList(result.executeAndCollect());
+ assertEquals(Math.min(numSamples, totalMinusOne + 1),
sampled.size());
+ }
+ }
+
@Test
public void testGenerateBatchData() throws Exception {
DataStream<Long> dataStream =