This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit d1d5d00cfd9d12d47e89a7d765772993bd6b0d4a
Author: Yun Gao <[email protected]>
AuthorDate: Wed Sep 29 23:04:37 2021 +0800

    [FLINK-24652][iteration] Add per-round operator wrappers
    
    This closes #14.
---
 flink-ml-iteration/pom.xml                         |   8 +
 .../flink/iteration/operator/OperatorUtils.java    |  23 +-
 .../perround/AbstractPerRoundWrapperOperator.java  | 488 +++++++++++++++++++++
 .../MultipleInputPerRoundWrapperOperator.java      | 180 ++++++++
 .../perround/OneInputPerRoundWrapperOperator.java  |  89 ++++
 .../operator/perround/PerRoundOperatorWrapper.java |  85 ++++
 .../perround/TwoInputPerRoundWrapperOperator.java  | 142 ++++++
 .../apache/flink/iteration/proxy/ProxyOutput.java  |   2 +-
 .../state/ProxyInternalTimeServiceManager.java     |  62 +++
 .../proxy/state/ProxyKeyedStateBackend.java        | 236 ++++++++++
 .../proxy/state/ProxyOperatorStateBackend.java     | 128 ++++++
 .../proxy/state/ProxyStateSnapshotContext.java     |  55 +++
 .../state/ProxyStreamOperatorStateContext.java     |  86 ++++
 .../iteration/proxy/state/StateNamePrefix.java}    |  32 +-
 .../iteration/operator/allround/LifeCycle.java     |   2 +
 .../MultipleInputAllRoundWrapperOperatorTest.java  |  40 +-
 .../OneInputAllRoundWrapperOperatorTest.java       |  35 +-
 .../TwoInputAllRoundWrapperOperatorTest.java       |  42 +-
 .../MultipleInputPerRoundWrapperOperatorTest.java} | 120 +++--
 .../OneInputPerRoundWrapperOperatorTest.java}      |  83 ++--
 .../perround/PerRoundOperatorStateTest.java        | 266 +++++++++++
 .../TwoInputPerRoundWrapperOperatorTest.java}      |  97 ++--
 22 files changed, 2142 insertions(+), 159 deletions(-)

diff --git a/flink-ml-iteration/pom.xml b/flink-ml-iteration/pom.xml
index d2cdb36..74044a3 100644
--- a/flink-ml-iteration/pom.xml
+++ b/flink-ml-iteration/pom.xml
@@ -61,6 +61,14 @@ under the License.
             <scope>provided</scope>
         </dependency>
 
+        <!-- We have special treatment with the rocksdb state. -->
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            
<artifactId>flink-statebackend-rocksdb_${scala.binary.version}</artifactId>
+            <version>${flink.version}</version>
+            <scope>provided</scope>
+        </dependency>
+
         <!-- Required for feedback edge implementation -->
         <dependency>
             <groupId>org.apache.flink</groupId>
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index 992d715..1b2769e 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -26,10 +26,11 @@ import 
org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.function.ThrowingConsumer;
 
 import java.util.Arrays;
 import java.util.concurrent.Executor;
