lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r890241966


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> 
input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, 
OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;
+
+        private StreamTask<?, ?> containingTask;
 
-        public MapPartitionOperator(MapPartitionFunction<IN, OUT> 
mapPartitionFunc) {
+        private DataCacheWriter<IN> dataCacheWriter;
+
+        public MapPartitionOperator(
+                MapPartitionFunction<IN, OUT> mapPartitionFunc, 
TypeInformation<IN> inputType) {
             super(mapPartitionFunc);
+            this.inputType = inputType;
+        }
+
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<OUT>> output) {
+            super.setup(containingTask, config, output);
+
+            basePath =
+                    OperatorUtils.getDataCachePath(
+                            
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+            this.config = config;

Review Comment:
   nits: I think we typically put simple assignment (e.g. `this.config = 
config`) before non-trivial instantiation (e.g. `basePath = ..`). Could you 
update the code to follow this convention?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java:
##########
@@ -19,127 +19,162 @@
 package org.apache.flink.iteration.datacache.nonkeyed;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
 
+    /** A soft limit on the max allowed size of a single segment. */
+    static final long MAX_SEGMENT_SIZE = 1L << 30; // 1GB
+
+    /** The tool to serialize received records into bytes. */
     private final TypeSerializer<T> serializer;
 
+    /** The file system that contains the cache files. */
     private final FileSystem fileSystem;
 
+    /** A generator to generate paths of cache files. */
     private final SupplierWithException<Path, IOException> pathGenerator;
 
-    private final List<Segment> finishSegments;
+    /** An optional pool that provide memory segments to hold cached records 
in memory. */
+    @Nullable private final MemorySegmentPool segmentPool;
+
+    /** The segments that contain previously added records. */
+    private final List<Segment> finishedSegments;
 
-    private SegmentWriter currentSegment;
+    /** The current writer for new records. */
+    @Nullable private SegmentWriter<T> currentSegmentWriter;
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
-        this(serializer, fileSystem, pathGenerator, Collections.emptyList());
+        this(serializer, fileSystem, pathGenerator, null, 
Collections.emptyList());
     }
 
     public DataCacheWriter(
             TypeSerializer<T> serializer,
             FileSystem fileSystem,
             SupplierWithException<Path, IOException> pathGenerator,
-            List<Segment> priorFinishedSegments)
+            MemorySegmentPool segmentPool)
             throws IOException {
-        this.serializer = serializer;
-        this.fileSystem = fileSystem;
-        this.pathGenerator = pathGenerator;
-
-        this.finishSegments = new ArrayList<>(priorFinishedSegments);
+        this(serializer, fileSystem, pathGenerator, segmentPool, 
Collections.emptyList());
+    }
 
-        this.currentSegment = new SegmentWriter(pathGenerator.get());
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            List<Segment> finishedSegments)
+            throws IOException {
+        this(serializer, fileSystem, pathGenerator, null, finishedSegments);
     }
 
-    public void addRecord(T record) throws IOException {
-        currentSegment.addRecord(record);
+    public DataCacheWriter(
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator,
+            @Nullable MemorySegmentPool segmentPool,
+            List<Segment> finishedSegments)

Review Comment:
   nits: would it be better to keep the previous name `priorFinishedSegments`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> 
input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, 
OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;

Review Comment:
   It appears that we only need `config.getOperatorID()` from this config. 
Would it be simpler to just save the `OperatorID`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -256,4 +392,81 @@ public void flatMap(T[] values, Collector<Tuple2<Integer, 
T[]>> collector) {
             }
         }
     }
