zhipeng93 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r876709362


##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,49 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/** A segment contains the information about a cache unit. */
 public class Segment implements Serializable {
 
-    private final Path path;
+    /** The pre-allocated path on disk to persist the records. */
+    Path path;
 
-    /** The count of the records in the file. */
-    private final int count;
+    /** The number of records in the file. */
+    int count;
 
-    /** The total length of file. */
-    private final long size;
+    /** The size of the records in file. */
+    long fsSize;
 
-    public Segment(Path path, int count, long size) {
+    /** The size of the records in memory. */
+    transient long inMemorySize;
+
+    /** The cached records in memory. */
+    transient List<Object> cache;
+
+    /** The serializer for the records. */
+    transient TypeSerializer<Object> serializer;
+
+    Segment() {}
+
+    Segment(Path path, int count, long fsSize) {
         this.path = path;
         this.count = count;
-        this.size = size;
-    }
-
-    public Path getPath() {
-        return path;
+        this.fsSize = fsSize;
     }
 
-    public int getCount() {
-        return count;
+    boolean isOnDisk() throws IOException {

Review Comment:
   What about we use MemorySegment and FsSegment?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -94,6 +120,26 @@ public static <T> DataStream<T> reduce(DataStream<T> input, 
ReduceFunction<T> fu
         }
     }
 
+    /**
+     * Takes a randomly sampled subset of elements in a bounded data stream.
+     *
+     * <p>If the number of elements in the stream is smaller than expected 
number of samples, all
+     * elements will be included in the sample.
+     *
+     * @param input The input data stream.
+     * @param numSamples The number of elements to be sampled.
+     * @param randomSeed The seed to randomly pick elements as sample.
+     * @return A data stream containing a list of the sampled elements.
+     */
+    public static <T> DataStream<List<T>> sample(
+            DataStream<T> input, int numSamples, long randomSeed) {
+        return input.transform(
+                        "samplingOperator",
+                        Types.LIST(input.getType()),
+                        new SamplingOperator<>(numSamples, randomSeed))
+                .setParallelism(1);

Review Comment:
   The semantic of `Sample` seems not to change the parallelism of the 
operator? Moreover, we probably should do distributed sampling for better 
performance.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,106 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/**
+ * A segment contains the information about a cache unit.
+ *
+ * <p>If the unit is persisted in a file on disk, this class provides the 
number of records in the
+ * unit, the path to the file, and the size of the file.
+ *
+ * <p>If the unit is cached in memory, this class provides the number of 
records, the cached
+ * objects, and information to persist them on disk, including the 
pre-allocated path, and the type
+ * serializer.
+ */
+@Internal
 public class Segment implements Serializable {
 
+    /** The pre-allocated path to persist records on disk. */
     private final Path path;
 
-    /** The count of the records in the file. */
+    /** The number of records in the file. */

Review Comment:
   nit: int the file --> in the segment



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheReader.java:
##########
@@ -20,120 +20,122 @@
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FileSystem;
-import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-
-import javax.annotation.Nullable;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
-/** Reads the cached data from a list of paths. */
+/** Reads the cached data from a list of segments. */
 public class DataCacheReader<T> implements Iterator<T> {
 
-    private final TypeSerializer<T> serializer;
+    private final MemoryManager memoryManager;
 
-    private final FileSystem fileSystem;
+    private final TypeSerializer<T> serializer;
 
     private final List<Segment> segments;
 
-    @Nullable private SegmentReader currentSegmentReader;
+    private SegmentReader<T> currentReader;
+
+    private SegmentWriter<T> cacheWriter;
+
+    private int segmentIndex;
 
     public DataCacheReader(
-            TypeSerializer<T> serializer, FileSystem fileSystem, List<Segment> 
segments)
-            throws IOException {
-        this(serializer, fileSystem, segments, new Tuple2<>(0, 0));
+            TypeSerializer<T> serializer, MemoryManager memoryManager, 
List<Segment> segments) {
+        this(serializer, memoryManager, segments, new Tuple2<>(0, 0));
     }
 
     public DataCacheReader(
             TypeSerializer<T> serializer,
-            FileSystem fileSystem,
+            MemoryManager memoryManager,
             List<Segment> segments,
-            Tuple2<Integer, Integer> readerPosition)
-            throws IOException {
-
+            Tuple2<Integer, Integer> readerPosition) {
+        this.memoryManager = memoryManager;
         this.serializer = serializer;
-        this.fileSystem = fileSystem;
         this.segments = segments;
+        this.segmentIndex = readerPosition.f0;
+
+        createSegmentReaderAndCache(readerPosition.f0, readerPosition.f1);
+    }
+
+    private void createSegmentReaderAndCache(int index, int startOffset) {
+        try {
+            cacheWriter = null;
 
-        if (readerPosition.f0 < segments.size()) {
-            this.currentSegmentReader = new SegmentReader(readerPosition.f0, 
readerPosition.f1);
+            if (index >= segments.size()) {
+                currentReader = null;
+                return;
+            }
+
+            currentReader = SegmentReader.create(serializer, 
segments.get(index), startOffset);
+
+            boolean shouldCacheInMemory =
+                    startOffset == 0
+                            && currentReader instanceof FsSegmentReader
+                            && 
MemoryUtils.isMemoryEnoughForCache(memoryManager);
+
+            if (shouldCacheInMemory) {
+                cacheWriter =
+                        SegmentWriter.create(
+                                segments.get(index).getPath(),
+                                memoryManager,
+                                serializer,
+                                segments.get(index).getFsSize(),
+                                true,
+                                false);
+            }
+
+        } catch (IOException e) {
+            throw new RuntimeException(e);
         }
     }
 
     @Override
     public boolean hasNext() {
-        return currentSegmentReader != null && currentSegmentReader.hasNext();
+        return currentReader != null && currentReader.hasNext();
     }
 
     @Override
     public T next() {
         try {
-            T next = currentSegmentReader.next();
-
-            if (!currentSegmentReader.hasNext()) {
-                currentSegmentReader.close();
-                if (currentSegmentReader.index < segments.size() - 1) {
-                    currentSegmentReader = new 
SegmentReader(currentSegmentReader.index + 1, 0);
-                } else {
-                    currentSegmentReader = null;
+            T record = currentReader.next();
+
+            if (cacheWriter != null) {
+                if (!cacheWriter.addRecord(record)) {

Review Comment:
   I am a bit confused here about adding a `cacheWriter` here. Could you 
explain a bit about this?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/FsSegmentWriter.java:
##########
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.Optional;
+
+/** A class that writes cache data to file system. */
+@Internal
+public class FsSegmentWriter<T> implements SegmentWriter<T> {
+    private final FileSystem fileSystem;
+
+    // TODO: adjust the file size limit automatically according to the 
provided file system.
+    private static final int CACHE_FILE_SIZE_LIMIT = 100 * 1024 * 1024; // 
100MB
+
+    private final TypeSerializer<T> serializer;
+
+    private final Path path;
+
+    private final FSDataOutputStream outputStream;
+
+    private final ByteArrayOutputStream byteArrayOutputStream;
+
+    private final ObjectOutputStream objectOutputStream;
+
+    private final DataOutputView outputView;
+
+    private int count;
+
+    public FsSegmentWriter(TypeSerializer<T> serializer, Path path) throws 
IOException {
+        this.serializer = serializer;
+        this.path = path;
+        this.fileSystem = path.getFileSystem();
+        this.outputStream = fileSystem.create(path, 
FileSystem.WriteMode.NO_OVERWRITE);
+        this.byteArrayOutputStream = new ByteArrayOutputStream();
+        this.objectOutputStream = new ObjectOutputStream(outputStream);

Review Comment:
   Using `BufferedOutputStream` to wrap the `FsDataOutputStream` can improve 
the performance here, without using the `objectOutputStream`. The intuition is 
that `objectOutputStream` use a buffer for each record, to reduce the number of 
calling `DataOutputStream#write`, while `BufferedOutputStream` use a buffer for 
(possibly) many records.
   
   A code example could be:
   
   ```
       void test1(final int numTries) throws IOException {
           Path path = new Path("/tmp/result1");
           FSDataOutputStream outputStream = path.getFileSystem().create(path, 
FileSystem.WriteMode.OVERWRITE);
           TypeSerializer serializer = DenseVectorSerializer.INSTANCE;
           DenseVector record = Vectors.dense(new double[100]);
   
           // add the following line comparing with the init implementation.
           BufferedOutputStream bufferedOutputStream = new 
BufferedOutputStream(outputStream);
   
           DataOutputView outputView = new 
DataOutputViewStreamWrapper(bufferedOutputStream);
           for (int i = 0; i < numTries; i ++) {
               serializer.serialize(record, outputView);
           }
           bufferedOutputStream.flush();
       }
   
       void test2(final int numTries) throws IOException {
           Path path = new Path("/tmp/result2");
           FSDataOutputStream outputStream = path.getFileSystem().create(path, 
FileSystem.WriteMode.OVERWRITE);
           TypeSerializer serializer = DenseVectorSerializer.INSTANCE;
           DenseVector record = Vectors.dense(new double[100]);
   
           ByteArrayOutputStream byteArrayOutputStream = new 
ByteArrayOutputStream();
           DataOutputView outputView = new 
DataOutputViewStreamWrapper(byteArrayOutputStream);
           for (int i = 0; i < numTries; i ++) {
               serializer.serialize(record, outputView);
               byte[] bytes = byteArrayOutputStream.toByteArray();
               ObjectOutputStream objectOutputStream = new 
ObjectOutputStream(outputStream);
               objectOutputStream.writeObject(bytes);
               byteArrayOutputStream.reset();
           }
       }
   
       @Test
       public void test() throws IOException {
           int numTries = 1000000;
           long time = System.currentTimeMillis();
           test1(numTries);
           System.out.println("Option-1: " + (System.currentTimeMillis() - 
time));
   
           time = System.currentTimeMillis();
           test2(numTries);
           System.out.println("Option-2: " + (System.currentTimeMillis() - 
time));
       }
   ```
   
   The result turns to be 
   ```
   Option-1: 3005
   Option-2: 14898
   ```
   
   



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,106 @@
 
 package org.apache.flink.iteration.datacache.nonkeyed;
 
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
+import java.util.List;
 import java.util.Objects;
 
-/** A segment represents a single file for the cache. */
+/**
+ * A segment contains the information about a cache unit.
+ *
+ * <p>If the unit is persisted in a file on disk, this class provides the 
number of records in the
+ * unit, the path to the file, and the size of the file.
+ *
+ * <p>If the unit is cached in memory, this class provides the number of 
records, the cached
+ * objects, and information to persist them on disk, including the 
pre-allocated path, and the type
+ * serializer.
+ */
+@Internal

Review Comment:
   Is using `MemorySegment` and `FsSegment` accordingly more clear ?



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final Segment segment;
+
+    private final MemoryManager memoryManager;
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager)
+            throws MemoryAllocationException {
+        this(path, memoryManager, 0L);
+    }
+
+    public MemorySegmentWriter(Path path, MemoryManager memoryManager, long 
expectedSize)
+            throws MemoryAllocationException {
+        Preconditions.checkNotNull(memoryManager);
+        this.segment = new Segment();
+        this.segment.path = path;
+        this.segment.cache = new ArrayList<>();
+        this.segment.inMemorySize = 0L;
+        this.memoryManager = memoryManager;
+    }
+
+    @Override
+    public boolean addRecord(T record) throws IOException {
+        if (!MemoryUtils.isMemoryEnoughForCache(memoryManager)) {
+            return false;
+        }
+
+        long recordSize = GraphLayout.parseInstance(record).totalSize();
+
+        try {
+            memoryManager.reserveMemory(this, recordSize);
+        } catch (MemoryReservationException e) {
+            return false;
+        }
+
+        this.segment.cache.add(record);
+        segment.inMemorySize += recordSize;
+
+        this.segment.count++;
+        return true;
+    }
+
+    @Override
+    public Optional<Segment> finish() throws IOException {
+        if (segment.count > 0) {
+            return Optional.of(segment);
+        } else {
+            memoryManager.releaseMemory(segment.path, segment.inMemorySize);

Review Comment:
   Is `releaseMemory` needed here? It seems that `inMemorySize` should be 
always zero here.



##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/MemorySegmentWriter.java:
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.memory.MemoryReservationException;
+import org.apache.flink.util.Preconditions;
+
+import org.openjdk.jol.info.GraphLayout;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** A class that writes cache data to memory segments. */
+@Internal
+public class MemorySegmentWriter<T> implements SegmentWriter<T> {
+    private final MemoryManager memoryManager;
+
+    private final Path path;
+
+    private final List<T> cache;
+
+    private final TypeSerializer<T> serializer;
+
+    private long inMemorySize;
+
+    private int count;
+
+    private long reservedMemorySize;
+
+    public MemorySegmentWriter(
+            Path path, MemoryManager memoryManager, TypeSerializer<T> 
serializer, long expectedSize)
+            throws MemoryReservationException {
+        this.serializer = serializer;
+        Preconditions.checkNotNull(memoryManager);
+        this.path = path;
+        this.cache = new ArrayList<>();
+        this.inMemorySize = 0L;
+        this.count = 0;
+        this.memoryManager = memoryManager;
+
+        if (expectedSize > 0) {
+            memoryManager.reserveMemory(this.path, expectedSize);
+        }
+        this.reservedMemorySize = expectedSize;
+    }
+
+    @Override
+    public boolean addRecord(T record) {

Review Comment:
   The way of using `MemoryManager` seems not appropriate to me after digging 
into the usage of `MemoryManager`. [1][2]
   
   The code snippet here seems to be caching the record in java heap, but 
trying to reserve memory from off-heap memory. If I am understanding [1] [2] 
correctly, 
   - When using `MemoryManager` to manipulate managed memory, we are mostly 
dealing with off-heap memory.
   - The managed memory for each operator should be a fixed one after 
generating the job graph, i.e., it is not dynamically allocated.
   - The usage of managed memory should be declared to the jobgraph explicitly 
and then be used by the operator. Otherwise it will lead to OOM if deployed in 
a container.
   
   As I see, there are basically two options to cache the data:
   - cache it in `task heap` (i.e., cache it in a `list`): It is simple and 
easy to implement, but the downside is that we cannot control the size of 
cached element `statically` and the program may not be robust --- `task heap` 
is shared among the JVM and we have no idea about how others are using the JVM 
heap memory. Moreover, we need to write the `list` to state for recovery.
   - cache it in `off-heap` (for example using the managed memory). In this 
way, we need to declare the usage of the managed to the job graph via 
`Transformation#declareManagedMemoryUseCaseAtOperatorScope` or 
`Transformation#declareManagedMemoryUseCaseAtSlotScope` and get the fraction of 
the managed memory from [3].
   
   
   I would suggest to go with option-2, but need more discussions with the 
runtime guys.
   
   [1] 
https://cwiki.apache.org/confluence/display/FLINK/FLIP-53%3A+Fine+Grained+Operator+Resource+Management
   [2] 
https://cwiki.apache.org/confluence/display/FLINK/FLIP-141%3A+Intra-Slot+Managed+Memory+Sharing
   [3] 
https://github.com/apache/flink/blob/18a967f8ad7b22c2942e227fb84f08f552660b5a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/sort/SortOperator.java#L79



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to