-import java.util.function.Consumer;
 
 /** Utility class for operators. */
 public class OperatorUtils {
@@ -58,14 +59,20 @@ public class OperatorUtils {
     }
 
     public static <T> void processOperatorOrUdfIfSatisfy(
-            StreamOperator<?> operator, Class<T> targetInterface, Consumer<T> 
action) {
-        if (targetInterface.isAssignableFrom(operator.getClass())) {
-            action.accept((T) operator);
-        } else if (operator instanceof AbstractUdfStreamOperator<?, ?>) {
-            Object udf = ((AbstractUdfStreamOperator<?, ?>) 
operator).getUserFunction();
-            if (targetInterface.isAssignableFrom(udf.getClass())) {
-                action.accept((T) udf);
+            StreamOperator<?> operator,
+            Class<T> targetInterface,
+            ThrowingConsumer<T, Exception> action) {
+        try {
+            if (targetInterface.isAssignableFrom(operator.getClass())) {
+                action.accept((T) operator);
+            } else if (operator instanceof AbstractUdfStreamOperator<?, ?>) {
+                Object udf = ((AbstractUdfStreamOperator<?, ?>) 
operator).getUserFunction();
+                if (targetInterface.isAssignableFrom(udf.getClass())) {
+                    action.accept((T) udf);
+                }
             }
+        } catch (Exception e) {
+            ExceptionUtils.rethrow(e);
         }
     }
 }
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
new file mode 100644
index 0000000..3903340
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
@@ -0,0 +1,488 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.perround;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MetricOptions;
+import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext;
+import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
+import org.apache.flink.iteration.utils.ReflectionUtils;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.util.LatencyStats;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.function.BiConsumerWithException;
+
+import org.rocksdb.RocksDB;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Optional;
+
+/** The base class for all the per-round wrapper operators. */
+public abstract class AbstractPerRoundWrapperOperator<T, S extends 
StreamOperator<T>>
+        extends AbstractWrapperOperator<T>
+        implements StreamOperatorStateHandler.CheckpointedStreamOperator {
+
+    private static final Logger LOG =
+            LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class);
+
+    private static final String HEAP_KEYED_STATE_NAME =
+            "org.apache.flink.runtime.state.heap.HeapKeyedStateBackend";
+
+    private static final String ROCKSDB_KEYED_STATE_NAME =
+            
"org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend";
+
+    /** The wrapped operators for each round. */
+    private final Map<Integer, S> wrappedOperators;
+
+    protected final LatencyStats latencyStats;
+
+    private transient StreamOperatorStateContext streamOperatorStateContext;
+
+    private transient StreamOperatorStateHandler stateHandler;
+
+    private transient InternalTimeServiceManager<?> timeServiceManager;
+
+    private transient KeySelector<?, ?> stateKeySelector1;
+
+    private transient KeySelector<?, ?> stateKeySelector2;
+
+    public AbstractPerRoundWrapperOperator(
+            StreamOperatorParameters<IterationRecord<T>> parameters,
+            StreamOperatorFactory<T> operatorFactory) {
+        super(parameters, operatorFactory);
+
+        this.wrappedOperators = new HashMap<>();
+        this.latencyStats = initializeLatencyStats();
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    protected S getWrappedOperator(int round) {
+        S wrappedOperator = wrappedOperators.get(round);
+        if (wrappedOperator != null) {
+            return wrappedOperator;
+        }
+
+        // We need to clone the operator factory to also support 
SimpleOperatorFactory.
+        try {
+            StreamOperatorFactory<T> clonedOperatorFactory =
+                    InstantiationUtil.clone(operatorFactory);
+            wrappedOperator =
+                    (S)
+                            StreamOperatorFactoryUtil.<T, S>createOperator(
+                                            clonedOperatorFactory,
+                                            (StreamTask) 
parameters.getContainingTask(),
+                                            parameters.getStreamConfig(),
+                                            proxyOutput,
+                                            
parameters.getOperatorEventDispatcher())
+                                    .f0;
+            initializeStreamOperator(wrappedOperator, round);
+            wrappedOperators.put(round, wrappedOperator);
+            return wrappedOperator;
+        } catch (Exception e) {
+            ExceptionUtils.rethrow(e);
+        }
+
+        return wrappedOperator;
+    }
+
+    protected abstract void endInputAndEmitMaxWatermark(S operator, int round) 
throws Exception;
+
+    private void closeStreamOperator(S operator, int round) throws Exception {
+        setIterationContextRound(round);
+        endInputAndEmitMaxWatermark(operator, round);
+        operator.finish();
+        operator.close();
+        setIterationContextRound(null);
+
+        // Cleanup the states used by this operator.
+        cleanupOperatorStates(round);
+
+        if (stateHandler.getKeyedStateBackend() != null) {
+            cleanupKeyedStates(round);
+        }
+    }
+
+    @Override
+    public void onEpochWatermarkIncrement(int epochWatermark) throws 
IOException {
+        try {
+            // Destroys all the operators with round < epoch watermark. Notes 
that
+            // the onEpochWatermarkIncrement must be from 0 and increment by 1 
each time.
+            if (wrappedOperators.containsKey(epochWatermark)) {
+                closeStreamOperator(wrappedOperators.get(epochWatermark), 
epochWatermark);
+                wrappedOperators.remove(epochWatermark);
+            }
+
+            super.onEpochWatermarkIncrement(epochWatermark);
+        } catch (Exception e) {
+            ExceptionUtils.rethrow(e);
+        }
+    }
+
+    protected void processForEachWrappedOperator(
+            BiConsumerWithException<Integer, S, Exception> consumer) throws 
Exception {
+        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
+            consumer.accept(entry.getKey(), entry.getValue());
+        }
+    }
+
+    @Override
+    public void open() throws Exception {}
+
+    @Override
+    public void initializeState(StreamTaskStateInitializer 
streamTaskStateManager)
+            throws Exception {
+        final TypeSerializer<?> keySerializer =
+                
streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader());
+
+        streamOperatorStateContext =
+                streamTaskStateManager.streamOperatorStateContext(
+                        getOperatorID(),
+                        getClass().getSimpleName(),
+                        parameters.getProcessingTimeService(),
+                        this,
+                        keySerializer,
+                        containingTask.getCancelables(),
+                        metrics,
+                        
streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(
+                                ManagedMemoryUseCase.STATE_BACKEND,
+                                containingTask
+                                        .getEnvironment()
+                                        .getTaskManagerInfo()
+                                        .getConfiguration(),
+                                containingTask.getUserCodeClassLoader()),
+                        isUsingCustomRawKeyedState());
+
+        stateHandler =
+                new StreamOperatorStateHandler(
+                        streamOperatorStateContext,
+                        containingTask.getExecutionConfig(),
+                        containingTask.getCancelables());
+        stateHandler.initializeOperatorState(this);
+        this.timeServiceManager = 
streamOperatorStateContext.internalTimerServiceManager();
+
+        stateKeySelector1 =
+                streamConfig.getStatePartitioner(0, 
containingTask.getUserCodeClassLoader());
+        stateKeySelector2 =
+                streamConfig.getStatePartitioner(1, 
containingTask.getUserCodeClassLoader());
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws 
Exception {
+        // Do thing for now since we do not have states.
+    }
+
+    @Internal
+    protected boolean isUsingCustomRawKeyedState() {
+        return false;
+    }
+
+    @Override
+    public void finish() throws Exception {
+        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
+            closeStreamOperator(entry.getValue(), entry.getKey());
+        }
+        wrappedOperators.clear();
+    }
+
+    @Override
+    public void close() throws Exception {
+        if (stateHandler != null) {
+            stateHandler.dispose();
+        }
+    }
+
+    @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
+            entry.getValue().prepareSnapshotPreBarrier(checkpointId);
+        }
+    }
+
+    @Override
+    public OperatorSnapshotFutures snapshotState(
+            long checkpointId,
+            long timestamp,
+            CheckpointOptions checkpointOptions,
+            CheckpointStreamFactory factory)
+            throws Exception {
+        return stateHandler.snapshotState(
+                this,
+                Optional.ofNullable(timeServiceManager),
+                streamConfig.getOperatorName(),
+                checkpointId,
+                timestamp,
+                checkpointOptions,
+                factory,
+                isUsingCustomRawKeyedState());
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
+            if 
(StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom(
+                    entry.getValue().getClass())) {
+                ((StreamOperatorStateHandler.CheckpointedStreamOperator) 
entry.getValue())
+                        .snapshotState(new ProxyStateSnapshotContext(context));
+            }
+        }
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void setKeyContextElement1(StreamRecord record) throws Exception {
+        setKeyContextElement(record, stateKeySelector1);
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void setKeyContextElement2(StreamRecord record) throws Exception {
+        setKeyContextElement(record, stateKeySelector2);
+    }
+
+    private <T> void setKeyContextElement(StreamRecord<T> record, 
KeySelector<T, ?> selector)
+            throws Exception {
+        if (selector != null) {
+            Object key = selector.getKey(record.getValue());
+            setCurrentKey(key);
+        }
+    }
+
+    @Override
+    public OperatorMetricGroup getMetricGroup() {
+        return metrics;
+    }
+
+    @Override
+    public OperatorID getOperatorID() {
+        return streamConfig.getOperatorID();
+    }
+
+    @Override
+    public void notifyCheckpointComplete(long l) throws Exception {
+        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
+            entry.getValue().notifyCheckpointComplete(l);
+        }
+    }
+
+    @Override
+    public void notifyCheckpointAborted(long checkpointId) throws Exception {
+        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
+            entry.getValue().notifyCheckpointAborted(checkpointId);
+        }
+    }
+
+    @Override
+    public void setCurrentKey(Object key) {
+        stateHandler.setCurrentKey(key);
+    }
+
+    @Override
+    public Object getCurrentKey() {
+        if (stateHandler == null) {
+            return null;
+        }
+
+        return stateHandler.getKeyedStateStore().orElse(null);
+    }
+
+    protected void reportOrForwardLatencyMarker(LatencyMarker marker) {
+        // all operators are tracking latencies
+        this.latencyStats.reportLatency(marker);
+
+        // everything except sinks forwards latency markers
+        this.output.emitLatencyMarker(marker);
+    }
+
+    private LatencyStats initializeLatencyStats() {
+        try {
+            Configuration taskManagerConfig =
+                    
containingTask.getEnvironment().getTaskManagerInfo().getConfiguration();
+            int historySize = 
taskManagerConfig.getInteger(MetricOptions.LATENCY_HISTORY_SIZE);
+            if (historySize <= 0) {
+                LOG.warn(
+                        "{} has been set to a value equal or below 0: {}. 
Using default.",
+                        MetricOptions.LATENCY_HISTORY_SIZE,
+                        historySize);
+                historySize = 
MetricOptions.LATENCY_HISTORY_SIZE.defaultValue();
+            }
+
+            final String configuredGranularity =
+                    
taskManagerConfig.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY);
+            LatencyStats.Granularity granularity;
+            try {
+                granularity =
+                        LatencyStats.Granularity.valueOf(
+                                
configuredGranularity.toUpperCase(Locale.ROOT));
+            } catch (IllegalArgumentException iae) {
+                granularity = LatencyStats.Granularity.OPERATOR;
+                LOG.warn(
+                        "Configured value {} option for {} is invalid. 
Defaulting to {}.",
+                        configuredGranularity,
+                        MetricOptions.LATENCY_SOURCE_GRANULARITY.key(),
+                        granularity);
+            }
+            MetricGroup jobMetricGroup = this.metrics.getJobMetricGroup();
+            return new LatencyStats(
+                    jobMetricGroup.addGroup("latency"),
+                    historySize,
+                    containingTask.getIndexInSubtaskGroup(),
+                    getOperatorID(),
+                    granularity);
+        } catch (Exception e) {
+            LOG.warn("An error occurred while instantiating latency metrics.", 
e);
+            return new LatencyStats(
+                    
UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup()
+                            .addGroup("latency"),
+                    1,
+                    0,
+                    new OperatorID(),
+                    LatencyStats.Granularity.SINGLE);
+        }
+    }
+
+    private void initializeStreamOperator(S operator, int round) throws 
Exception {
+        operator.initializeState(
+                (operatorID,
+                        operatorClassName,
+                        processingTimeService,
+                        keyContext,
+                        keySerializer,
+                        streamTaskCloseableRegistry,
+                        metricGroup,
+                        managedMemoryFraction,
+                        isUsingCustomRawKeyedState) ->
+                        new ProxyStreamOperatorStateContext(
+                                streamOperatorStateContext, 
getRoundStatePrefix(round)));
+        operator.open();
+    }
+
+    private void cleanupOperatorStates(int round) {
+        String roundPrefix = getRoundStatePrefix(round);
+        OperatorStateBackend operatorStateBackend = 
stateHandler.getOperatorStateBackend();
+
+        if (operatorStateBackend instanceof DefaultOperatorStateBackend) {
+            for (String fieldNames :
+                    new String[] {
+                        "registeredOperatorStates",
+                        "registeredBroadcastStates",
+                        "accessedStatesByName",
+                        "accessedBroadcastStatesByName"
+                    }) {
+                Map<String, ?> field =
+                        ReflectionUtils.getFieldValue(
+                                operatorStateBackend,
+                                DefaultOperatorStateBackend.class,
+                                fieldNames);
+                field.entrySet().removeIf(entry -> 
entry.getKey().startsWith(roundPrefix));
+            }
+        } else {
+            LOG.warn("Unable to cleanup the operator state {}", 
operatorStateBackend);
+        }
+    }
+
+    private void cleanupKeyedStates(int round) {
+        String roundPrefix = getRoundStatePrefix(round);
+        KeyedStateBackend<?> keyedStateBackend = 
stateHandler.getKeyedStateBackend();
+        if 
(keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) {
+            ReflectionUtils.<Map<String, ?>>getFieldValue(
+                            keyedStateBackend, HeapKeyedStateBackend.class, 
"registeredKVStates")
+                    .entrySet()
+                    .removeIf(entry -> entry.getKey().startsWith(roundPrefix));
+            ReflectionUtils.<Map<String, ?>>getFieldValue(
+                            keyedStateBackend,
+                            AbstractKeyedStateBackend.class,
+                            "keyValueStatesByName")
+                    .entrySet()
+                    .removeIf(entry -> entry.getKey().startsWith(roundPrefix));
+        } else if 
(keyedStateBackend.getClass().getName().equals(ROCKSDB_KEYED_STATE_NAME)) {
+            RocksDB db =
+                    ReflectionUtils.getFieldValue(
+                            keyedStateBackend, RocksDBKeyedStateBackend.class, 
"db");
+            HashMap<String, RocksDBKeyedStateBackend.RocksDbKvStateInfo> 
kvStateInformation =
+                    ReflectionUtils.getFieldValue(
+                            keyedStateBackend,
+                            RocksDBKeyedStateBackend.class,
+                            "kvStateInformation");
+            kvStateInformation.entrySet().stream()
+                    .filter(entry -> entry.getKey().startsWith(roundPrefix))
+                    .forEach(
+                            entry -> {
+                                try {
+                                    
db.dropColumnFamily(entry.getValue().columnFamilyHandle);
+                                } catch (Exception e) {
+                                    LOG.error(
+                                            "Failed to drop state {} for round 
{}",
+                                            entry.getKey(),
+                                            round);
+                                }
+                            });
+            kvStateInformation.entrySet().removeIf(entry -> 
entry.getKey().startsWith(roundPrefix));
+
+            Map<String, ?> field =
+                    ReflectionUtils.getFieldValue(
+                            keyedStateBackend,
+                            AbstractKeyedStateBackend.class,
+                            "keyValueStatesByName");
+            field.entrySet().removeIf(entry -> 
entry.getKey().startsWith(roundPrefix));
+        } else {
+            LOG.warn("Unable to cleanup the keyed state {}", 
keyedStateBackend);
+        }
+    }
+
+    private String getRoundStatePrefix(int round) {
+        return "r" + round + "-";
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
new file mode 100644
index 0000000..7c6240f
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.perround;
+
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.graph.StreamEdge;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/** Per-round wrapper for the multiple-inputs operator. */
+public class MultipleInputPerRoundWrapperOperator<OUT>
+        extends AbstractPerRoundWrapperOperator<OUT, 
MultipleInputStreamOperator<OUT>>
+        implements MultipleInputStreamOperator<IterationRecord<OUT>> {
+
+    /** The number of total inputs. */
+    private final int numberOfInputs;
+
+    /**
+     * Cached inputs for each epoch. This is to avoid repeat calls to the 
{@link
+     * MultipleInputStreamOperator#getInputs()}, which might not returns the 
same inputs for each
+     * call.
+     */
+    private final Map<Integer, List<Input>> operatorInputsByEpoch = new 
HashMap<>();
+
+    public MultipleInputPerRoundWrapperOperator(
+            StreamOperatorParameters<IterationRecord<OUT>> parameters,
+            StreamOperatorFactory<OUT> operatorFactory) {
+        super(parameters, operatorFactory);
+
+        // Determine how much inputs we have
+        List<StreamEdge> inEdges =
+                
streamConfig.getInPhysicalEdges(containingTask.getUserCodeClassLoader());
+        this.numberOfInputs =
+                
inEdges.stream().map(StreamEdge::getTypeNumber).collect(Collectors.toSet()).size();
+    }
+
+    @Override
+    protected void 
endInputAndEmitMaxWatermark(MultipleInputStreamOperator<OUT> operator, int 
round)
+            throws Exception {
+        OperatorUtils.processOperatorOrUdfIfSatisfy(
+                operator,
+                BoundedMultiInput.class,
+                boundedMultiInput -> {
+                    for (int i = 0; i < numberOfInputs; ++i) {
+                        boundedMultiInput.endInput(i + 1);
+                    }
+                });
+
+        for (int i = 0; i < numberOfInputs; ++i) {
+            operatorInputsByEpoch.get(round).get(i).processWatermark(new 
Watermark(Long.MAX_VALUE));
+        }
+    }
+
+    private <IN> void processElement(
+            int inputIndex,
+            Input<IN> input,
+            StreamRecord<IN> reusedInput,
+            StreamRecord<IterationRecord<IN>> element)
+            throws Exception {
+        switch (element.getValue().getType()) {
+            case RECORD:
+                reusedInput.replace(element.getValue().getValue(), 
element.getTimestamp());
+                setIterationContextRound(element.getValue().getEpoch());
+                input.processElement(reusedInput);
+                clearIterationContextRound();
+                break;
+            case EPOCH_WATERMARK:
+                onEpochWatermarkEvent(inputIndex, element.getValue());
+                break;
+            default:
+                throw new FlinkRuntimeException("Not supported iteration 
record type: " + element);
+        }
+    }
+
+    @Override
+    @SuppressWarnings({"rawtypes"})
+    public List<Input> getInputs() {
+        List<Input> proxyInputs = new ArrayList<>();
+
+        for (int i = 0; i < numberOfInputs; ++i) {
+            // TODO: Note that here we relies on the assumption that the
+            // stream graph generator labels the input from 1 to n for
+            // the input array, which we map them from 0 to n - 1.
+            proxyInputs.add(new ProxyInput(i));
+        }
+        return proxyInputs;
+    }
+
+    private class ProxyInput<IN> implements Input<IterationRecord<IN>> {
+
+        private final int inputIndex;
+
+        private final StreamRecord<IN> reusedInput;
+
+        public ProxyInput(int inputIndex) {
+            this.inputIndex = inputIndex;
+            this.reusedInput = new StreamRecord<>(null, 0);
+        }
+
+        @Override
+        public void processElement(StreamRecord<IterationRecord<IN>> element) 
throws Exception {
+            if 
(!operatorInputsByEpoch.containsKey(element.getValue().getEpoch())) {
+                MultipleInputStreamOperator<OUT> operator =
+                        getWrappedOperator(element.getValue().getEpoch());
+                operatorInputsByEpoch.put(element.getValue().getEpoch(), 
operator.getInputs());
+            }
+
+            MultipleInputPerRoundWrapperOperator.this.processElement(
+                    inputIndex,
+                    
operatorInputsByEpoch.get(element.getValue().getEpoch()).get(inputIndex),
+                    reusedInput,
+                    element);
+        }
+
+        @Override
+        public void processWatermark(Watermark mark) throws Exception {
+            processForEachWrappedOperator(
+                    (round, wrappedOperator) -> {
+                        
operatorInputsByEpoch.get(round).get(inputIndex).processWatermark(mark);
+                    });
+        }
+
+        @Override
+        public void processWatermarkStatus(WatermarkStatus watermarkStatus) 
throws Exception {
+            processForEachWrappedOperator(
+                    (round, wrappedOperator) -> {
+                        operatorInputsByEpoch
+                                .get(round)
+                                .get(inputIndex)
+                                .processWatermarkStatus(watermarkStatus);
+                    });
+        }
+
+        @Override
+        public void processLatencyMarker(LatencyMarker latencyMarker) throws 
Exception {
+            reportOrForwardLatencyMarker(latencyMarker);
+        }
+
+        @Override
+        public void setKeyContextElement(StreamRecord<IterationRecord<IN>> 
record)
+                throws Exception {
+            MultipleInputStreamOperator<OUT> operator =
+                    getWrappedOperator(record.getValue().getEpoch());
+
+            reusedInput.replace(record.getValue(), record.getTimestamp());
+            
operator.getInputs().get(inputIndex).setKeyContextElement(reusedInput);
+        }
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java
new file mode 100644
index 0000000..ae85618
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.perround;
+
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+import org.apache.flink.util.FlinkRuntimeException;
+
+/** Per-round wrapper operator for the one-input operator. */
+public class OneInputPerRoundWrapperOperator<IN, OUT>
+        extends AbstractPerRoundWrapperOperator<OUT, 
OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IterationRecord<IN>, 
IterationRecord<OUT>> {
+
+    private final StreamRecord<IN> reusedInput;
+
+    public OneInputPerRoundWrapperOperator(
+            StreamOperatorParameters<IterationRecord<OUT>> parameters,
+            StreamOperatorFactory<OUT> operatorFactory) {
+        super(parameters, operatorFactory);
+        this.reusedInput = new StreamRecord<>(null, 0);
+    }
+
+    @Override
+    protected void endInputAndEmitMaxWatermark(OneInputStreamOperator<IN, OUT> 
operator, int round)
+            throws Exception {
+        OperatorUtils.processOperatorOrUdfIfSatisfy(
+                operator, BoundedOneInput.class, BoundedOneInput::endInput);
+        operator.processWatermark(new Watermark(Long.MAX_VALUE));
+    }
+
+    @Override
+    public void processElement(StreamRecord<IterationRecord<IN>> element) 
throws Exception {
+        switch (element.getValue().getType()) {
+            case RECORD:
+                reusedInput.replace(element.getValue().getValue(), 
element.getTimestamp());
+                setIterationContextRound(element.getValue().getEpoch());
+                
getWrappedOperator(element.getValue().getEpoch()).processElement(reusedInput);
+                clearIterationContextRound();
+                break;
+            case EPOCH_WATERMARK:
+                onEpochWatermarkEvent(0, element.getValue());
+                break;
+            default:
+                throw new FlinkRuntimeException("Not supported iteration 
record type: " + element);
+        }
+    }
+
+    @Override
+    public void processWatermark(Watermark mark) throws Exception {
+        processForEachWrappedOperator(
+                (round, wrappedOperator) -> 
wrappedOperator.processWatermark(mark));
+    }
+
+    @Override
+    public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws 
Exception {
+        processForEachWrappedOperator(
+                (round, wrappedOperator) ->
+                        
wrappedOperator.processWatermarkStatus(watermarkStatus));
+    }
+
+    @Override
+    public void processLatencyMarker(LatencyMarker latencyMarker) throws 
Exception {
+        reportOrForwardLatencyMarker(latencyMarker);
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
new file mode 100644
index 0000000..ffa2221
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.perround;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.OperatorWrapper;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
+import org.apache.flink.iteration.proxy.ProxyStreamPartitioner;
+import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+
+/** The operator wrapper implementation for per-round wrappers. */
+public class PerRoundOperatorWrapper<T> implements OperatorWrapper<T, 
IterationRecord<T>> {
+
+    @Override
+    public StreamOperator<IterationRecord<T>> wrap(
+            StreamOperatorParameters<IterationRecord<T>> operatorParameters,
+            StreamOperatorFactory<T> operatorFactory) {
+        Class<? extends StreamOperator> operatorClass =
+                
operatorFactory.getStreamOperatorClass(getClass().getClassLoader());
+        if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return new OneInputPerRoundWrapperOperator<>(operatorParameters, 
operatorFactory);
+        } else if 
(TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return new TwoInputPerRoundWrapperOperator<>(operatorParameters, 
operatorFactory);
+        } else if 
(MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) {
+            return new 
MultipleInputPerRoundWrapperOperator<>(operatorParameters, operatorFactory);
+        } else {
+            throw new UnsupportedOperationException(
+                    "Unsupported operator class for all-round wrapper: " + 
operatorClass);
+        }
+    }
+
+    @Override
+    public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector(
+            KeySelector<T, KEY> keySelector) {
+        return new ProxyKeySelector<>(keySelector);
+    }
+
+    @Override
+    public StreamPartitioner<IterationRecord<T>> wrapStreamPartitioner(
+            StreamPartitioner<T> streamPartitioner) {
+        if (streamPartitioner instanceof BroadcastPartitioner) {
+            return new BroadcastPartitioner<>();
+        }
+
+        return new ProxyStreamPartitioner<>(streamPartitioner);
+    }
+
+    @Override
+    public OutputTag<IterationRecord<T>> wrapOutputTag(OutputTag<T> outputTag) 
{
+        return new OutputTag<>(
+                outputTag.getId(), new 
IterationRecordTypeInfo<>(outputTag.getTypeInfo()));
+    }
+
+    @Override
+    public TypeInformation<IterationRecord<T>> 
getWrappedTypeInfo(TypeInformation<T> typeInfo) {
+        return new IterationRecordTypeInfo<>(typeInfo);
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java
new file mode 100644
index 0000000..39918ad
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.perround;
+
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.function.ThrowingConsumer;
+
+/** Per-round wrapper for the two-inputs operator. */
+public class TwoInputPerRoundWrapperOperator<IN1, IN2, OUT>
+        extends AbstractPerRoundWrapperOperator<OUT, 
TwoInputStreamOperator<IN1, IN2, OUT>>
+        implements TwoInputStreamOperator<
+                IterationRecord<IN1>, IterationRecord<IN2>, 
IterationRecord<OUT>> {
+
+    private final StreamRecord<IN1> reusedInput1;
+
+    private final StreamRecord<IN2> reusedInput2;
+
+    public TwoInputPerRoundWrapperOperator(
+            StreamOperatorParameters<IterationRecord<OUT>> parameters,
+            StreamOperatorFactory<OUT> operatorFactory) {
+        super(parameters, operatorFactory);
+
+        this.reusedInput1 = new StreamRecord<>(null, 0);
+        this.reusedInput2 = new StreamRecord<>(null, 0);
+    }
+
+    @Override
+    protected void endInputAndEmitMaxWatermark(
+            TwoInputStreamOperator<IN1, IN2, OUT> operator, int round) throws 
Exception {
+        OperatorUtils.processOperatorOrUdfIfSatisfy(
+                operator,
+                BoundedMultiInput.class,
+                boundedMultiInput -> {
+                    boundedMultiInput.endInput(1);
+                    boundedMultiInput.endInput(2);
+                });
+        operator.processWatermark1(new Watermark(Long.MAX_VALUE));
+        operator.processWatermark2(new Watermark(Long.MAX_VALUE));
+    }
+
+    @Override
+    public void processElement1(StreamRecord<IterationRecord<IN1>> element) 
throws Exception {
+        processElement(
+                element,
+                0,
+                reusedInput1,
+                
getWrappedOperator(element.getValue().getEpoch())::processElement1);
+    }
+
+    @Override
+    public void processElement2(StreamRecord<IterationRecord<IN2>> element) 
throws Exception {
+        processElement(
+                element,
+                1,
+                reusedInput2,
+                
getWrappedOperator(element.getValue().getEpoch())::processElement2);
+    }
+
+    private <IN> void processElement(
+            StreamRecord<IterationRecord<IN>> element,
+            int inputIndex,
+            StreamRecord<IN> reusedInput,
+            ThrowingConsumer<StreamRecord<IN>, Exception> processor)
+            throws Exception {
+
+        switch (element.getValue().getType()) {
+            case RECORD:
+                reusedInput.replace(element.getValue().getValue(), 
element.getTimestamp());
+                setIterationContextRound(element.getValue().getEpoch());
+                processor.accept(reusedInput);
+                clearIterationContextRound();
+                break;
+            case EPOCH_WATERMARK:
+                onEpochWatermarkEvent(inputIndex, element.getValue());
+                break;
+            default:
+                throw new FlinkRuntimeException("Not supported iteration 
record type: " + element);
+        }
+    }
+
+    @Override
+    public void processWatermark1(Watermark mark) throws Exception {
+        processForEachWrappedOperator(
+                (round, wrappedOperator) -> 
wrappedOperator.processWatermark1(mark));
+    }
+
+    @Override
+    public void processWatermark2(Watermark mark) throws Exception {
+        processForEachWrappedOperator(
+                (round, wrappedOperator) -> 
wrappedOperator.processWatermark2(mark));
+    }
+
+    @Override
+    public void processLatencyMarker1(LatencyMarker latencyMarker) throws 
Exception {
+        reportOrForwardLatencyMarker(latencyMarker);
+    }
+
+    @Override
+    public void processLatencyMarker2(LatencyMarker latencyMarker) throws 
Exception {
+        reportOrForwardLatencyMarker(latencyMarker);
+    }
+
+    @Override
+    public void processWatermarkStatus1(WatermarkStatus watermarkStatus) 
throws Exception {
+        processForEachWrappedOperator(
+                (round, wrappedOperator) ->
+                        
wrappedOperator.processWatermarkStatus1(watermarkStatus));
+    }
+
+    @Override
+    public void processWatermarkStatus2(WatermarkStatus watermarkStatus) 
throws Exception {
+        processForEachWrappedOperator(
+                (round, wrappedOperator) ->
+                        
wrappedOperator.processWatermarkStatus2(watermarkStatus));
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java
index 0d80f62..7f5cc8a 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java
@@ -53,7 +53,7 @@ public class ProxyOutput<T> implements 
Output<StreamRecord<T>> {
 
     @Override
     public void emitWatermark(Watermark mark) {
-        output.emitWatermark(mark);
+        // For now, we only supports the MAX_WATERMARK separately for each 
operator.
     }
 
     @Override
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyInternalTimeServiceManager.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyInternalTimeServiceManager.java
new file mode 100644
index 0000000..7377328
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyInternalTimeServiceManager.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.proxy.state;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.api.watermark.Watermark;
+
+/** Proxy {@link InternalTimeServiceManager} for the wrapped operators. */
+public class ProxyInternalTimeServiceManager<K> implements 
InternalTimeServiceManager<K> {
+
+    private final InternalTimeServiceManager<K> wrappedManager;
+
+    private final StateNamePrefix stateNamePrefix;
+
+    public ProxyInternalTimeServiceManager(
+            InternalTimeServiceManager<K> wrappedManager, StateNamePrefix 
stateNamePrefix) {
+        this.wrappedManager = wrappedManager;
+        this.stateNamePrefix = stateNamePrefix;
+    }
+
+    @Override
+    public <N> InternalTimerService<N> getInternalTimerService(
+            String name,
+            TypeSerializer<K> keySerializer,
+            TypeSerializer<N> namespaceSerializer,
+            Triggerable<K, N> triggerable) {
+        return wrappedManager.getInternalTimerService(
+                stateNamePrefix.prefix(name), keySerializer, 
namespaceSerializer, triggerable);
+    }
+
+    @Override
+    public void advanceWatermark(Watermark watermark) throws Exception {
+        wrappedManager.advanceWatermark(watermark);
+    }
+
+    @Override
+    public void snapshotToRawKeyedState(
+            KeyedStateCheckpointOutputStream stateCheckpointOutputStream, 
String operatorName)
+            throws Exception {
+        wrappedManager.snapshotToRawKeyedState(stateCheckpointOutputStream, 
operatorName);
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyKeyedStateBackend.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyKeyedStateBackend.java
new file mode 100644
index 0000000..bc51f95
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyKeyedStateBackend.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.proxy.state;
+
+import org.apache.flink.api.common.state.AggregatingStateDescriptor;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ReducingStateDescriptor;
+import org.apache.flink.api.common.state.State;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.Keyed;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.KeyedStateFunction;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.PriorityComparable;
+import org.apache.flink.runtime.state.SavepointResources;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
+
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.concurrent.RunnableFuture;
+import java.util.stream.Stream;
+
+/** Proxy {@link KeyedStateBackend} for the wrapped operators. */
+public class ProxyKeyedStateBackend<K> implements 
CheckpointableKeyedStateBackend<K> {
+
+    private final CheckpointableKeyedStateBackend<K> wrappedBackend;
+
+    private final StateNamePrefix stateNamePrefix;
+
+    public ProxyKeyedStateBackend(
+            CheckpointableKeyedStateBackend<K> wrappedBackend, StateNamePrefix 
stateNamePrefix) {
+        this.wrappedBackend = wrappedBackend;
+        this.stateNamePrefix = stateNamePrefix;
+    }
+
+    @Override
+    public void setCurrentKey(K newKey) {
+        wrappedBackend.setCurrentKey(newKey);
+    }
+
+    @Override
+    public K getCurrentKey() {
+        return wrappedBackend.getCurrentKey();
+    }
+
+    @Override
+    public TypeSerializer<K> getKeySerializer() {
+        return wrappedBackend.getKeySerializer();
+    }
+
+    @Override
+    public <N, S extends State, T> void applyToAllKeys(
+            N namespace,
+            TypeSerializer<N> namespaceSerializer,
+            StateDescriptor<S, T> stateDescriptor,
+            KeyedStateFunction<K, S> function)
+            throws Exception {
+        StateDescriptor<S, T> newDescriptor = 
createNewDescriptor(stateDescriptor);
+        wrappedBackend.applyToAllKeys(namespace, namespaceSerializer, 
newDescriptor, function);
+    }
+
+    @Override
+    public <N> Stream<K> getKeys(String state, N namespace) {
+        return wrappedBackend.getKeys(stateNamePrefix.prefix(state), 
namespace);
+    }
+
+    @Override
+    public <N> Stream<Tuple2<K, N>> getKeysAndNamespaces(String state) {
+        return 
wrappedBackend.getKeysAndNamespaces(stateNamePrefix.prefix(state));
+    }
+
+    @Override
+    public <N, S extends State, T> S getOrCreateKeyedState(
+            TypeSerializer<N> namespaceSerializer, StateDescriptor<S, T> 
stateDescriptor)
+            throws Exception {
+        StateDescriptor<S, T> newDescriptor = 
createNewDescriptor(stateDescriptor);
+        return wrappedBackend.getOrCreateKeyedState(namespaceSerializer, 
newDescriptor);
+    }
+
+    @Override
+    public <N, S extends State> S getPartitionedState(
+            N namespace,
+            TypeSerializer<N> namespaceSerializer,
+            StateDescriptor<S, ?> stateDescriptor)
+            throws Exception {
+        StateDescriptor<S, ?> newDescriptor = 
createNewDescriptor(stateDescriptor);
+        return wrappedBackend.getPartitionedState(namespace, 
namespaceSerializer, newDescriptor);
+    }
+
+    @Override
+    public void registerKeySelectionListener(KeySelectionListener<K> listener) 
{
+        wrappedBackend.registerKeySelectionListener(listener);
+    }
+
+    @Override
+    public boolean deregisterKeySelectionListener(KeySelectionListener<K> 
listener) {
+        return wrappedBackend.deregisterKeySelectionListener(listener);
+    }
+
+    @Nonnull
+    @Override
+    public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+            @Nonnull TypeSerializer<N> namespaceSerializer,
+            @Nonnull StateDescriptor<S, SV> stateDesc,
+            @Nonnull
+                    StateSnapshotTransformer.StateSnapshotTransformFactory<SEV>
+                            snapshotTransformFactory)
+            throws Exception {
+        StateDescriptor<S, ?> newDescriptor = createNewDescriptor(stateDesc);
+        return wrappedBackend.createInternalState(
+                namespaceSerializer, newDescriptor, snapshotTransformFactory);
+    }
+
+    @SuppressWarnings("unchecked")
+    protected <S extends State, T> StateDescriptor<S, T> createNewDescriptor(
+            StateDescriptor<S, T> descriptor) {
+        switch (descriptor.getType()) {
+            case VALUE:
+                {
+                    return (StateDescriptor<S, T>)
+                            new ValueStateDescriptor<>(
+                                    
stateNamePrefix.prefix(descriptor.getName()),
+                                    descriptor.getSerializer());
+                }
+            case LIST:
+                {
+                    ListStateDescriptor<T> listStateDescriptor =
+                            (ListStateDescriptor<T>) descriptor;
+                    return (StateDescriptor<S, T>)
+                            new ListStateDescriptor<>(
+                                    
stateNamePrefix.prefix(listStateDescriptor.getName()),
+                                    
listStateDescriptor.getElementSerializer());
+                }
+            case REDUCING:
+                {
+                    ReducingStateDescriptor<T> reducingStateDescriptor =
+                            (ReducingStateDescriptor<T>) descriptor;
+                    return (StateDescriptor<S, T>)
+                            new ReducingStateDescriptor<>(
+                                    
stateNamePrefix.prefix(reducingStateDescriptor.getName()),
+                                    
reducingStateDescriptor.getReduceFunction(),
+                                    reducingStateDescriptor.getSerializer());
+                }
+            case AGGREGATING:
+                {
+                    AggregatingStateDescriptor<?, ?, T> 
aggregatingStateDescriptor =
+                            (AggregatingStateDescriptor<?, ?, T>) descriptor;
+                    return new AggregatingStateDescriptor(
+                            
stateNamePrefix.prefix(aggregatingStateDescriptor.getName()),
+                            aggregatingStateDescriptor.getAggregateFunction(),
+                            aggregatingStateDescriptor.getSerializer());
+                }
+            case MAP:
+                {
+                    MapStateDescriptor<?, Map<?, ?>> mapStateDescriptor =
+                            (MapStateDescriptor<?, Map<?, ?>>) descriptor;
+                    return new MapStateDescriptor(
+                            
stateNamePrefix.prefix(mapStateDescriptor.getName()),
+                            mapStateDescriptor.getKeySerializer(),
+                            mapStateDescriptor.getValueSerializer());
+                }
+            default:
+                throw new UnsupportedOperationException("Unsupported state 
type");
+        }
+    }
+
+    @Override
+    public KeyGroupRange getKeyGroupRange() {
+        return wrappedBackend.getKeyGroupRange();
+    }
+
+    @Nonnull
+    @Override
+    public SavepointResources<K> savepoint() throws Exception {
+        return wrappedBackend.savepoint();
+    }
+
+    @Override
+    public void dispose() {
+        // Do not dispose for poxy.
+    }
+
+    @Override
+    public void close() throws IOException {
+        // Do not close for poxy.
+    }
+
+    @Nonnull
+    @Override
+    public <T extends HeapPriorityQueueElement & PriorityComparable<? super T> 
& Keyed<?>>
+            KeyGroupedInternalPriorityQueue<T> create(
+                    @Nonnull String stateName,
+                    @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+        return wrappedBackend.create(
+                stateNamePrefix.prefix(stateName), 
byteOrderedElementSerializer);
+    }
+
+    @Nonnull
+    @Override
+    public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
+            long checkpointId,
+            long timestamp,
+            @Nonnull CheckpointStreamFactory streamFactory,
+            @Nonnull CheckpointOptions checkpointOptions)
+            throws Exception {
+        return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, 
checkpointOptions);
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
new file mode 100644
index 0000000..a655886
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.proxy.state;
+
+import org.apache.flink.api.common.state.BroadcastState;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SnapshotResult;
+
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.RunnableFuture;
+
+/** Proxy {@link OperatorStateBackend} for the wrapped Operator. */
+public class ProxyOperatorStateBackend implements OperatorStateBackend {
+
+    private final OperatorStateBackend wrappedBackend;
+
+    private final StateNamePrefix stateNamePrefix;
+
+    public ProxyOperatorStateBackend(
+            OperatorStateBackend wrappedBackend, StateNamePrefix 
stateNamePrefix) {
+        this.wrappedBackend = wrappedBackend;
+        this.stateNamePrefix = stateNamePrefix;
+    }
+
+    @Override
+    public <K, V> BroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, 
V> stateDescriptor)
+            throws Exception {
+        MapStateDescriptor<K, V> newDescriptor =
+                new MapStateDescriptor<>(
+                        stateNamePrefix.prefix(stateDescriptor.getName()),
+                        stateDescriptor.getKeySerializer(),
+                        stateDescriptor.getValueSerializer());
+        return wrappedBackend.getBroadcastState(newDescriptor);
+    }
+
+    @Override
+    public <S> ListState<S> getListState(ListStateDescriptor<S> 
stateDescriptor) throws Exception {
+        ListStateDescriptor<S> newDescriptor =
+                new ListStateDescriptor<>(
+                        stateNamePrefix.prefix(stateDescriptor.getName()),
+                        stateDescriptor.getElementSerializer());
+        return wrappedBackend.getListState(newDescriptor);
+    }
+
+    @Override
+    public <S> ListState<S> getUnionListState(ListStateDescriptor<S> 
stateDescriptor)
+            throws Exception {
+        ListStateDescriptor<S> newDescriptor =
+                new ListStateDescriptor<S>(
+                        stateNamePrefix.prefix(stateDescriptor.getName()),
+                        stateDescriptor.getElementSerializer());
+        return wrappedBackend.getUnionListState(newDescriptor);
+    }
+
+    @Override
+    public Set<String> getRegisteredStateNames() {
+        Set<String> filteredNames = new HashSet<>();
+        Set<String> names = wrappedBackend.getRegisteredStateNames();
+
+        for (String name : names) {
+            if (name.startsWith(stateNamePrefix.getNamePrefix())) {
+                
filteredNames.add(name.substring(stateNamePrefix.getNamePrefix().length()));
+            }
+        }
+
+        return filteredNames;
+    }
+
+    @Override
+    public Set<String> getRegisteredBroadcastStateNames() {
+        Set<String> filteredNames = new HashSet<>();
+        Set<String> names = wrappedBackend.getRegisteredBroadcastStateNames();
+
+        for (String name : names) {
+            if (name.startsWith(stateNamePrefix.getNamePrefix())) {
+                
filteredNames.add(name.substring(stateNamePrefix.getNamePrefix().length()));
+            }
+        }
+
+        return filteredNames;
+    }
+
+    @Override
+    public void dispose() {
+        // Do not dispose for proxy.
+    }
+
+    @Override
+    public void close() throws IOException {
+        // Do not close for proxy.
+    }
+
+    @Nonnull
+    @Override
+    public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
+            long checkpointId,
+            long timestamp,
+            @Nonnull CheckpointStreamFactory streamFactory,
+            @Nonnull CheckpointOptions checkpointOptions)
+            throws Exception {
+        return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, 
checkpointOptions);
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStateSnapshotContext.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStateSnapshotContext.java
new file mode 100644
index 0000000..35d164c
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStateSnapshotContext.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.proxy.state;
+
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+/** Proxy {@link StateSnapshotContext} for the wrapped operators. */
+public class ProxyStateSnapshotContext implements StateSnapshotContext {
+
+    private final StateSnapshotContext wrappedContext;
+
+    public ProxyStateSnapshotContext(StateSnapshotContext wrappedContext) {
+        this.wrappedContext = wrappedContext;
+    }
+
+    @Override
+    public KeyedStateCheckpointOutputStream getRawKeyedOperatorStateOutput() 
throws Exception {
+        throw new UnsupportedOperationException(
+                "Currently we do not support the raw operator state inside the 
iteration.");
+    }
+
+    @Override
+    public OperatorStateCheckpointOutputStream getRawOperatorStateOutput() 
throws Exception {
+        throw new UnsupportedOperationException(
+                "Currently we do not support the raw keyed state inside the 
iteration.");
+    }
+
+    @Override
+    public long getCheckpointId() {
+        return wrappedContext.getCheckpointId();
+    }
+
+    @Override
+    public long getCheckpointTimestamp() {
+        return wrappedContext.getCheckpointTimestamp();
+    }
+}
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStreamOperatorStateContext.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStreamOperatorStateContext.java
new file mode 100644
index 0000000..a892cc5
--- /dev/null
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStreamOperatorStateContext.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.proxy.state;
+
+import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
+import org.apache.flink.util.CloseableIterable;
+
+import java.util.Objects;
+import java.util.OptionalLong;
+
+/** Proxy {@link StreamOperatorStateContext} for the wrapped operator. */
+public class ProxyStreamOperatorStateContext implements 
StreamOperatorStateContext {
+
+    private final StreamOperatorStateContext wrapped;
+
+    private final StateNamePrefix stateNamePrefix;
+
+    public ProxyStreamOperatorStateContext(
+            StreamOperatorStateContext wrapped, String stateNamePrefix) {
+        this.wrapped = Objects.requireNonNull(wrapped);
+        this.stateNamePrefix = new StateNamePrefix(stateNamePrefix);
+    }
+
+    @Override
+    public boolean isRestored() {
+        return wrapped.isRestored();
+    }
+
+    @Override
+    public OptionalLong getRestoredCheckpointId() {
+        return wrapped.getRestoredCheckpointId();
+    }
+
+    @Override
+    public OperatorStateBackend operatorStateBackend() {
+        return wrapped.operatorStateBackend() == null
+                ? null
+                : new 
ProxyOperatorStateBackend(wrapped.operatorStateBackend(), stateNamePrefix);
+    }
+
+    @Override
+    public CheckpointableKeyedStateBackend<?> keyedStateBackend() {
+        return wrapped.keyedStateBackend() == null
+                ? null
+                : new ProxyKeyedStateBackend<>(wrapped.keyedStateBackend(), 
stateNamePrefix);
+    }
+
+    @Override
+    public InternalTimeServiceManager<?> internalTimerServiceManager() {
+        return wrapped.internalTimerServiceManager() == null
+                ? null
+                : new ProxyInternalTimeServiceManager<>(
+                        wrapped.internalTimerServiceManager(), 
stateNamePrefix);
+    }
+
+    @Override
+    public CloseableIterable<StatePartitionStreamProvider> 
rawOperatorStateInputs() {
+        return CloseableIterable.empty();
+    }
+
+    @Override
+    public CloseableIterable<KeyGroupStatePartitionStreamProvider> 
rawKeyedStateInputs() {
+        return CloseableIterable.empty();
+    }
+}
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/StateNamePrefix.java
similarity index 66%
copy from 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
copy to 
flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/StateNamePrefix.java
index 3328a07..dbe507c 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/StateNamePrefix.java
@@ -16,20 +16,22 @@
  * limitations under the License.
  */
 
-package org.apache.flink.iteration.operator.allround;
+package org.apache.flink.iteration.proxy.state;
 
-/** Utilities to track the life-cycle of the operators. */
-public enum LifeCycle {
-    SETUP,
-    OPEN,
-    INITIALIZE_STATE,
-    PROCESS_ELEMENT,
-    PROCESS_ELEMENT_1,
-    PROCESS_ELEMENT_2,
-    PREPARE_SNAPSHOT_PRE_BARRIER,
-    SNAPSHOT_STATE,
-    NOTIFY_CHECKPOINT_COMPLETE,
-    NOTIFY_CHECKPOINT_ABORT,
-    FINISH,
-    CLOSE,
+/** The prefix for the state name. */
+public class StateNamePrefix {
+
+    private final String namePrefix;
+
+    public StateNamePrefix(String namePrefix) {
+        this.namePrefix = namePrefix;
+    }
+
+    public String getNamePrefix() {
+        return namePrefix;
+    }
+
+    public String prefix(String stateName) {
+        return namePrefix + stateName;
+    }
 }
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
index 3328a07..552c933 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
@@ -30,6 +30,8 @@ public enum LifeCycle {
     SNAPSHOT_STATE,
     NOTIFY_CHECKPOINT_COMPLETE,
     NOTIFY_CHECKPOINT_ABORT,
+    END_INPUT,
+    MAX_WATERMARK,
     FINISH,
     CLOSE,
 }
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
index 6491369..9a2b72a 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.operators.AbstractInput;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Input;
 import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -57,7 +58,7 @@ import static org.junit.Assert.assertEquals;
 /** Tests the {@link OneInputAllRoundWrapperOperator}. */
 public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
 
-    private static List<LifeCycle> lifeCycles = new ArrayList<>();
+    private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>();
 
     @Test
     public void testProcessElementsAndEpochWatermarks() throws Exception {
@@ -86,7 +87,7 @@ public class MultipleInputAllRoundWrapperOperatorTest extends 
TestLogger {
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-2")), 2);
 
-            // Check the output
+            // Checks the output
             assertEquals(
                     Arrays.asList(
                             new StreamRecord<>(IterationRecord.newRecord(5, 
1), 2),
@@ -97,7 +98,7 @@ public class MultipleInputAllRoundWrapperOperatorTest extends 
TestLogger {
                                             5, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
-            // Check the other lifecycles.
+            // Checks the other lifecycles.
             harness.getStreamTask()
                     .triggerCheckpointOnBarrier(
                             new CheckpointMetaData(5, 2),
@@ -132,15 +133,21 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
                             LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            // The first input
+                            LifeCycle.END_INPUT,
+                            // The second input
+                            LifeCycle.END_INPUT,
+                            // The third input
+                            LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
-                    lifeCycles);
+                    LIFE_CYCLES);
         }
     }
 
     private static class LifeCycleTrackingTwoInputStreamOperator
             extends AbstractStreamOperatorV2<Integer>
-            implements MultipleInputStreamOperator<Integer> {
+            implements MultipleInputStreamOperator<Integer>, BoundedMultiInput 
{
 
         private final int numberOfInputs;
 
@@ -159,7 +166,7 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
                             @Override
                             public void processElement(StreamRecord element) 
throws Exception {
                                 output.collect(element);
-                                lifeCycles.add(LifeCycle.PROCESS_ELEMENT);
+                                LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT);
                             }
                         });
             }
@@ -170,49 +177,54 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
         @Override
         public void open() throws Exception {
             super.open();
-            lifeCycles.add(LifeCycle.OPEN);
+            LIFE_CYCLES.add(LifeCycle.OPEN);
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            lifeCycles.add(LifeCycle.INITIALIZE_STATE);
+            LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE);
         }
 
         @Override
         public void finish() throws Exception {
             super.finish();
-            lifeCycles.add(LifeCycle.FINISH);
+            LIFE_CYCLES.add(LifeCycle.FINISH);
         }
 
         @Override
         public void close() throws Exception {
             super.close();
-            lifeCycles.add(LifeCycle.CLOSE);
+            LIFE_CYCLES.add(LifeCycle.CLOSE);
         }
 
         @Override
         public void prepareSnapshotPreBarrier(long checkpointId) throws 
Exception {
             super.prepareSnapshotPreBarrier(checkpointId);
-            lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
+            LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
         }
 
         @Override
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
-            lifeCycles.add(LifeCycle.SNAPSHOT_STATE);
+            LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE);
         }
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
             super.notifyCheckpointComplete(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
         }
 
         @Override
         public void notifyCheckpointAborted(long checkpointId) throws 
Exception {
             super.notifyCheckpointAborted(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+        }
+
+        @Override
+        public void endInput(int inputId) throws Exception {
+            LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
     }
 
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index c58f431..f628b65 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -34,6 +34,7 @@ import 
org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
@@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals;
 /** Tests the {@link OneInputAllRoundWrapperOperator}. */
 public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
 
-    private static List<LifeCycle> lifeCycles = new ArrayList<>();
+    private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>();
 
     @Test
     public void testProcessElementsAndEpochWatermarks() throws Exception {
@@ -78,7 +79,7 @@ public class OneInputAllRoundWrapperOperatorTest extends 
TestLogger {
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one")));
 
-            // Check the output
+            // Checks the output
             assertEquals(
                     Arrays.asList(
                             new StreamRecord<>(IterationRecord.newRecord(5, 
1), 2),
@@ -121,15 +122,16 @@ public class OneInputAllRoundWrapperOperatorTest extends 
TestLogger {
                             LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
-                    lifeCycles);
+                    LIFE_CYCLES);
         }
     }
 
     private static class LifeCycleTrackingOneInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements OneInputStreamOperator<Integer, Integer> {
+            implements OneInputStreamOperator<Integer, Integer>, 
BoundedOneInput {
 
         @Override
         public void setup(
@@ -137,61 +139,66 @@ public class OneInputAllRoundWrapperOperatorTest extends 
TestLogger {
                 StreamConfig config,
                 Output<StreamRecord<Integer>> output) {
             super.setup(containingTask, config, output);
-            lifeCycles.add(LifeCycle.SETUP);
+            LIFE_CYCLES.add(LifeCycle.SETUP);
         }
 
         @Override
         public void open() throws Exception {
             super.open();
-            lifeCycles.add(LifeCycle.OPEN);
+            LIFE_CYCLES.add(LifeCycle.OPEN);
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            lifeCycles.add(LifeCycle.INITIALIZE_STATE);
+            LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE);
         }
 
         @Override
         public void finish() throws Exception {
             super.finish();
-            lifeCycles.add(LifeCycle.FINISH);
+            LIFE_CYCLES.add(LifeCycle.FINISH);
         }
 
         @Override
         public void close() throws Exception {
             super.close();
-            lifeCycles.add(LifeCycle.CLOSE);
+            LIFE_CYCLES.add(LifeCycle.CLOSE);
         }
 
         @Override
         public void prepareSnapshotPreBarrier(long checkpointId) throws 
Exception {
             super.prepareSnapshotPreBarrier(checkpointId);
-            lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
+            LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
         }
 
         @Override
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
-            lifeCycles.add(LifeCycle.SNAPSHOT_STATE);
+            LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE);
         }
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
             super.notifyCheckpointComplete(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
         }
 
         @Override
         public void notifyCheckpointAborted(long checkpointId) throws 
Exception {
             super.notifyCheckpointAborted(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
         }
 
         @Override
         public void processElement(StreamRecord<Integer> element) throws 
Exception {
             output.collect(element);
-            lifeCycles.add(LifeCycle.PROCESS_ELEMENT);
+            LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT);
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
     }
 }
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
index e7e2604..82d5854 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
@@ -34,6 +34,7 @@ import 
org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
@@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals;
 /** Tests the {@link OneInputAllRoundWrapperOperator}. */
 public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
 
-    private static List<LifeCycle> lifeCycles = new ArrayList<>();
+    private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>();
 
     @Test
     public void testProcessElementsAndEpochWatermarks() throws Exception {
@@ -81,7 +82,7 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-1")), 1);
 
-            // Check the output
+            // Checks the output
             assertEquals(
                     Arrays.asList(
                             new StreamRecord<>(IterationRecord.newRecord(5, 
1), 2),
@@ -91,7 +92,7 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
                                             5, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
-            // Check the other lifecycles.
+            // Checks the other lifecycles.
             harness.getStreamTask()
                     .triggerCheckpointOnBarrier(
                             new CheckpointMetaData(5, 2),
@@ -125,15 +126,19 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
                             LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            // The first input
+                            LifeCycle.END_INPUT,
+                            // The second input
+                            LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
-                    lifeCycles);
+                    LIFE_CYCLES);
         }
     }
 
     private static class LifeCycleTrackingTwoInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements TwoInputStreamOperator<Integer, Integer, Integer> {
+            implements TwoInputStreamOperator<Integer, Integer, Integer>, 
BoundedMultiInput {
 
         @Override
         public void setup(
@@ -141,67 +146,72 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
                 StreamConfig config,
                 Output<StreamRecord<Integer>> output) {
             super.setup(containingTask, config, output);
-            lifeCycles.add(LifeCycle.SETUP);
+            LIFE_CYCLES.add(LifeCycle.SETUP);
         }
 
         @Override
         public void open() throws Exception {
             super.open();
-            lifeCycles.add(LifeCycle.OPEN);
+            LIFE_CYCLES.add(LifeCycle.OPEN);
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            lifeCycles.add(LifeCycle.INITIALIZE_STATE);
+            LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE);
         }
 
         @Override
         public void finish() throws Exception {
             super.finish();
-            lifeCycles.add(LifeCycle.FINISH);
+            LIFE_CYCLES.add(LifeCycle.FINISH);
         }
 
         @Override
         public void close() throws Exception {
             super.close();
-            lifeCycles.add(LifeCycle.CLOSE);
+            LIFE_CYCLES.add(LifeCycle.CLOSE);
         }
 
         @Override
         public void prepareSnapshotPreBarrier(long checkpointId) throws 
Exception {
             super.prepareSnapshotPreBarrier(checkpointId);
-            lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
+            LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
         }
 
         @Override
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
-            lifeCycles.add(LifeCycle.SNAPSHOT_STATE);
+            LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE);
         }
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
             super.notifyCheckpointComplete(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
         }
 
         @Override
         public void notifyCheckpointAborted(long checkpointId) throws 
Exception {
             super.notifyCheckpointAborted(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
         }
 
         @Override
         public void processElement1(StreamRecord<Integer> element) throws 
Exception {
             output.collect(element);
-            lifeCycles.add(LifeCycle.PROCESS_ELEMENT_1);
+            LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_1);
         }
 
         @Override
         public void processElement2(StreamRecord<Integer> element) throws 
Exception {
             output.collect(element);
-            lifeCycles.add(LifeCycle.PROCESS_ELEMENT_2);
+            LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_2);
+        }
+
+        @Override
+        public void endInput(int inputId) throws Exception {
+            LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
     }
 }
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
similarity index 68%
copy from 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
copy to 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
index 6491369..9dba5fa 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
@@ -16,12 +16,14 @@
  * limitations under the License.
  */
 
-package org.apache.flink.iteration.operator.allround;
+package org.apache.flink.iteration.operator.perround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.operator.allround.LifeCycle;
+import 
org.apache.flink.iteration.operator.allround.OneInputAllRoundWrapperOperator;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -35,6 +37,7 @@ import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.operators.AbstractInput;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Input;
 import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -55,16 +58,16 @@ import java.util.List;
 import static org.junit.Assert.assertEquals;
 
 /** Tests the {@link OneInputAllRoundWrapperOperator}. */
-public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
+public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
 
-    private static List<LifeCycle> lifeCycles = new ArrayList<>();
+    private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>();
 
     @Test
     public void testProcessElementsAndEpochWatermarks() throws Exception {
         StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
                 new WrapperOperatorFactory<>(
-                        new LifeCycleTrackingTwoInputStreamOperatorFactory(),
-                        new AllRoundOperatorWrapper<>());
+                        new LifeCycleTrackingMultiInputStreamOperatorFactory(),
+                        new PerRoundOperatorWrapper<>());
         OperatorID operatorId = new OperatorID();
 
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
@@ -77,27 +80,16 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
                         .setupOutputForSingletonOperatorChain(wrapperFactory, 
operatorId)
                         .build()) {
             harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
-            harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(6, 2), 3), 1);
-            harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(7, 3), 4), 2);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-0")), 0);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-1")), 1);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-2")), 2);
+            harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(6, 2), 3), 2);
 
-            // Check the output
+            // Checks the output
             assertEquals(
                     Arrays.asList(
                             new StreamRecord<>(IterationRecord.newRecord(5, 
1), 2),
-                            new StreamRecord<>(IterationRecord.newRecord(6, 
2), 3),
-                            new StreamRecord<>(IterationRecord.newRecord(7, 
3), 4),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            5, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                            new StreamRecord<>(IterationRecord.newRecord(6, 
2), 3)),
                     new ArrayList<>(harness.getOutput()));
 
-            // Check the other lifecycles.
+            // Checks the other lifecycles.
             harness.getStreamTask()
                     .triggerCheckpointOnBarrier(
                             new CheckpointMetaData(5, 2),
@@ -115,6 +107,31 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
             harness.getStreamTask().notifyCheckpointAbortAsync(6, 5);
             harness.processAll();
 
+            harness.getOutput().clear();
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"only-one-0")), 0);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"only-one-1")), 1);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"only-one-2")), 2);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, 
"only-one-0")), 0);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, 
"only-one-1")), 1);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, 
"only-one-2")), 2);
+
+            // Checks the output
+            assertEquals(
+                    Arrays.asList(
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            1, 
OperatorUtils.getUniqueSenderId(operatorId, 0))),
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            2, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                    new ArrayList<>(harness.getOutput()));
+
             harness.processEvent(EndOfData.INSTANCE, 0);
             harness.processEvent(EndOfData.INSTANCE, 1);
             harness.processEvent(EndOfData.INSTANCE, 2);