+
+    /*
+     * A stream operator that takes a randomly sampled subset of elements in a 
bounded data stream.
+     */
+    private static class SamplingOperator<T> extends AbstractStreamOperator<T>
+            implements OneInputStreamOperator<T, T>, BoundedOneInput {
+        private final int numSamples;
+
+        private final Random random;
+
+        private ListState<T> samplesState;
+
+        private List<T> samples;
+
+        private ListState<Integer> countState;
+
+        private int count;
+
+        SamplingOperator(int numSamples, long randomSeed) {
+            this.numSamples = numSamples;
+            this.random = new Random(randomSeed);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            ListStateDescriptor<T> samplesDescriptor =
+                    new ListStateDescriptor<>(
+                            "samplesState",
+                            getOperatorConfig()
+                                    .getTypeSerializerIn(0, 
getClass().getClassLoader()));
+            samplesState = 
context.getOperatorStateStore().getListState(samplesDescriptor);
+            samples = new ArrayList<>(numSamples);
+            samplesState.get().forEach(samples::add);
+
+            ListStateDescriptor<Integer> countDescriptor =
+                    new ListStateDescriptor<>("countState", 
IntSerializer.INSTANCE);
+            countState = 
context.getOperatorStateStore().getListState(countDescriptor);
+            Iterator<Integer> countIterator = countState.get().iterator();
+            if (countIterator.hasNext()) {
+                count = countIterator.next();
+            } else {
+                count = 0;
+            }
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            samplesState.update(samples);
+            countState.update(Collections.singletonList(count));
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws 
Exception {
+            T sample = streamRecord.getValue();

Review Comment:
   nits: sample means something that is chosen from a collection. Since it is 
not sure we will chose this value, would it be simpler to just name it `value`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> 
input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, 
OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;
+
+        private StreamTask<?, ?> containingTask;
 
-        public MapPartitionOperator(MapPartitionFunction<IN, OUT> 
mapPartitionFunc) {
+        private DataCacheWriter<IN> dataCacheWriter;
+
+        public MapPartitionOperator(
+                MapPartitionFunction<IN, OUT> mapPartitionFunc, 
TypeInformation<IN> inputType) {
             super(mapPartitionFunc);
+            this.inputType = inputType;
+        }
+
+        @Override
+        public void setup(
+                StreamTask<?, ?> containingTask,
+                StreamConfig config,
+                Output<StreamRecord<OUT>> output) {
+            super.setup(containingTask, config, output);
+
+            basePath =
+                    OperatorUtils.getDataCachePath(
+                            
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                            containingTask
+                                    .getEnvironment()
+                                    .getIOManager()
+                                    .getSpillingDirectoriesPaths());
+            this.config = config;
+            this.containingTask = containingTask;
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            ListStateDescriptor<IN> descriptor =
-                    new ListStateDescriptor<>(
-                            "inputState",
-                            getOperatorConfig()
-                                    .getTypeSerializerIn(0, 
getClass().getClassLoader()));
-            valuesState = 
context.getOperatorStateStore().getListState(descriptor);
+
+            List<StatePartitionStreamProvider> inputs =
+                    
IteratorUtils.toList(context.getRawOperatorStateInputs().iterator());
+            Preconditions.checkState(
+                    inputs.size() < 2, "The input from raw operator state 
should be one or zero.");
+
+            List<Segment> priorFinishedSegments = new ArrayList<>();
+            if (inputs.size() > 0) {
+
+                InputStream inputStream = inputs.get(0).getStream();
+
+                DataCacheSnapshot dataCacheSnapshot =
+                        DataCacheSnapshot.recover(
+                                inputStream,
+                                basePath.getFileSystem(),
+                                OperatorUtils.createDataCacheFileGenerator(
+                                        basePath, "cache", 
config.getOperatorID()));
+
+                priorFinishedSegments = dataCacheSnapshot.getSegments();
+            }
+
+            dataCacheWriter =
+                    new DataCacheWriter<>(
+                            
inputType.createSerializer(containingTask.getExecutionConfig()),
+                            basePath.getFileSystem(),
+                            OperatorUtils.createDataCacheFileGenerator(
+                                    basePath, "cache", config.getOperatorID()),
+                            priorFinishedSegments);
         }
 
         @Override
-        public void endInput() throws Exception {
-            userFunction.mapPartition(valuesState.get(), new 
TimestampedCollector<>(output));
-            valuesState.clear();
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+
+            dataCacheWriter.writeSegmentsToFiles();
+            DataCacheSnapshot dataCacheSnapshot =
+                    new DataCacheSnapshot(
+                            basePath.getFileSystem(), null, 
dataCacheWriter.getSegments());
+            context.getRawOperatorStateOutput().startNewPartition();
+            dataCacheSnapshot.writeTo(context.getRawOperatorStateOutput());
         }
 
         @Override
         public void processElement(StreamRecord<IN> input) throws Exception {
-            valuesState.add(input.getValue());
+            dataCacheWriter.addRecord(input.getValue());
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            List<Segment> pendingSegments = dataCacheWriter.getSegments();

Review Comment:
   nits: It is not clear what `pending` means in this context. Would it be 
simpler to just name it `segments`?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -113,32 +162,119 @@ public static <T> DataStream<T> reduce(DataStream<T> 
input, ReduceFunction<T> fu
             extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, 
OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private ListState<IN> valuesState;
+        private final TypeInformation<IN> inputType;
+
+        private Path basePath;
+
+        private StreamConfig config;
+
+        private StreamTask<?, ?> containingTask;

Review Comment:
   It appears that we only need `containingTask.getExecutionConfig()`. Would it 
be simpler to just save the ExecutionConfig?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to