zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891126951


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +307,79 @@ public void snapshotState(StateSnapshotContext context) 
throws Exception {
             }
         }
     }
+
+    /**
+     * A stream operator that takes a randomly sampled subset of elements in a 
bounded data stream.
+     */
+    private static class SamplingOperator<T> extends 
AbstractStreamOperator<List<T>>
+            implements OneInputStreamOperator<T, List<T>>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, 
getClass().getClassLoader()));
+            samplesState = 
context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>();
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", 
IntSerializer.INSTANCE);
+            countState = 
context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws 
Exception {
+            T sample = streamRecord.getValue();
+            count++;
+
+            // Code below is inspired by the Reservoir Sampling algorithm.
+            if (samples.size() < numSamples) {
+                samples.add(sample);
+            } else {
+                if (random.nextInt(count) < numSamples) {
+                    samples.set(random.nextInt(numSamples), sample);
+                }
+            }
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            Collections.shuffle(samples, random);

Review Comment:
   > the first arriving element, if sampled, will always be the first returning 
element
   
   What is the problem of situation? I think it is ok if the sampled element 
preserves the order in each worker. If we look at spark#sample [1], the order 
in each partition also preserves.
   
   [1] 
https://github.com/apache/spark/blob/6026dd25748fd79caeedc083f99d5c954fb3a19f/core/src/main/scala/org/apache/spark/rdd/RDD.scala#L554



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to