@@ -123,28 +140,50 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
 
             assertEquals(
                     Arrays.asList(
+                            /* First wrapped operator */
                             LifeCycle.INITIALIZE_STATE,
                             LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT,
+                            /* second wrapped operator */
+                            LifeCycle.INITIALIZE_STATE,
+                            LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT,
-                            LifeCycle.PROCESS_ELEMENT,
+                            /* states */
+                            LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.SNAPSHOT_STATE,
+                            LifeCycle.SNAPSHOT_STATE,
+                            LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            // The first input
+                            LifeCycle.END_INPUT,
+                            // The second input
+                            LifeCycle.END_INPUT,
+                            // The third input
+                            LifeCycle.END_INPUT,
+                            LifeCycle.FINISH,
+                            LifeCycle.CLOSE,
+                            // The first input
+                            LifeCycle.END_INPUT,
+                            // The second input
+                            LifeCycle.END_INPUT,
+                            // The third input
+                            LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
-                    lifeCycles);
+                    LIFE_CYCLES);
         }
     }
 
-    private static class LifeCycleTrackingTwoInputStreamOperator
+    private static class LifeCycleTrackingMultiInputStreamOperator
             extends AbstractStreamOperatorV2<Integer>
-            implements MultipleInputStreamOperator<Integer> {
+            implements MultipleInputStreamOperator<Integer>, BoundedMultiInput 
{
 
         private final int numberOfInputs;
 
-        public LifeCycleTrackingTwoInputStreamOperator(
+        public LifeCycleTrackingMultiInputStreamOperator(
                 StreamOperatorParameters<Integer> parameters, int 
numberOfInputs) {
             super(parameters, numberOfInputs);
             this.numberOfInputs = numberOfInputs;
@@ -159,7 +198,7 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
                             @Override
                             public void processElement(StreamRecord element) 
throws Exception {
                                 output.collect(element);
-                                lifeCycles.add(LifeCycle.PROCESS_ELEMENT);
+                                LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT);
                             }
                         });
             }
