guoweiM commented on a change in pull request #25:
URL: https://github.com/apache/flink-ml/pull/25#discussion_r746333683



##########
File path: 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
##########
@@ -50,12 +69,206 @@ public void testReplaying() throws Exception {
         final int numRecords = 10;
         OperatorID operatorId = new OperatorID();
 
+        createHarnessAndRun(
+                operatorId,
+                null,
+                harness -> {
+                    // First round
+                    for (int i = 0; i < numRecords; ++i) {
+                        harness.processElement(
+                                new 
StreamRecord<>(IterationRecord.newRecord(i, 0)), 0, 0);
+                    }
+                    harness.endInput(0, true);
+                    harness.processElement(
+                            new 
StreamRecord<>(IterationRecord.newEpochWatermark(0, "sender1")),
+                            1,
+                            0);
+                    assertOutputAllRecordsAndEpochWatermark(
+                            harness.getOutput(), numRecords, operatorId, 0);
+                    harness.getOutput().clear();
+
+                    // The round 1
+                    harness.processElement(
+                            new 
StreamRecord<>(IterationRecord.newEpochWatermark(1, "sender1")),
+                            1,
+                            0);
+                    // The output would be done asynchronously inside the 
ReplayerOperator.
+                    while (harness.getOutput().size() < numRecords + 1) {
+                        Thread.sleep(500);
+                    }
+                    assertOutputAllRecordsAndEpochWatermark(
+                            harness.getOutput(), numRecords, operatorId, 1);
+                    harness.getOutput().clear();
+
+                    // The round 2
+                    harness.processElement(
+                            new 
StreamRecord<>(IterationRecord.newEpochWatermark(2, "sender1")),
+                            1,
+                            0);
+                    // The output would be done asynchronously inside the 
ReplayerOperator.
+                    while (harness.getOutput().size() < numRecords + 1) {
+                        Thread.sleep(500);
+                    }
+                    assertOutputAllRecordsAndEpochWatermark(
+                            harness.getOutput(), numRecords, operatorId, 2);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreOnFirstEpoch() throws Exception {
+        final int numRecords = 10;
+        OperatorID operatorId = new OperatorID();
+
+        List<Object> firstRoundOutput = new ArrayList<>();
+        List<Object> secondRoundOutput = new ArrayList<>();
+
+        TaskStateSnapshot snapshot =
+                createHarnessAndRun(
+                        operatorId,
+                        null,
+                        harness -> {
+                            
harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            for (int i = 0; i < numRecords / 2; ++i) {
+                                harness.processElement(
+                                        new 
StreamRecord<>(IterationRecord.newRecord(i, 0)), 0, 0);
+                            }
+
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    
CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            firstRoundOutput.addAll(harness.getOutput());
+
+                            
harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+
+        createHarnessAndRun(
+                operatorId,
+                snapshot,
+                harness -> {
+                    for (int i = numRecords / 2; i < numRecords; ++i) {
+                        harness.processElement(
+                                new 
StreamRecord<>(IterationRecord.newRecord(i, 0)), 0, 0);
+                    }
+                    harness.endInput(0, true);
+                    harness.processElement(
+                            new 
StreamRecord<>(IterationRecord.newEpochWatermark(0, "send-0")),
+                            1,
+                            0);
+                    harness.processAll();
+                    firstRoundOutput.addAll(harness.getOutput());
+
+                    // The second round
+                    harness.getOutput().clear();
+                    harness.processElement(
+                            new 
StreamRecord<>(IterationRecord.newEpochWatermark(1, "send-0")),
+                            1,
+                            0);
+                    secondRoundOutput.addAll(harness.getOutput());
+
+                    return null;
+                });
+
+        assertOutputAllRecordsAndEpochWatermark(firstRoundOutput, numRecords, 
operatorId, 0);
+        assertOutputAllRecordsAndEpochWatermark(secondRoundOutput, numRecords, 
operatorId, 1);
+    }
+
+    @Test
+    public void testSnapshotAndRestoreOnSecondEpoch() throws Exception {
+        final int numRecords = 20;

Review comment:
       Would you like enlight me a little why this test case sets the 
`numRecords` 20 but others set it 10.

##########
File path: 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
##########
@@ -50,12 +69,206 @@ public void testReplaying() throws Exception {
         final int numRecords = 10;
         OperatorID operatorId = new OperatorID();
 
+        createHarnessAndRun(

Review comment:
       extend `TestLoger`

##########
File path: 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
##########
@@ -100,6 +110,15 @@
 
     private transient KeySelector<?, ?> stateKeySelector2;
 
+    // --------------- state ---------------------------
+    private int latestEpochWatermark = -1;

Review comment:
       `latestEpochWatermark` is not state?

##########
File path: 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
##########
@@ -60,11 +73,19 @@
 
     private TypeSerializer<T> typeSerializer;
 
-    private Executor dataReplayerExecutor;
+    private MailboxExecutor mailboxExecutor;
 
     private DataCacheWriter<T> dataCacheWriter;
 
-    private AtomicReference<DataCacheReader<T>> currentDataCacheReader;
+    @Nullable private DataCacheReader<T> currentDataCacheReader;
+
+    private int currentEpoch;
+
+    // ------------- states -------------------
+
+    private ListState<Integer> parallelismState;

Review comment:
       Please add some comments

##########
File path: 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
##########
@@ -88,33 +109,108 @@ public void setup(
                     (IterationRecordSerializer)
                             
config.getTypeSerializerOut(getClass().getClassLoader());
             typeSerializer = iterationRecordSerializer.getInnerSerializer();
-            dataReplayerExecutor =
-                    Executors.newSingleThreadExecutor(
-                            runnable -> {
-                                Thread thread = new Thread(runnable);
-                                thread.setName(
-                                        "Replay-"
-                                                + getOperatorID()
-                                                + "-"
-                                                + 
containingTask.getIndexInSubtaskGroup());
-                                return thread;
-                            });
+
+            mailboxExecutor =
+                    containingTask
+                            .getMailboxExecutorFactory()
+                            .createExecutor(TaskMailbox.MIN_PRIORITY);
+        } catch (Exception e) {
+            ExceptionUtils.rethrow(e);
+        }
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws 
Exception {
+        super.initializeState(context);
+
+        parallelismState =
+                context.getOperatorStateStore()
+                        .getUnionListState(
+                                new ListStateDescriptor<>("parallelism", 
IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+                .ifPresent(
+                        oldParallelism ->
+                                checkState(
+                                        oldParallelism
+                                                == getRuntimeContext()
+                                                        
.getNumberOfParallelSubtasks(),
+                                        "The Replay operator is recovered with 
parallelism changed from "
+                                                + oldParallelism
+                                                + " to "
+                                                + getRuntimeContext()
+                                                        
.getNumberOfParallelSubtasks()));
+
+        currentEpochState =
+                context.getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<Integer>("epoch", 
IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(currentEpochState, "epoch")
+                .ifPresent(epoch -> currentEpoch = epoch);
+
+        try {
+            SupplierWithException<Path, IOException> pathGenerator =
+                    OperatorUtils.createDataCacheFileGenerator(
+                            basePath, "replay", config.getOperatorID());
+
+            DataCacheSnapshot dataCacheSnapshot = null;
+            List<StatePartitionStreamProvider> rawStateInputs =
+                    
IteratorUtils.toList(context.getRawOperatorStateInputs().iterator());
+            if (rawStateInputs.size() > 0) {
+                checkState(
+                        rawStateInputs.size() == 1,
+                        "Currently the replay operator does not support 
rescaling");
+
+                dataCacheSnapshot =
+                        DataCacheSnapshot.recover(
+                                rawStateInputs.get(0).getStream(), fileSystem, 
pathGenerator);
+            }
+
             dataCacheWriter =
                     new DataCacheWriter<>(
                             typeSerializer,
                             fileSystem,
-                            OperatorUtils.createDataCacheFileGenerator(
-                                    basePath, "replay", 
config.getOperatorID()));
+                            pathGenerator,
+                            dataCacheSnapshot == null
+                                    ? Collections.emptyList()
+                                    : dataCacheSnapshot.getSegments());
+
+            if (dataCacheSnapshot != null && 
dataCacheSnapshot.getReaderPosition() != null) {
+                currentDataCacheReader =
+                        new DataCacheReader<>(
+                                typeSerializer,
+                                fileSystem,
+                                dataCacheSnapshot.getSegments(),
+                                dataCacheSnapshot.getReaderPosition());
+            }
 
-            currentDataCacheReader = new AtomicReference<>();
         } catch (Exception e) {
-            ExceptionUtils.rethrow(e);
+            throw new FlinkRuntimeException("Failed to replay the records", e);
         }
     }
 
     @Override
-    public void initializeState(StateInitializationContext context) throws 
Exception {
-        super.initializeState(context);
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+
+        // Always clear the union list state before set value.

Review comment:
       clear -> clears?

##########
File path: 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
##########
@@ -100,6 +110,15 @@
 
     private transient KeySelector<?, ?> stateKeySelector2;
 
+    // --------------- state ---------------------------
+    private int latestEpochWatermark = -1;
+
+    private ListState<Integer> parallelismState;

Review comment:
       Would you like to add some comments for the following state variables?

##########
File path: 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
##########
@@ -223,7 +246,44 @@ public void initializeState(StreamTaskStateInitializer 
streamTaskStateManager)
 
     @Override
     public void initializeState(StateInitializationContext context) throws 
Exception {
-        // Do thing for now since we do not have states.
+        parallelismState =

Review comment:
       Do we need to call the super method?

##########
File path: 
flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReducePerRoundOperator.java
##########
@@ -18,11 +18,19 @@
 
 package org.apache.flink.test.iteration.operators;
 
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.OutputTag;
 
+import java.util.Collections;
+
 /**
  * An operators that reduce the received numbers and emit the result into the 
output, and also emit

Review comment:
       operators --> operator

##########
File path: 
flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundCheckpointITCase.java
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.test.iteration.operators.EpochRecord;
+import org.apache.flink.test.iteration.operators.FailingMap;
+import org.apache.flink.test.iteration.operators.OutputRecord;
+import org.apache.flink.test.iteration.operators.SequenceSource;
+import 
org.apache.flink.test.iteration.operators.TwoInputReducePerRoundOperator;
+import org.apache.flink.testutils.junit.SharedObjects;
+import org.apache.flink.testutils.junit.SharedReference;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static 
org.apache.flink.test.iteration.UnboundedStreamIterationITCase.createMiniClusterConfiguration;
+import static org.junit.Assert.assertEquals;
+
+/** Tests checkpoints. */
+@RunWith(Parameterized.class)
+public class BoundedPerRoundCheckpointITCase extends TestLogger {
+
+    @Rule public final SharedObjects sharedObjects = SharedObjects.create();
+
+    private SharedReference<List<OutputRecord<Integer>>> result;
+
+    @Parameterized.Parameter(0)
+    public int failoverCount;
+
+    @Parameterized.Parameters(name = "failoverCount = {0}")
+    public static Collection<Object[]> params() {
+        return Arrays.asList(
+                new Object[] {1000},
+                new Object[] {4000},
+                new Object[] {6123},
+                new Object[] {8000},
+                new Object[] {10875},
+                new Object[] {15900});
+    }
+
+    @Before
+    public void setup() {
+        result = sharedObjects.add(new ArrayList<>());
+    }
+
+    @Test
+    public void testFailoverAndRestore() throws Exception {
+        try (MiniCluster miniCluster = new 
MiniCluster(createMiniClusterConfiguration(2, 2))) {
+            miniCluster.start();
+
+            // Create the test job

Review comment:
       Create --> Creates?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to