lindong28 commented on code in PR #97:
URL: https://github.com/apache/flink-ml/pull/97#discussion_r891014714
##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
package org.apache.flink.iteration.datacache.nonkeyed;
+import org.apache.flink.annotation.Internal;
import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
-import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+ /** The path to the file containing persisted records. */
private final Path path;
- /** The count of the records in the file. */
+ /**
+ * The count of records in the file at the path if the file size is not
zero, otherwise the
+ * count of records in the cache.
+ */
private final int count;
- /** The total length of file. */
- private final long size;
+ /**
+ * The total length of file containing persisted records. Its value is 0
iff the segment has not
+ * been written to the given path.
+ */
+ private long fsSize = 0L;
+
+ /**
+ * The memory segments containing cached records. This list is empty iff
the segment has not
+ * been cached in memory.
+ */
+ private List<MemorySegment> cache = new ArrayList<>();
+
+ Segment(Path path, int count, long fsSize) {
+ this.path = checkNotNull(path);
+ checkArgument(count > 0);
+ this.count = count;
+ checkArgument(fsSize > 0);
+ this.fsSize = fsSize;
+ }
- public Segment(Path path, int count, long size) {
- this.path = path;
+ Segment(Path path, int count, List<MemorySegment> cache) {
+ this.path = checkNotNull(path);
+ checkArgument(count > 0);
this.count = count;
- this.size = size;
+ this.cache = checkNotNull(cache);
+ }
+
+ void setCache(List<MemorySegment> cache) {
+ this.cache = checkNotNull(cache);
Review Comment:
nits: we typically don't explicitly check whether the input argument is null
in such cases. Could you update the code for consistency and simplicity?
##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java:
##########
@@ -167,26 +183,69 @@ public static DataCacheSnapshot recover(
if (isDistributedFS) {
segments = deserializeSegments(dis);
} else {
- int totalRecords = dis.readInt();
- long totalSize = dis.readLong();
-
- Path path = pathGenerator.get();
- try (FSDataOutputStream outputStream =
- fileSystem.create(path,
FileSystem.WriteMode.NO_OVERWRITE)) {
-
- BoundedInputStream inputStream =
- new BoundedInputStream(checkpointInputStream,
totalSize);
- inputStream.setPropagateClose(false);
- IOUtils.copyBytes(inputStream, outputStream, false);
- inputStream.close();
+ int segmentNum = dis.readInt();
+ segments = new ArrayList<>(segmentNum);
+ for (int i = 0; i < segmentNum; i++) {
+ int count = dis.readInt();
+ long fsSize = dis.readLong();
+ Path path = pathGenerator.get();
+ try (FSDataOutputStream outputStream =
Review Comment:
Prior to this PR, when we recover from a snapshot of multiple smaller
segments, we might merge these segments into one segment. We no longer do this
after this PR. Could you double check its performance impact with @gaoyunhaii?
##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java:
##########
@@ -18,38 +18,80 @@
package org.apache.flink.iteration.datacache.nonkeyed;
+import org.apache.flink.annotation.Internal;
import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.MemorySegment;
-import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
-/** A segment represents a single file for the cache. */
-public class Segment implements Serializable {
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+/** A segment contains the information about a cache unit. */
+@Internal
+public class Segment {
+
+ /** The path to the file containing persisted records. */
private final Path path;
- /** The count of the records in the file. */
+ /**
+ * The count of records in the file at the path if the file size is not
zero, otherwise the
+ * count of records in the cache.
+ */
private final int count;
- /** The total length of file. */
- private final long size;
+ /**
+ * The total length of file containing persisted records. Its value is 0
iff the segment has not
+ * been written to the given path.
+ */
+ private long fsSize = 0L;
+
+ /**
+ * The memory segments containing cached records. This list is empty iff
the segment has not
+ * been cached in memory.
+ */
+ private List<MemorySegment> cache = new ArrayList<>();
+
+ Segment(Path path, int count, long fsSize) {
+ this.path = checkNotNull(path);
Review Comment:
nits: the code probably looks nicer if we assign all variables before
checking their values.
Same for the other constructors.
##########
flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/ListStateWithCache.java:
##########
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.table.runtime.util.LazyMemorySegmentPool;
+import org.apache.flink.table.runtime.util.MemorySegmentPool;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A {@link ListState} child class that records data and replays them on
required.
+ *
+ * <p>This class basically stores data in file system, and provides the option
to cache them in
+ * memory. In order to use the memory caching option, users need to allocate
certain managed memory
+ * for the wrapper operator through {@link
+ *
org.apache.flink.api.dag.Transformation#declareManagedMemoryUseCaseAtOperatorScope}.
+ *
+ * <p>NOTE: Users need to explicitly invoke this class's {@link
+ * #snapshotState(StateSnapshotContext)} method in order to store the recorded
data in snapshot.
+ */
+public class ListStateWithCache<T> implements ListState<T> {
+
+ /** The tool to serialize/deserialize records. */
+ private final TypeSerializer<T> serializer;
+
+ /** The path of the directory that holds the files containing recorded
data. */
+ private final Path basePath;
+
+ /** The data cache writer for the received records. */
+ private final DataCacheWriter<T> dataCacheWriter;
+
+ @SuppressWarnings("unchecked")
+ public ListStateWithCache(
+ TypeSerializer<T> serializer,
+ StreamTask<?, ?> containingTask,
+ StreamingRuntimeContext runtimeContext,
+ StateInitializationContext stateInitializationContext,
+ OperatorID operatorID)
+ throws IOException {
+ this.serializer = serializer;
+
+ MemorySegmentPool segmentPool = null;
+ double fraction =
+ containingTask
+ .getConfiguration()
+ .getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.OPERATOR,
+
runtimeContext.getTaskManagerRuntimeInfo().getConfiguration(),
+ runtimeContext.getUserCodeClassLoader());
+ if (fraction > 0) {
+ MemoryManager memoryManager =
containingTask.getEnvironment().getMemoryManager();
+ segmentPool =
+ new LazyMemorySegmentPool(
+ containingTask,
+ memoryManager,
+ memoryManager.computeNumberOfPages(fraction));
+ }
+
+ basePath =
+ OperatorUtils.getDataCachePath(
+
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+ containingTask
+ .getEnvironment()
+ .getIOManager()
+ .getSpillingDirectoriesPaths());
+
+ List<StatePartitionStreamProvider> inputs =
+ IteratorUtils.toList(
+
stateInitializationContext.getRawOperatorStateInputs().iterator());
+ Preconditions.checkState(
+ inputs.size() < 2, "The input from raw operator state should
be one or zero.");
+
+ List<Segment> priorFinishedSegments = new ArrayList<>();
+ if (inputs.size() > 0) {
+ DataCacheSnapshot dataCacheSnapshot =
+ DataCacheSnapshot.recover(
+ inputs.get(0).getStream(),
+ basePath.getFileSystem(),
+ OperatorUtils.createDataCacheFileGenerator(
+ basePath, "cache", operatorID));
+
+ if (segmentPool != null) {
+ dataCacheSnapshot.tryReadSegmentsToMemory(serializer,
segmentPool);
+ }
+
+ priorFinishedSegments = dataCacheSnapshot.getSegments();
+ }
+
+ this.dataCacheWriter =
+ new DataCacheWriter<>(
+ serializer,
+ basePath.getFileSystem(),
+ OperatorUtils.createDataCacheFileGenerator(basePath,
"cache", operatorID),
+ segmentPool,
+ priorFinishedSegments);
+ }
+
+ public void snapshotState(StateSnapshotContext context) throws Exception {
Review Comment:
Since snapshot() and add() re-use the same serializer and the serializer is
not thread safe, could you double check that snapshot() and add() won't be
invoked concurrently?
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -253,59 +258,76 @@ public Tuple3<Integer, DenseVector, Long>
map(Tuple2<Integer, DenseVector> value
implements TwoInputStreamOperator<
DenseVector, DenseVector[], Tuple2<Integer,
DenseVector>>,
IterationListener<Tuple2<Integer, DenseVector>> {
+
private final DistanceMeasure distanceMeasure;
- private ListState<DenseVector> points;
- private ListState<DenseVector[]> centroids;
+
+ private ListState<DenseVector[]> centroidsState;
Review Comment:
nits: could we keep the original name `centroids` for simplicity? Same for
points.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java:
##########
@@ -160,6 +162,9 @@ public IterationBodyResult process(
BasicTypeInfo.INT_TYPE_INFO,
DenseVectorTypeInfo.INSTANCE),
new
SelectNearestCentroidOperator(distanceMeasure));
+ centroidIdAndPoints
+ .getTransformation()
+
.declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase.OPERATOR, 64);
Review Comment:
Instead of explicitly calling this API, could this be specified in Flink job
configuration according to this documentation
https://nightlies.apache.org/flink/flink-docs-master/docs/deployment/memory/mem_setup_tm/#consumer-weights
?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]