@@ -170,65 +209,70 @@ public class MultipleInputAllRoundWrapperOperatorTest 
extends TestLogger {
         @Override
         public void open() throws Exception {
             super.open();
-            lifeCycles.add(LifeCycle.OPEN);
+            LIFE_CYCLES.add(LifeCycle.OPEN);
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            lifeCycles.add(LifeCycle.INITIALIZE_STATE);
+            LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE);
         }
 
         @Override
         public void finish() throws Exception {
             super.finish();
-            lifeCycles.add(LifeCycle.FINISH);
+            LIFE_CYCLES.add(LifeCycle.FINISH);
         }
 
         @Override
         public void close() throws Exception {
             super.close();
-            lifeCycles.add(LifeCycle.CLOSE);
+            LIFE_CYCLES.add(LifeCycle.CLOSE);
         }
 
         @Override
         public void prepareSnapshotPreBarrier(long checkpointId) throws 
Exception {
             super.prepareSnapshotPreBarrier(checkpointId);
-            lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
+            LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
         }
 
         @Override
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
-            lifeCycles.add(LifeCycle.SNAPSHOT_STATE);
+            LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE);
         }
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
             super.notifyCheckpointComplete(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
         }
 
         @Override
         public void notifyCheckpointAborted(long checkpointId) throws 
Exception {
             super.notifyCheckpointAborted(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+        }
+
+        @Override
+        public void endInput(int inputId) throws Exception {
+            LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
     }
 
-    /** The operator factory for the lifecycle-tracking operator. */
-    public static class LifeCycleTrackingTwoInputStreamOperatorFactory
+    /** Life-cycle tracking stream operator factory. */
+    private static class LifeCycleTrackingMultiInputStreamOperatorFactory
             extends AbstractStreamOperatorFactory<Integer> {
 
         @Override
         public <T extends StreamOperator<Integer>> T createStreamOperator(
                 StreamOperatorParameters<Integer> parameters) {
-            return (T) new LifeCycleTrackingTwoInputStreamOperator(parameters, 
3);
+            return (T) new 
LifeCycleTrackingMultiInputStreamOperator(parameters, 3);
         }
 
         @Override
         public Class<? extends StreamOperator> 
getStreamOperatorClass(ClassLoader classLoader) {
-            return LifeCycleTrackingTwoInputStreamOperator.class;
+            return LifeCycleTrackingMultiInputStreamOperator.class;
         }
     }
 }
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
similarity index 73%
copy from 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
copy to 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
index c58f431..787dd0f 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
@@ -16,12 +16,13 @@
  * limitations under the License.
  */
 
-package org.apache.flink.iteration.operator.allround;
+package org.apache.flink.iteration.operator.perround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.operator.allround.LifeCycle;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -34,6 +35,7 @@ import 
org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
@@ -53,17 +55,17 @@ import java.util.List;
 
 import static org.junit.Assert.assertEquals;
 
-/** Tests the {@link OneInputAllRoundWrapperOperator}. */
-public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
+/** Tests the {@link OneInputPerRoundWrapperOperator}. */
+public class OneInputPerRoundWrapperOperatorTest extends TestLogger {
 
-    private static List<LifeCycle> lifeCycles = new ArrayList<>();
+    private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>();
 
     @Test
     public void testProcessElementsAndEpochWatermarks() throws Exception {
         StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
                 new WrapperOperatorFactory<>(
                         SimpleOperatorFactory.of(new 
LifeCycleTrackingOneInputStreamOperator()),
-                        new AllRoundOperatorWrapper<>());
+                        new PerRoundOperatorWrapper<>());
         OperatorID operatorId = new OperatorID();
 
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
@@ -75,20 +77,15 @@ public class OneInputAllRoundWrapperOperatorTest extends 
TestLogger {
                         .build()) {
             harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(5, 1), 2));
             harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(6, 2), 3));
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one")));
 
-            // Check the output
+            // Checks the output
             assertEquals(
                     Arrays.asList(
                             new StreamRecord<>(IterationRecord.newRecord(5, 
1), 2),
-                            new StreamRecord<>(IterationRecord.newRecord(6, 
2), 3),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            5, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                            new StreamRecord<>(IterationRecord.newRecord(6, 
2), 3)),
                     new ArrayList<>(harness.getOutput()));
 
-            // Check the other lifecycles.
+            // Checks the other lifecycles.
             harness.getStreamTask()
                     .triggerCheckpointOnBarrier(
                             new CheckpointMetaData(5, 2),
@@ -106,30 +103,61 @@ public class OneInputAllRoundWrapperOperatorTest extends 
TestLogger {
             harness.getStreamTask().notifyCheckpointAbortAsync(6, 5);
             harness.processAll();
 
+            harness.getOutput().clear();
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"only-one")));
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, 
"only-one")));
+
+            // Checks the output
+            assertEquals(
+                    Arrays.asList(
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            1, 
OperatorUtils.getUniqueSenderId(operatorId, 0))),
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            2, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                    new ArrayList<>(harness.getOutput()));
+
             harness.processEvent(EndOfData.INSTANCE, 0);
             harness.endInput();
             harness.finishProcessing();
 
             assertEquals(
                     Arrays.asList(
+                            /* First wrapped operator */
                             LifeCycle.SETUP,
                             LifeCycle.INITIALIZE_STATE,
                             LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT,
+                            /* second wrapped operator */
+                            LifeCycle.SETUP,
+                            LifeCycle.INITIALIZE_STATE,
+                            LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT,
+                            /* states */
+                            LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.SNAPSHOT_STATE,
+                            LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
+                            LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
+                            LifeCycle.NOTIFY_CHECKPOINT_ABORT,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.END_INPUT,
+                            LifeCycle.FINISH,
+                            LifeCycle.CLOSE,
+                            LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
-                    lifeCycles);
+                    LIFE_CYCLES);
         }
     }
 
     private static class LifeCycleTrackingOneInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements OneInputStreamOperator<Integer, Integer> {
+            implements OneInputStreamOperator<Integer, Integer>, 
BoundedOneInput {
 
         @Override
         public void setup(
@@ -137,61 +165,66 @@ public class OneInputAllRoundWrapperOperatorTest extends 
TestLogger {
                 StreamConfig config,
                 Output<StreamRecord<Integer>> output) {
             super.setup(containingTask, config, output);
-            lifeCycles.add(LifeCycle.SETUP);
+            LIFE_CYCLES.add(LifeCycle.SETUP);
         }
 
         @Override
         public void open() throws Exception {
             super.open();
-            lifeCycles.add(LifeCycle.OPEN);
+            LIFE_CYCLES.add(LifeCycle.OPEN);
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            lifeCycles.add(LifeCycle.INITIALIZE_STATE);
+            LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE);
         }
 
         @Override
         public void finish() throws Exception {
             super.finish();
-            lifeCycles.add(LifeCycle.FINISH);
+            LIFE_CYCLES.add(LifeCycle.FINISH);
         }
 
         @Override
         public void close() throws Exception {
             super.close();
-            lifeCycles.add(LifeCycle.CLOSE);
+            LIFE_CYCLES.add(LifeCycle.CLOSE);
         }
 
         @Override
         public void prepareSnapshotPreBarrier(long checkpointId) throws 
Exception {
             super.prepareSnapshotPreBarrier(checkpointId);
-            lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
+            LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
         }
 
         @Override
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
-            lifeCycles.add(LifeCycle.SNAPSHOT_STATE);
+            LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE);
         }
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
             super.notifyCheckpointComplete(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
         }
 
         @Override
         public void notifyCheckpointAborted(long checkpointId) throws 
Exception {
             super.notifyCheckpointAborted(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
         }
 
         @Override
         public void processElement(StreamRecord<Integer> element) throws 
Exception {
             output.collect(element);
-            lifeCycles.add(LifeCycle.PROCESS_ELEMENT);
+            LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT);
+        }
+
+        @Override
+        public void endInput() throws Exception {
+            LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
     }
 }
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorStateTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorStateTest.java
new file mode 100644
index 0000000..c720e57
--- /dev/null
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorStateTest.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.perround;
+
+import org.apache.flink.api.common.state.BroadcastState;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.java.typeutils.EnumTypeInfo;
+import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
+import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
+import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
+import org.apache.flink.streaming.api.operators.ProcessOperator;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
+import 
org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Test;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests the state isolation and cleanup for the per-round operators. */
+public class PerRoundOperatorStateTest extends TestLogger {
+
+    @Test
+    public void testStateIsolationWithoutKeyedStateBackend() throws Exception {
+        testStateIsolation(null);
+    }
+
+    @Test
+    public void testStateIsolationWithHashMapKeyedStateBackend() throws 
Exception {
+        testStateIsolation(new HashMapStateBackend());
+    }
+
+    @Test
+    public void testStateIsolationWithRocksDBKeyedStateBackend() throws 
Exception {
+        testStateIsolation(new EmbeddedRocksDBStateBackend());
+    }
+
+    private void testStateIsolation(@Nullable StateBackend stateBackend) 
throws Exception {
+        StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
+                new WrapperOperatorFactory<>(
+                        SimpleOperatorFactory.of(
+                                stateBackend == null
+                                        ? new ProcessOperator<>(new 
StatefulProcessFunction())
+                                        : new KeyedProcessOperator<>(
+                                                new 
KeyedStatefulProcessFunction())),
+                        new PerRoundOperatorWrapper<>());
+        OperatorID operatorId = new OperatorID();
+
+        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new,
+                                new 
IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .modifyStreamConfig(
+                                streamConfig -> {
+                                    if (stateBackend != null) {
+                                        
streamConfig.setStateBackend(stateBackend);
+                                        
streamConfig.setManagedMemoryFractionOperatorOfUseCase(
+                                                
ManagedMemoryUseCase.STATE_BACKEND, 0.2);
+                                        
streamConfig.setStateKeySerializer(IntSerializer.INSTANCE);
+                                    }
+                                })
+                        .addInput(
+                                new IterationRecordTypeInfo<>(new 
EnumTypeInfo<>(ActionType.class)),
+                                1,
+                                stateBackend == null ? null : new 
ProxyKeySelector<>(x -> 10))
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, 
operatorId)
+                        .build()) {
+
+            // Set round 0
+            harness.processElement(
+                    new 
StreamRecord<>(IterationRecord.newRecord(ActionType.SET, 0)), 0);
+            testGetRound(harness, Arrays.asList(10, 10, stateBackend == null ? 
-1 : 10), 0);
+            testGetRound(harness, Arrays.asList(-1, -1, -1), 1);
+
+            // Set round 1
+            harness.processElement(
+                    new 
StreamRecord<>(IterationRecord.newRecord(ActionType.SET, 1)), 0);
+            testGetRound(harness, Arrays.asList(10, 10, stateBackend == null ? 
-1 : 10), 1);
+
+            // Clear round 0. Although after round 0 we should not receive 
records for round 0 in
+            // realistic, we use this method to check the current value of 
states.
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, 
"sender")), 0);
+            testGetRound(harness, Arrays.asList(-1, -1, -1), 0);
+            testGetRound(harness, Arrays.asList(10, 10, stateBackend == null ? 
-1 : 10), 1);
+
+            // Clear round 1
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"sender")), 0);
+            testGetRound(harness, Arrays.asList(-1, -1, -1), 0);
+            testGetRound(harness, Arrays.asList(-1, -1, -1), 1);
+        }
+    }
+
+    private void testGetRound(
+            StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness,
+            List<Integer> expectedValues,
+            int round)
+            throws Exception {
+        harness.getOutput().clear();
+        harness.processElement(
+                new StreamRecord<>(IterationRecord.newRecord(ActionType.GET, 
round)), 0);
+
+        assertEquals(
+                expectedValues.stream()
+                        .map(i -> IterationRecord.newRecord(i, round))
+                        .collect(Collectors.toList()),
+                harness.getOutput().stream()
+                        .map(r -> ((StreamRecord<?>) r).getValue())
+                        .collect(Collectors.toList()));
+    }
+
+    enum ActionType {
+        SET,
+        GET
+    }
+
+    private static class StatefulProcessFunction extends 
ProcessFunction<ActionType, Integer>
+            implements CheckpointedFunction {
+
+        private transient BroadcastState<Integer, Integer> broadcastState;
+
+        private transient ListState<Integer> operatorState;
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            this.operatorState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("opState", 
IntSerializer.INSTANCE));
+            this.broadcastState =
+                    context.getOperatorStateStore()
+                            .getBroadcastState(
+                                    new MapStateDescriptor<>(
+                                            "broadState",
+                                            IntSerializer.INSTANCE,
+                                            IntSerializer.INSTANCE));
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) throws 
Exception {}
+
+        @Override
+        public void processElement(ActionType value, Context ctx, 
Collector<Integer> out)
+                throws Exception {
+            switch (value) {
+                case SET:
+                    operatorState.add(10);
+                    broadcastState.put(10, 10);
+                    break;
+                case GET:
+                    out.collect(
+                            operatorState.get().iterator().hasNext()
+                                    ? operatorState.get().iterator().next()
+                                    : -1);
+                    out.collect(mapNullToMinusOne(broadcastState.get(10)));
+                    // To keep the same amount of outputs with the keyed one.
+                    out.collect(mapNullToMinusOne(null));
+                    break;
+            }
+        }
+    }
+
+    private static class KeyedStatefulProcessFunction
+            extends KeyedProcessFunction<Integer, ActionType, Integer>
+            implements CheckpointedFunction {
+
+        private transient BroadcastState<Integer, Integer> broadcastState;
+
+        private transient ListState<Integer> operatorState;
+
+        private transient ValueState<Integer> keyedState;
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            this.operatorState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("opState", 
IntSerializer.INSTANCE));
+            this.broadcastState =
+                    context.getOperatorStateStore()
+                            .getBroadcastState(
+                                    new MapStateDescriptor<>(
+                                            "broadState",
+                                            IntSerializer.INSTANCE,
+                                            IntSerializer.INSTANCE));
+            this.keyedState =
+                    context.getKeyedStateStore()
+                            .getState(
+                                    new ValueStateDescriptor<>(
+                                            "keyedState", 
IntSerializer.INSTANCE));
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) throws 
Exception {}
+
+        @Override
+        public void processElement(ActionType value, Context ctx, 
Collector<Integer> out)
+                throws Exception {
+            switch (value) {
+                case SET:
+                    operatorState.add(10);
+                    broadcastState.put(10, 10);
+                    keyedState.update(10);
+                    break;
+                case GET:
+                    out.collect(
+                            operatorState.get().iterator().hasNext()
+                                    ? operatorState.get().iterator().next()
+                                    : -1);
+                    out.collect(mapNullToMinusOne(broadcastState.get(10)));
+                    out.collect(mapNullToMinusOne(keyedState == null ? null : 
keyedState.value()));
+                    break;
+            }
+        }
+    }
+
+    private static int mapNullToMinusOne(Integer value) {
+        return value == null ? -1 : value;
+    }
+}
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
similarity index 71%
copy from 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
copy to 
flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
index e7e2604..f2134b8 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
@@ -16,12 +16,13 @@
  * limitations under the License.
  */
 
-package org.apache.flink.iteration.operator.allround;
+package org.apache.flink.iteration.operator.perround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.operator.allround.LifeCycle;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -34,6 +35,7 @@ import 
org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
@@ -53,17 +55,17 @@ import java.util.List;
 
 import static org.junit.Assert.assertEquals;
 
-/** Tests the {@link OneInputAllRoundWrapperOperator}. */
-public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
+/** Tests the {@link OneInputPerRoundWrapperOperator}. */
+public class TwoInputPerRoundWrapperOperatorTest extends TestLogger {
 
-    private static List<LifeCycle> lifeCycles = new ArrayList<>();
+    private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>();
 
     @Test
     public void testProcessElementsAndEpochWatermarks() throws Exception {
         StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
                 new WrapperOperatorFactory<>(
                         SimpleOperatorFactory.of(new 
LifeCycleTrackingTwoInputStreamOperator()),
-                        new AllRoundOperatorWrapper<>());
+                        new PerRoundOperatorWrapper<>());
         OperatorID operatorId = new OperatorID();
 
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
@@ -76,22 +78,15 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
                         .build()) {
             harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
             harness.processElement(new 
StreamRecord<>(IterationRecord.newRecord(6, 2), 3), 1);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-0")), 0);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, 
"only-one-1")), 1);
 
-            // Check the output
+            // Checks the output
             assertEquals(
                     Arrays.asList(
                             new StreamRecord<>(IterationRecord.newRecord(5, 
1), 2),
-                            new StreamRecord<>(IterationRecord.newRecord(6, 
2), 3),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            5, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                            new StreamRecord<>(IterationRecord.newRecord(6, 
2), 3)),
                     new ArrayList<>(harness.getOutput()));
 
-            // Check the other lifecycles.
+            // Checks the other lifecycles.
             harness.getStreamTask()
                     .triggerCheckpointOnBarrier(
                             new CheckpointMetaData(5, 2),
@@ -109,6 +104,27 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
             harness.getStreamTask().notifyCheckpointAbortAsync(6, 5);
             harness.processAll();
 
+            harness.getOutput().clear();
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"only-one")), 0);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, 
"only-one")), 1);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, 
"only-one")), 0);
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, 
"only-one")), 1);
+
+            // Checks the output
+            assertEquals(
+                    Arrays.asList(
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            1, 
OperatorUtils.getUniqueSenderId(operatorId, 0))),
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            2, 
OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                    new ArrayList<>(harness.getOutput()));
+
             harness.processEvent(EndOfData.INSTANCE, 0);
             harness.processEvent(EndOfData.INSTANCE, 1);
             harness.endInput();
@@ -116,24 +132,44 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
 
             assertEquals(
                     Arrays.asList(
+                            /* First wrapped operator */
                             LifeCycle.SETUP,
                             LifeCycle.INITIALIZE_STATE,
                             LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT_1,
+                            /* second wrapped operator */
+                            LifeCycle.SETUP,
+                            LifeCycle.INITIALIZE_STATE,
+                            LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT_2,
+                            /* states */
+                            LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.SNAPSHOT_STATE,
+                            LifeCycle.SNAPSHOT_STATE,
+                            LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            // The first input
+                            LifeCycle.END_INPUT,
+                            // The second input
+                            LifeCycle.END_INPUT,
+                            LifeCycle.FINISH,
+                            LifeCycle.CLOSE,
+                            // The first input
+                            LifeCycle.END_INPUT,
+                            // The second input
+                            LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
-                    lifeCycles);
+                    LIFE_CYCLES);
         }
     }
 
     private static class LifeCycleTrackingTwoInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements TwoInputStreamOperator<Integer, Integer, Integer> {
+            implements TwoInputStreamOperator<Integer, Integer, Integer>, 
BoundedMultiInput {
 
         @Override
         public void setup(
@@ -141,67 +177,72 @@ public class TwoInputAllRoundWrapperOperatorTest extends 
TestLogger {
                 StreamConfig config,
                 Output<StreamRecord<Integer>> output) {
             super.setup(containingTask, config, output);
-            lifeCycles.add(LifeCycle.SETUP);
+            LIFE_CYCLES.add(LifeCycle.SETUP);
         }
 
         @Override
         public void open() throws Exception {
             super.open();
-            lifeCycles.add(LifeCycle.OPEN);
+            LIFE_CYCLES.add(LifeCycle.OPEN);
         }
 
         @Override
         public void initializeState(StateInitializationContext context) throws 
Exception {
             super.initializeState(context);
-            lifeCycles.add(LifeCycle.INITIALIZE_STATE);
+            LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE);
         }
 
         @Override
         public void finish() throws Exception {
             super.finish();
-            lifeCycles.add(LifeCycle.FINISH);
+            LIFE_CYCLES.add(LifeCycle.FINISH);
         }
 
         @Override
         public void close() throws Exception {
             super.close();
-            lifeCycles.add(LifeCycle.CLOSE);
+            LIFE_CYCLES.add(LifeCycle.CLOSE);
         }
 
         @Override
         public void prepareSnapshotPreBarrier(long checkpointId) throws 
Exception {
             super.prepareSnapshotPreBarrier(checkpointId);
-            lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
+            LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER);
         }
 
         @Override
         public void snapshotState(StateSnapshotContext context) throws 
Exception {
             super.snapshotState(context);
-            lifeCycles.add(LifeCycle.SNAPSHOT_STATE);
+            LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE);
         }
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws 
Exception {
             super.notifyCheckpointComplete(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE);
         }
 
         @Override
         public void notifyCheckpointAborted(long checkpointId) throws 
Exception {
             super.notifyCheckpointAborted(checkpointId);
-            lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
+            LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT);
         }
 
         @Override
         public void processElement1(StreamRecord<Integer> element) throws 
Exception {
             output.collect(element);
-            lifeCycles.add(LifeCycle.PROCESS_ELEMENT_1);
+            LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_1);
         }
 
         @Override
         public void processElement2(StreamRecord<Integer> element) throws 
Exception {
             output.collect(element);
-            lifeCycles.add(LifeCycle.PROCESS_ELEMENT_2);
+            LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_2);
+        }
+
+        @Override
+        public void endInput(int inputId) throws Exception {
+            LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
     }
 }

Reply via email to