This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit d1d5d00cfd9d12d47e89a7d765772993bd6b0d4a Author: Yun Gao <[email protected]> AuthorDate: Wed Sep 29 23:04:37 2021 +0800 [FLINK-24652][iteration] Add per-round operator wrappers This closes #14. --- flink-ml-iteration/pom.xml | 8 + .../flink/iteration/operator/OperatorUtils.java | 23 +- .../perround/AbstractPerRoundWrapperOperator.java | 488 +++++++++++++++++++++ .../MultipleInputPerRoundWrapperOperator.java | 180 ++++++++ .../perround/OneInputPerRoundWrapperOperator.java | 89 ++++ .../operator/perround/PerRoundOperatorWrapper.java | 85 ++++ .../perround/TwoInputPerRoundWrapperOperator.java | 142 ++++++ .../apache/flink/iteration/proxy/ProxyOutput.java | 2 +- .../state/ProxyInternalTimeServiceManager.java | 62 +++ .../proxy/state/ProxyKeyedStateBackend.java | 236 ++++++++++ .../proxy/state/ProxyOperatorStateBackend.java | 128 ++++++ .../proxy/state/ProxyStateSnapshotContext.java | 55 +++ .../state/ProxyStreamOperatorStateContext.java | 86 ++++ .../iteration/proxy/state/StateNamePrefix.java} | 32 +- .../iteration/operator/allround/LifeCycle.java | 2 + .../MultipleInputAllRoundWrapperOperatorTest.java | 40 +- .../OneInputAllRoundWrapperOperatorTest.java | 35 +- .../TwoInputAllRoundWrapperOperatorTest.java | 42 +- .../MultipleInputPerRoundWrapperOperatorTest.java} | 120 +++-- .../OneInputPerRoundWrapperOperatorTest.java} | 83 ++-- .../perround/PerRoundOperatorStateTest.java | 266 +++++++++++ .../TwoInputPerRoundWrapperOperatorTest.java} | 97 ++-- 22 files changed, 2142 insertions(+), 159 deletions(-) diff --git a/flink-ml-iteration/pom.xml b/flink-ml-iteration/pom.xml index d2cdb36..74044a3 100644 --- a/flink-ml-iteration/pom.xml +++ b/flink-ml-iteration/pom.xml @@ -61,6 +61,14 @@ under the License. <scope>provided</scope> </dependency> + <!-- We have special treatment with the rocksdb state. --> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-statebackend-rocksdb_${scala.binary.version}</artifactId> + <version>${flink.version}</version> + <scope>provided</scope> + </dependency> + <!-- Required for feedback edge implementation --> <dependency> <groupId>org.apache.flink</groupId> diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java index 992d715..1b2769e 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java @@ -26,10 +26,11 @@ import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer; import org.apache.flink.statefun.flink.core.feedback.FeedbackKey; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.function.ThrowingConsumer; import java.util.Arrays; import java.util.concurrent.Executor; -import java.util.function.Consumer; /** Utility class for operators. */ public class OperatorUtils { @@ -58,14 +59,20 @@ public class OperatorUtils { } public static <T> void processOperatorOrUdfIfSatisfy( - StreamOperator<?> operator, Class<T> targetInterface, Consumer<T> action) { - if (targetInterface.isAssignableFrom(operator.getClass())) { - action.accept((T) operator); - } else if (operator instanceof AbstractUdfStreamOperator<?, ?>) { - Object udf = ((AbstractUdfStreamOperator<?, ?>) operator).getUserFunction(); - if (targetInterface.isAssignableFrom(udf.getClass())) { - action.accept((T) udf); + StreamOperator<?> operator, + Class<T> targetInterface, + ThrowingConsumer<T, Exception> action) { + try { + if (targetInterface.isAssignableFrom(operator.getClass())) { + action.accept((T) operator); + } else if (operator instanceof AbstractUdfStreamOperator<?, ?>) { + Object udf = ((AbstractUdfStreamOperator<?, ?>) operator).getUserFunction(); + if (targetInterface.isAssignableFrom(udf.getClass())) { + action.accept((T) udf); + } } + } catch (Exception e) { + ExceptionUtils.rethrow(e); } } } diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java new file mode 100644 index 0000000..3903340 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.operator.perround; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.MetricOptions; +import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend; +import org.apache.flink.core.memory.ManagedMemoryUseCase; +import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.operator.AbstractWrapperOperator; +import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext; +import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext; +import org.apache.flink.iteration.utils.ReflectionUtils; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.metrics.groups.OperatorMetricGroup; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; +import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.StreamOperatorStateContext; +import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler; +import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.streaming.util.LatencyStats; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.function.BiConsumerWithException; + +import org.rocksdb.RocksDB; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +/** The base class for all the per-round wrapper operators. */ +public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperator<T>> + extends AbstractWrapperOperator<T> + implements StreamOperatorStateHandler.CheckpointedStreamOperator { + + private static final Logger LOG = + LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class); + + private static final String HEAP_KEYED_STATE_NAME = + "org.apache.flink.runtime.state.heap.HeapKeyedStateBackend"; + + private static final String ROCKSDB_KEYED_STATE_NAME = + "org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend"; + + /** The wrapped operators for each round. */ + private final Map<Integer, S> wrappedOperators; + + protected final LatencyStats latencyStats; + + private transient StreamOperatorStateContext streamOperatorStateContext; + + private transient StreamOperatorStateHandler stateHandler; + + private transient InternalTimeServiceManager<?> timeServiceManager; + + private transient KeySelector<?, ?> stateKeySelector1; + + private transient KeySelector<?, ?> stateKeySelector2; + + public AbstractPerRoundWrapperOperator( + StreamOperatorParameters<IterationRecord<T>> parameters, + StreamOperatorFactory<T> operatorFactory) { + super(parameters, operatorFactory); + + this.wrappedOperators = new HashMap<>(); + this.latencyStats = initializeLatencyStats(); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + protected S getWrappedOperator(int round) { + S wrappedOperator = wrappedOperators.get(round); + if (wrappedOperator != null) { + return wrappedOperator; + } + + // We need to clone the operator factory to also support SimpleOperatorFactory. + try { + StreamOperatorFactory<T> clonedOperatorFactory = + InstantiationUtil.clone(operatorFactory); + wrappedOperator = + (S) + StreamOperatorFactoryUtil.<T, S>createOperator( + clonedOperatorFactory, + (StreamTask) parameters.getContainingTask(), + parameters.getStreamConfig(), + proxyOutput, + parameters.getOperatorEventDispatcher()) + .f0; + initializeStreamOperator(wrappedOperator, round); + wrappedOperators.put(round, wrappedOperator); + return wrappedOperator; + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + + return wrappedOperator; + } + + protected abstract void endInputAndEmitMaxWatermark(S operator, int round) throws Exception; + + private void closeStreamOperator(S operator, int round) throws Exception { + setIterationContextRound(round); + endInputAndEmitMaxWatermark(operator, round); + operator.finish(); + operator.close(); + setIterationContextRound(null); + + // Cleanup the states used by this operator. + cleanupOperatorStates(round); + + if (stateHandler.getKeyedStateBackend() != null) { + cleanupKeyedStates(round); + } + } + + @Override + public void onEpochWatermarkIncrement(int epochWatermark) throws IOException { + try { + // Destroys all the operators with round < epoch watermark. Notes that + // the onEpochWatermarkIncrement must be from 0 and increment by 1 each time. + if (wrappedOperators.containsKey(epochWatermark)) { + closeStreamOperator(wrappedOperators.get(epochWatermark), epochWatermark); + wrappedOperators.remove(epochWatermark); + } + + super.onEpochWatermarkIncrement(epochWatermark); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } + + protected void processForEachWrappedOperator( + BiConsumerWithException<Integer, S, Exception> consumer) throws Exception { + for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) { + consumer.accept(entry.getKey(), entry.getValue()); + } + } + + @Override + public void open() throws Exception {} + + @Override + public void initializeState(StreamTaskStateInitializer streamTaskStateManager) + throws Exception { + final TypeSerializer<?> keySerializer = + streamConfig.getStateKeySerializer(containingTask.getUserCodeClassLoader()); + + streamOperatorStateContext = + streamTaskStateManager.streamOperatorStateContext( + getOperatorID(), + getClass().getSimpleName(), + parameters.getProcessingTimeService(), + this, + keySerializer, + containingTask.getCancelables(), + metrics, + streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot( + ManagedMemoryUseCase.STATE_BACKEND, + containingTask + .getEnvironment() + .getTaskManagerInfo() + .getConfiguration(), + containingTask.getUserCodeClassLoader()), + isUsingCustomRawKeyedState()); + + stateHandler = + new StreamOperatorStateHandler( + streamOperatorStateContext, + containingTask.getExecutionConfig(), + containingTask.getCancelables()); + stateHandler.initializeOperatorState(this); + this.timeServiceManager = streamOperatorStateContext.internalTimerServiceManager(); + + stateKeySelector1 = + streamConfig.getStatePartitioner(0, containingTask.getUserCodeClassLoader()); + stateKeySelector2 = + streamConfig.getStatePartitioner(1, containingTask.getUserCodeClassLoader()); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + // Do thing for now since we do not have states. + } + + @Internal + protected boolean isUsingCustomRawKeyedState() { + return false; + } + + @Override + public void finish() throws Exception { + for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) { + closeStreamOperator(entry.getValue(), entry.getKey()); + } + wrappedOperators.clear(); + } + + @Override + public void close() throws Exception { + if (stateHandler != null) { + stateHandler.dispose(); + } + } + + @Override + public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { + for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) { + entry.getValue().prepareSnapshotPreBarrier(checkpointId); + } + } + + @Override + public OperatorSnapshotFutures snapshotState( + long checkpointId, + long timestamp, + CheckpointOptions checkpointOptions, + CheckpointStreamFactory factory) + throws Exception { + return stateHandler.snapshotState( + this, + Optional.ofNullable(timeServiceManager), + streamConfig.getOperatorName(), + checkpointId, + timestamp, + checkpointOptions, + factory, + isUsingCustomRawKeyedState()); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) { + if (StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom( + entry.getValue().getClass())) { + ((StreamOperatorStateHandler.CheckpointedStreamOperator) entry.getValue()) + .snapshotState(new ProxyStateSnapshotContext(context)); + } + } + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement1(StreamRecord record) throws Exception { + setKeyContextElement(record, stateKeySelector1); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void setKeyContextElement2(StreamRecord record) throws Exception { + setKeyContextElement(record, stateKeySelector2); + } + + private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector) + throws Exception { + if (selector != null) { + Object key = selector.getKey(record.getValue()); + setCurrentKey(key); + } + } + + @Override + public OperatorMetricGroup getMetricGroup() { + return metrics; + } + + @Override + public OperatorID getOperatorID() { + return streamConfig.getOperatorID(); + } + + @Override + public void notifyCheckpointComplete(long l) throws Exception { + for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) { + entry.getValue().notifyCheckpointComplete(l); + } + } + + @Override + public void notifyCheckpointAborted(long checkpointId) throws Exception { + for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) { + entry.getValue().notifyCheckpointAborted(checkpointId); + } + } + + @Override + public void setCurrentKey(Object key) { + stateHandler.setCurrentKey(key); + } + + @Override + public Object getCurrentKey() { + if (stateHandler == null) { + return null; + } + + return stateHandler.getKeyedStateStore().orElse(null); + } + + protected void reportOrForwardLatencyMarker(LatencyMarker marker) { + // all operators are tracking latencies + this.latencyStats.reportLatency(marker); + + // everything except sinks forwards latency markers + this.output.emitLatencyMarker(marker); + } + + private LatencyStats initializeLatencyStats() { + try { + Configuration taskManagerConfig = + containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(); + int historySize = taskManagerConfig.getInteger(MetricOptions.LATENCY_HISTORY_SIZE); + if (historySize <= 0) { + LOG.warn( + "{} has been set to a value equal or below 0: {}. Using default.", + MetricOptions.LATENCY_HISTORY_SIZE, + historySize); + historySize = MetricOptions.LATENCY_HISTORY_SIZE.defaultValue(); + } + + final String configuredGranularity = + taskManagerConfig.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY); + LatencyStats.Granularity granularity; + try { + granularity = + LatencyStats.Granularity.valueOf( + configuredGranularity.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException iae) { + granularity = LatencyStats.Granularity.OPERATOR; + LOG.warn( + "Configured value {} option for {} is invalid. Defaulting to {}.", + configuredGranularity, + MetricOptions.LATENCY_SOURCE_GRANULARITY.key(), + granularity); + } + MetricGroup jobMetricGroup = this.metrics.getJobMetricGroup(); + return new LatencyStats( + jobMetricGroup.addGroup("latency"), + historySize, + containingTask.getIndexInSubtaskGroup(), + getOperatorID(), + granularity); + } catch (Exception e) { + LOG.warn("An error occurred while instantiating latency metrics.", e); + return new LatencyStats( + UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup() + .addGroup("latency"), + 1, + 0, + new OperatorID(), + LatencyStats.Granularity.SINGLE); + } + } + + private void initializeStreamOperator(S operator, int round) throws Exception { + operator.initializeState( + (operatorID, + operatorClassName, + processingTimeService, + keyContext, + keySerializer, + streamTaskCloseableRegistry, + metricGroup, + managedMemoryFraction, + isUsingCustomRawKeyedState) -> + new ProxyStreamOperatorStateContext( + streamOperatorStateContext, getRoundStatePrefix(round))); + operator.open(); + } + + private void cleanupOperatorStates(int round) { + String roundPrefix = getRoundStatePrefix(round); + OperatorStateBackend operatorStateBackend = stateHandler.getOperatorStateBackend(); + + if (operatorStateBackend instanceof DefaultOperatorStateBackend) { + for (String fieldNames : + new String[] { + "registeredOperatorStates", + "registeredBroadcastStates", + "accessedStatesByName", + "accessedBroadcastStatesByName" + }) { + Map<String, ?> field = + ReflectionUtils.getFieldValue( + operatorStateBackend, + DefaultOperatorStateBackend.class, + fieldNames); + field.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix)); + } + } else { + LOG.warn("Unable to cleanup the operator state {}", operatorStateBackend); + } + } + + private void cleanupKeyedStates(int round) { + String roundPrefix = getRoundStatePrefix(round); + KeyedStateBackend<?> keyedStateBackend = stateHandler.getKeyedStateBackend(); + if (keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) { + ReflectionUtils.<Map<String, ?>>getFieldValue( + keyedStateBackend, HeapKeyedStateBackend.class, "registeredKVStates") + .entrySet() + .removeIf(entry -> entry.getKey().startsWith(roundPrefix)); + ReflectionUtils.<Map<String, ?>>getFieldValue( + keyedStateBackend, + AbstractKeyedStateBackend.class, + "keyValueStatesByName") + .entrySet() + .removeIf(entry -> entry.getKey().startsWith(roundPrefix)); + } else if (keyedStateBackend.getClass().getName().equals(ROCKSDB_KEYED_STATE_NAME)) { + RocksDB db = + ReflectionUtils.getFieldValue( + keyedStateBackend, RocksDBKeyedStateBackend.class, "db"); + HashMap<String, RocksDBKeyedStateBackend.RocksDbKvStateInfo> kvStateInformation = + ReflectionUtils.getFieldValue( + keyedStateBackend, + RocksDBKeyedStateBackend.class, + "kvStateInformation"); + kvStateInformation.entrySet().stream() + .filter(entry -> entry.getKey().startsWith(roundPrefix)) + .forEach( + entry -> { + try { + db.dropColumnFamily(entry.getValue().columnFamilyHandle); + } catch (Exception e) { + LOG.error( + "Failed to drop state {} for round {}", + entry.getKey(), + round); + } + }); + kvStateInformation.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix)); + + Map<String, ?> field = + ReflectionUtils.getFieldValue( + keyedStateBackend, + AbstractKeyedStateBackend.class, + "keyValueStatesByName"); + field.entrySet().removeIf(entry -> entry.getKey().startsWith(roundPrefix)); + } else { + LOG.warn("Unable to cleanup the keyed state {}", keyedStateBackend); + } + } + + private String getRoundStatePrefix(int round) { + return "r" + round + "-"; + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java new file mode 100644 index 0000000..7c6240f --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.operator.perround; + +import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.Input; +import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** Per-round wrapper for the multiple-inputs operator. */ +public class MultipleInputPerRoundWrapperOperator<OUT> + extends AbstractPerRoundWrapperOperator<OUT, MultipleInputStreamOperator<OUT>> + implements MultipleInputStreamOperator<IterationRecord<OUT>> { + + /** The number of total inputs. */ + private final int numberOfInputs; + + /** + * Cached inputs for each epoch. This is to avoid repeat calls to the {@link + * MultipleInputStreamOperator#getInputs()}, which might not returns the same inputs for each + * call. + */ + private final Map<Integer, List<Input>> operatorInputsByEpoch = new HashMap<>(); + + public MultipleInputPerRoundWrapperOperator( + StreamOperatorParameters<IterationRecord<OUT>> parameters, + StreamOperatorFactory<OUT> operatorFactory) { + super(parameters, operatorFactory); + + // Determine how much inputs we have + List<StreamEdge> inEdges = + streamConfig.getInPhysicalEdges(containingTask.getUserCodeClassLoader()); + this.numberOfInputs = + inEdges.stream().map(StreamEdge::getTypeNumber).collect(Collectors.toSet()).size(); + } + + @Override + protected void endInputAndEmitMaxWatermark(MultipleInputStreamOperator<OUT> operator, int round) + throws Exception { + OperatorUtils.processOperatorOrUdfIfSatisfy( + operator, + BoundedMultiInput.class, + boundedMultiInput -> { + for (int i = 0; i < numberOfInputs; ++i) { + boundedMultiInput.endInput(i + 1); + } + }); + + for (int i = 0; i < numberOfInputs; ++i) { + operatorInputsByEpoch.get(round).get(i).processWatermark(new Watermark(Long.MAX_VALUE)); + } + } + + private <IN> void processElement( + int inputIndex, + Input<IN> input, + StreamRecord<IN> reusedInput, + StreamRecord<IterationRecord<IN>> element) + throws Exception { + switch (element.getValue().getType()) { + case RECORD: + reusedInput.replace(element.getValue().getValue(), element.getTimestamp()); + setIterationContextRound(element.getValue().getEpoch()); + input.processElement(reusedInput); + clearIterationContextRound(); + break; + case EPOCH_WATERMARK: + onEpochWatermarkEvent(inputIndex, element.getValue()); + break; + default: + throw new FlinkRuntimeException("Not supported iteration record type: " + element); + } + } + + @Override + @SuppressWarnings({"rawtypes"}) + public List<Input> getInputs() { + List<Input> proxyInputs = new ArrayList<>(); + + for (int i = 0; i < numberOfInputs; ++i) { + // TODO: Note that here we relies on the assumption that the + // stream graph generator labels the input from 1 to n for + // the input array, which we map them from 0 to n - 1. + proxyInputs.add(new ProxyInput(i)); + } + return proxyInputs; + } + + private class ProxyInput<IN> implements Input<IterationRecord<IN>> { + + private final int inputIndex; + + private final StreamRecord<IN> reusedInput; + + public ProxyInput(int inputIndex) { + this.inputIndex = inputIndex; + this.reusedInput = new StreamRecord<>(null, 0); + } + + @Override + public void processElement(StreamRecord<IterationRecord<IN>> element) throws Exception { + if (!operatorInputsByEpoch.containsKey(element.getValue().getEpoch())) { + MultipleInputStreamOperator<OUT> operator = + getWrappedOperator(element.getValue().getEpoch()); + operatorInputsByEpoch.put(element.getValue().getEpoch(), operator.getInputs()); + } + + MultipleInputPerRoundWrapperOperator.this.processElement( + inputIndex, + operatorInputsByEpoch.get(element.getValue().getEpoch()).get(inputIndex), + reusedInput, + element); + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> { + operatorInputsByEpoch.get(round).get(inputIndex).processWatermark(mark); + }); + } + + @Override + public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> { + operatorInputsByEpoch + .get(round) + .get(inputIndex) + .processWatermarkStatus(watermarkStatus); + }); + } + + @Override + public void processLatencyMarker(LatencyMarker latencyMarker) throws Exception { + reportOrForwardLatencyMarker(latencyMarker); + } + + @Override + public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record) + throws Exception { + MultipleInputStreamOperator<OUT> operator = + getWrappedOperator(record.getValue().getEpoch()); + + reusedInput.replace(record.getValue(), record.getTimestamp()); + operator.getInputs().get(inputIndex).setKeyContextElement(reusedInput); + } + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java new file mode 100644 index 0000000..ae85618 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.operator.perround; + +import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.FlinkRuntimeException; + +/** Per-round wrapper operator for the one-input operator. */ +public class OneInputPerRoundWrapperOperator<IN, OUT> + extends AbstractPerRoundWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>> + implements OneInputStreamOperator<IterationRecord<IN>, IterationRecord<OUT>> { + + private final StreamRecord<IN> reusedInput; + + public OneInputPerRoundWrapperOperator( + StreamOperatorParameters<IterationRecord<OUT>> parameters, + StreamOperatorFactory<OUT> operatorFactory) { + super(parameters, operatorFactory); + this.reusedInput = new StreamRecord<>(null, 0); + } + + @Override + protected void endInputAndEmitMaxWatermark(OneInputStreamOperator<IN, OUT> operator, int round) + throws Exception { + OperatorUtils.processOperatorOrUdfIfSatisfy( + operator, BoundedOneInput.class, BoundedOneInput::endInput); + operator.processWatermark(new Watermark(Long.MAX_VALUE)); + } + + @Override + public void processElement(StreamRecord<IterationRecord<IN>> element) throws Exception { + switch (element.getValue().getType()) { + case RECORD: + reusedInput.replace(element.getValue().getValue(), element.getTimestamp()); + setIterationContextRound(element.getValue().getEpoch()); + getWrappedOperator(element.getValue().getEpoch()).processElement(reusedInput); + clearIterationContextRound(); + break; + case EPOCH_WATERMARK: + onEpochWatermarkEvent(0, element.getValue()); + break; + default: + throw new FlinkRuntimeException("Not supported iteration record type: " + element); + } + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> wrappedOperator.processWatermark(mark)); + } + + @Override + public void processWatermarkStatus(WatermarkStatus watermarkStatus) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> + wrappedOperator.processWatermarkStatus(watermarkStatus)); + } + + @Override + public void processLatencyMarker(LatencyMarker latencyMarker) throws Exception { + reportOrForwardLatencyMarker(latencyMarker); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java new file mode 100644 index 0000000..ffa2221 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.operator.perround; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.operator.OperatorWrapper; +import org.apache.flink.iteration.proxy.ProxyKeySelector; +import org.apache.flink.iteration.proxy.ProxyStreamPartitioner; +import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo; +import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import org.apache.flink.util.OutputTag; + +/** The operator wrapper implementation for per-round wrappers. */ +public class PerRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationRecord<T>> { + + @Override + public StreamOperator<IterationRecord<T>> wrap( + StreamOperatorParameters<IterationRecord<T>> operatorParameters, + StreamOperatorFactory<T> operatorFactory) { + Class<? extends StreamOperator> operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new OneInputPerRoundWrapperOperator<>(operatorParameters, operatorFactory); + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new TwoInputPerRoundWrapperOperator<>(operatorParameters, operatorFactory); + } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return new MultipleInputPerRoundWrapperOperator<>(operatorParameters, operatorFactory); + } else { + throw new UnsupportedOperationException( + "Unsupported operator class for all-round wrapper: " + operatorClass); + } + } + + @Override + public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector( + KeySelector<T, KEY> keySelector) { + return new ProxyKeySelector<>(keySelector); + } + + @Override + public StreamPartitioner<IterationRecord<T>> wrapStreamPartitioner( + StreamPartitioner<T> streamPartitioner) { + if (streamPartitioner instanceof BroadcastPartitioner) { + return new BroadcastPartitioner<>(); + } + + return new ProxyStreamPartitioner<>(streamPartitioner); + } + + @Override + public OutputTag<IterationRecord<T>> wrapOutputTag(OutputTag<T> outputTag) { + return new OutputTag<>( + outputTag.getId(), new IterationRecordTypeInfo<>(outputTag.getTypeInfo())); + } + + @Override + public TypeInformation<IterationRecord<T>> getWrappedTypeInfo(TypeInformation<T> typeInfo) { + return new IterationRecordTypeInfo<>(typeInfo); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java new file mode 100644 index 0000000..39918ad --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.operator.perround; + +import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.operator.OperatorUtils; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.function.ThrowingConsumer; + +/** Per-round wrapper for the two-inputs operator. */ +public class TwoInputPerRoundWrapperOperator<IN1, IN2, OUT> + extends AbstractPerRoundWrapperOperator<OUT, TwoInputStreamOperator<IN1, IN2, OUT>> + implements TwoInputStreamOperator< + IterationRecord<IN1>, IterationRecord<IN2>, IterationRecord<OUT>> { + + private final StreamRecord<IN1> reusedInput1; + + private final StreamRecord<IN2> reusedInput2; + + public TwoInputPerRoundWrapperOperator( + StreamOperatorParameters<IterationRecord<OUT>> parameters, + StreamOperatorFactory<OUT> operatorFactory) { + super(parameters, operatorFactory); + + this.reusedInput1 = new StreamRecord<>(null, 0); + this.reusedInput2 = new StreamRecord<>(null, 0); + } + + @Override + protected void endInputAndEmitMaxWatermark( + TwoInputStreamOperator<IN1, IN2, OUT> operator, int round) throws Exception { + OperatorUtils.processOperatorOrUdfIfSatisfy( + operator, + BoundedMultiInput.class, + boundedMultiInput -> { + boundedMultiInput.endInput(1); + boundedMultiInput.endInput(2); + }); + operator.processWatermark1(new Watermark(Long.MAX_VALUE)); + operator.processWatermark2(new Watermark(Long.MAX_VALUE)); + } + + @Override + public void processElement1(StreamRecord<IterationRecord<IN1>> element) throws Exception { + processElement( + element, + 0, + reusedInput1, + getWrappedOperator(element.getValue().getEpoch())::processElement1); + } + + @Override + public void processElement2(StreamRecord<IterationRecord<IN2>> element) throws Exception { + processElement( + element, + 1, + reusedInput2, + getWrappedOperator(element.getValue().getEpoch())::processElement2); + } + + private <IN> void processElement( + StreamRecord<IterationRecord<IN>> element, + int inputIndex, + StreamRecord<IN> reusedInput, + ThrowingConsumer<StreamRecord<IN>, Exception> processor) + throws Exception { + + switch (element.getValue().getType()) { + case RECORD: + reusedInput.replace(element.getValue().getValue(), element.getTimestamp()); + setIterationContextRound(element.getValue().getEpoch()); + processor.accept(reusedInput); + clearIterationContextRound(); + break; + case EPOCH_WATERMARK: + onEpochWatermarkEvent(inputIndex, element.getValue()); + break; + default: + throw new FlinkRuntimeException("Not supported iteration record type: " + element); + } + } + + @Override + public void processWatermark1(Watermark mark) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> wrappedOperator.processWatermark1(mark)); + } + + @Override + public void processWatermark2(Watermark mark) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> wrappedOperator.processWatermark2(mark)); + } + + @Override + public void processLatencyMarker1(LatencyMarker latencyMarker) throws Exception { + reportOrForwardLatencyMarker(latencyMarker); + } + + @Override + public void processLatencyMarker2(LatencyMarker latencyMarker) throws Exception { + reportOrForwardLatencyMarker(latencyMarker); + } + + @Override + public void processWatermarkStatus1(WatermarkStatus watermarkStatus) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> + wrappedOperator.processWatermarkStatus1(watermarkStatus)); + } + + @Override + public void processWatermarkStatus2(WatermarkStatus watermarkStatus) throws Exception { + processForEachWrappedOperator( + (round, wrappedOperator) -> + wrappedOperator.processWatermarkStatus2(watermarkStatus)); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java index 0d80f62..7f5cc8a 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyOutput.java @@ -53,7 +53,7 @@ public class ProxyOutput<T> implements Output<StreamRecord<T>> { @Override public void emitWatermark(Watermark mark) { - output.emitWatermark(mark); + // For now, we only supports the MAX_WATERMARK separately for each operator. } @Override diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyInternalTimeServiceManager.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyInternalTimeServiceManager.java new file mode 100644 index 0000000..7377328 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyInternalTimeServiceManager.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.proxy.state; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.api.watermark.Watermark; + +/** Proxy {@link InternalTimeServiceManager} for the wrapped operators. */ +public class ProxyInternalTimeServiceManager<K> implements InternalTimeServiceManager<K> { + + private final InternalTimeServiceManager<K> wrappedManager; + + private final StateNamePrefix stateNamePrefix; + + public ProxyInternalTimeServiceManager( + InternalTimeServiceManager<K> wrappedManager, StateNamePrefix stateNamePrefix) { + this.wrappedManager = wrappedManager; + this.stateNamePrefix = stateNamePrefix; + } + + @Override + public <N> InternalTimerService<N> getInternalTimerService( + String name, + TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + Triggerable<K, N> triggerable) { + return wrappedManager.getInternalTimerService( + stateNamePrefix.prefix(name), keySerializer, namespaceSerializer, triggerable); + } + + @Override + public void advanceWatermark(Watermark watermark) throws Exception { + wrappedManager.advanceWatermark(watermark); + } + + @Override + public void snapshotToRawKeyedState( + KeyedStateCheckpointOutputStream stateCheckpointOutputStream, String operatorName) + throws Exception { + wrappedManager.snapshotToRawKeyedState(stateCheckpointOutputStream, operatorName); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyKeyedStateBackend.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyKeyedStateBackend.java new file mode 100644 index 0000000..bc51f95 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyKeyedStateBackend.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.proxy.state; + +import org.apache.flink.api.common.state.AggregatingStateDescriptor; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReducingStateDescriptor; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue; +import org.apache.flink.runtime.state.Keyed; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateFunction; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.PriorityComparable; +import org.apache.flink.runtime.state.SavepointResources; +import org.apache.flink.runtime.state.SnapshotResult; +import org.apache.flink.runtime.state.StateSnapshotTransformer; +import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.RunnableFuture; +import java.util.stream.Stream; + +/** Proxy {@link KeyedStateBackend} for the wrapped operators. */ +public class ProxyKeyedStateBackend<K> implements CheckpointableKeyedStateBackend<K> { + + private final CheckpointableKeyedStateBackend<K> wrappedBackend; + + private final StateNamePrefix stateNamePrefix; + + public ProxyKeyedStateBackend( + CheckpointableKeyedStateBackend<K> wrappedBackend, StateNamePrefix stateNamePrefix) { + this.wrappedBackend = wrappedBackend; + this.stateNamePrefix = stateNamePrefix; + } + + @Override + public void setCurrentKey(K newKey) { + wrappedBackend.setCurrentKey(newKey); + } + + @Override + public K getCurrentKey() { + return wrappedBackend.getCurrentKey(); + } + + @Override + public TypeSerializer<K> getKeySerializer() { + return wrappedBackend.getKeySerializer(); + } + + @Override + public <N, S extends State, T> void applyToAllKeys( + N namespace, + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, T> stateDescriptor, + KeyedStateFunction<K, S> function) + throws Exception { + StateDescriptor<S, T> newDescriptor = createNewDescriptor(stateDescriptor); + wrappedBackend.applyToAllKeys(namespace, namespaceSerializer, newDescriptor, function); + } + + @Override + public <N> Stream<K> getKeys(String state, N namespace) { + return wrappedBackend.getKeys(stateNamePrefix.prefix(state), namespace); + } + + @Override + public <N> Stream<Tuple2<K, N>> getKeysAndNamespaces(String state) { + return wrappedBackend.getKeysAndNamespaces(stateNamePrefix.prefix(state)); + } + + @Override + public <N, S extends State, T> S getOrCreateKeyedState( + TypeSerializer<N> namespaceSerializer, StateDescriptor<S, T> stateDescriptor) + throws Exception { + StateDescriptor<S, T> newDescriptor = createNewDescriptor(stateDescriptor); + return wrappedBackend.getOrCreateKeyedState(namespaceSerializer, newDescriptor); + } + + @Override + public <N, S extends State> S getPartitionedState( + N namespace, + TypeSerializer<N> namespaceSerializer, + StateDescriptor<S, ?> stateDescriptor) + throws Exception { + StateDescriptor<S, ?> newDescriptor = createNewDescriptor(stateDescriptor); + return wrappedBackend.getPartitionedState(namespace, namespaceSerializer, newDescriptor); + } + + @Override + public void registerKeySelectionListener(KeySelectionListener<K> listener) { + wrappedBackend.registerKeySelectionListener(listener); + } + + @Override + public boolean deregisterKeySelectionListener(KeySelectionListener<K> listener) { + return wrappedBackend.deregisterKeySelectionListener(listener); + } + + @Nonnull + @Override + public <N, SV, SEV, S extends State, IS extends S> IS createInternalState( + @Nonnull TypeSerializer<N> namespaceSerializer, + @Nonnull StateDescriptor<S, SV> stateDesc, + @Nonnull + StateSnapshotTransformer.StateSnapshotTransformFactory<SEV> + snapshotTransformFactory) + throws Exception { + StateDescriptor<S, ?> newDescriptor = createNewDescriptor(stateDesc); + return wrappedBackend.createInternalState( + namespaceSerializer, newDescriptor, snapshotTransformFactory); + } + + @SuppressWarnings("unchecked") + protected <S extends State, T> StateDescriptor<S, T> createNewDescriptor( + StateDescriptor<S, T> descriptor) { + switch (descriptor.getType()) { + case VALUE: + { + return (StateDescriptor<S, T>) + new ValueStateDescriptor<>( + stateNamePrefix.prefix(descriptor.getName()), + descriptor.getSerializer()); + } + case LIST: + { + ListStateDescriptor<T> listStateDescriptor = + (ListStateDescriptor<T>) descriptor; + return (StateDescriptor<S, T>) + new ListStateDescriptor<>( + stateNamePrefix.prefix(listStateDescriptor.getName()), + listStateDescriptor.getElementSerializer()); + } + case REDUCING: + { + ReducingStateDescriptor<T> reducingStateDescriptor = + (ReducingStateDescriptor<T>) descriptor; + return (StateDescriptor<S, T>) + new ReducingStateDescriptor<>( + stateNamePrefix.prefix(reducingStateDescriptor.getName()), + reducingStateDescriptor.getReduceFunction(), + reducingStateDescriptor.getSerializer()); + } + case AGGREGATING: + { + AggregatingStateDescriptor<?, ?, T> aggregatingStateDescriptor = + (AggregatingStateDescriptor<?, ?, T>) descriptor; + return new AggregatingStateDescriptor( + stateNamePrefix.prefix(aggregatingStateDescriptor.getName()), + aggregatingStateDescriptor.getAggregateFunction(), + aggregatingStateDescriptor.getSerializer()); + } + case MAP: + { + MapStateDescriptor<?, Map<?, ?>> mapStateDescriptor = + (MapStateDescriptor<?, Map<?, ?>>) descriptor; + return new MapStateDescriptor( + stateNamePrefix.prefix(mapStateDescriptor.getName()), + mapStateDescriptor.getKeySerializer(), + mapStateDescriptor.getValueSerializer()); + } + default: + throw new UnsupportedOperationException("Unsupported state type"); + } + } + + @Override + public KeyGroupRange getKeyGroupRange() { + return wrappedBackend.getKeyGroupRange(); + } + + @Nonnull + @Override + public SavepointResources<K> savepoint() throws Exception { + return wrappedBackend.savepoint(); + } + + @Override + public void dispose() { + // Do not dispose for poxy. + } + + @Override + public void close() throws IOException { + // Do not close for poxy. + } + + @Nonnull + @Override + public <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>> + KeyGroupedInternalPriorityQueue<T> create( + @Nonnull String stateName, + @Nonnull TypeSerializer<T> byteOrderedElementSerializer) { + return wrappedBackend.create( + stateNamePrefix.prefix(stateName), byteOrderedElementSerializer); + } + + @Nonnull + @Override + public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot( + long checkpointId, + long timestamp, + @Nonnull CheckpointStreamFactory streamFactory, + @Nonnull CheckpointOptions checkpointOptions) + throws Exception { + return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java new file mode 100644 index 0000000..a655886 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.proxy.state; + +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.SnapshotResult; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.RunnableFuture; + +/** Proxy {@link OperatorStateBackend} for the wrapped Operator. */ +public class ProxyOperatorStateBackend implements OperatorStateBackend { + + private final OperatorStateBackend wrappedBackend; + + private final StateNamePrefix stateNamePrefix; + + public ProxyOperatorStateBackend( + OperatorStateBackend wrappedBackend, StateNamePrefix stateNamePrefix) { + this.wrappedBackend = wrappedBackend; + this.stateNamePrefix = stateNamePrefix; + } + + @Override + public <K, V> BroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, V> stateDescriptor) + throws Exception { + MapStateDescriptor<K, V> newDescriptor = + new MapStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + stateDescriptor.getKeySerializer(), + stateDescriptor.getValueSerializer()); + return wrappedBackend.getBroadcastState(newDescriptor); + } + + @Override + public <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception { + ListStateDescriptor<S> newDescriptor = + new ListStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + stateDescriptor.getElementSerializer()); + return wrappedBackend.getListState(newDescriptor); + } + + @Override + public <S> ListState<S> getUnionListState(ListStateDescriptor<S> stateDescriptor) + throws Exception { + ListStateDescriptor<S> newDescriptor = + new ListStateDescriptor<S>( + stateNamePrefix.prefix(stateDescriptor.getName()), + stateDescriptor.getElementSerializer()); + return wrappedBackend.getUnionListState(newDescriptor); + } + + @Override + public Set<String> getRegisteredStateNames() { + Set<String> filteredNames = new HashSet<>(); + Set<String> names = wrappedBackend.getRegisteredStateNames(); + + for (String name : names) { + if (name.startsWith(stateNamePrefix.getNamePrefix())) { + filteredNames.add(name.substring(stateNamePrefix.getNamePrefix().length())); + } + } + + return filteredNames; + } + + @Override + public Set<String> getRegisteredBroadcastStateNames() { + Set<String> filteredNames = new HashSet<>(); + Set<String> names = wrappedBackend.getRegisteredBroadcastStateNames(); + + for (String name : names) { + if (name.startsWith(stateNamePrefix.getNamePrefix())) { + filteredNames.add(name.substring(stateNamePrefix.getNamePrefix().length())); + } + } + + return filteredNames; + } + + @Override + public void dispose() { + // Do not dispose for proxy. + } + + @Override + public void close() throws IOException { + // Do not close for proxy. + } + + @Nonnull + @Override + public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot( + long checkpointId, + long timestamp, + @Nonnull CheckpointStreamFactory streamFactory, + @Nonnull CheckpointOptions checkpointOptions) + throws Exception { + return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStateSnapshotContext.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStateSnapshotContext.java new file mode 100644 index 0000000..35d164c --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStateSnapshotContext.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.proxy.state; + +import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream; +import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream; +import org.apache.flink.runtime.state.StateSnapshotContext; + +/** Proxy {@link StateSnapshotContext} for the wrapped operators. */ +public class ProxyStateSnapshotContext implements StateSnapshotContext { + + private final StateSnapshotContext wrappedContext; + + public ProxyStateSnapshotContext(StateSnapshotContext wrappedContext) { + this.wrappedContext = wrappedContext; + } + + @Override + public KeyedStateCheckpointOutputStream getRawKeyedOperatorStateOutput() throws Exception { + throw new UnsupportedOperationException( + "Currently we do not support the raw operator state inside the iteration."); + } + + @Override + public OperatorStateCheckpointOutputStream getRawOperatorStateOutput() throws Exception { + throw new UnsupportedOperationException( + "Currently we do not support the raw keyed state inside the iteration."); + } + + @Override + public long getCheckpointId() { + return wrappedContext.getCheckpointId(); + } + + @Override + public long getCheckpointTimestamp() { + return wrappedContext.getCheckpointTimestamp(); + } +} diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStreamOperatorStateContext.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStreamOperatorStateContext.java new file mode 100644 index 0000000..a892cc5 --- /dev/null +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyStreamOperatorStateContext.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.proxy.state; + +import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend; +import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider; +import org.apache.flink.runtime.state.OperatorStateBackend; +import org.apache.flink.runtime.state.StatePartitionStreamProvider; +import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; +import org.apache.flink.streaming.api.operators.StreamOperatorStateContext; +import org.apache.flink.util.CloseableIterable; + +import java.util.Objects; +import java.util.OptionalLong; + +/** Proxy {@link StreamOperatorStateContext} for the wrapped operator. */ +public class ProxyStreamOperatorStateContext implements StreamOperatorStateContext { + + private final StreamOperatorStateContext wrapped; + + private final StateNamePrefix stateNamePrefix; + + public ProxyStreamOperatorStateContext( + StreamOperatorStateContext wrapped, String stateNamePrefix) { + this.wrapped = Objects.requireNonNull(wrapped); + this.stateNamePrefix = new StateNamePrefix(stateNamePrefix); + } + + @Override + public boolean isRestored() { + return wrapped.isRestored(); + } + + @Override + public OptionalLong getRestoredCheckpointId() { + return wrapped.getRestoredCheckpointId(); + } + + @Override + public OperatorStateBackend operatorStateBackend() { + return wrapped.operatorStateBackend() == null + ? null + : new ProxyOperatorStateBackend(wrapped.operatorStateBackend(), stateNamePrefix); + } + + @Override + public CheckpointableKeyedStateBackend<?> keyedStateBackend() { + return wrapped.keyedStateBackend() == null + ? null + : new ProxyKeyedStateBackend<>(wrapped.keyedStateBackend(), stateNamePrefix); + } + + @Override + public InternalTimeServiceManager<?> internalTimerServiceManager() { + return wrapped.internalTimerServiceManager() == null + ? null + : new ProxyInternalTimeServiceManager<>( + wrapped.internalTimerServiceManager(), stateNamePrefix); + } + + @Override + public CloseableIterable<StatePartitionStreamProvider> rawOperatorStateInputs() { + return CloseableIterable.empty(); + } + + @Override + public CloseableIterable<KeyGroupStatePartitionStreamProvider> rawKeyedStateInputs() { + return CloseableIterable.empty(); + } +} diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/StateNamePrefix.java similarity index 66% copy from flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java copy to flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/StateNamePrefix.java index 3328a07..dbe507c 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/StateNamePrefix.java @@ -16,20 +16,22 @@ * limitations under the License. */ -package org.apache.flink.iteration.operator.allround; +package org.apache.flink.iteration.proxy.state; -/** Utilities to track the life-cycle of the operators. */ -public enum LifeCycle { - SETUP, - OPEN, - INITIALIZE_STATE, - PROCESS_ELEMENT, - PROCESS_ELEMENT_1, - PROCESS_ELEMENT_2, - PREPARE_SNAPSHOT_PRE_BARRIER, - SNAPSHOT_STATE, - NOTIFY_CHECKPOINT_COMPLETE, - NOTIFY_CHECKPOINT_ABORT, - FINISH, - CLOSE, +/** The prefix for the state name. */ +public class StateNamePrefix { + + private final String namePrefix; + + public StateNamePrefix(String namePrefix) { + this.namePrefix = namePrefix; + } + + public String getNamePrefix() { + return namePrefix; + } + + public String prefix(String stateName) { + return namePrefix + stateName; + } } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java index 3328a07..552c933 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java @@ -30,6 +30,8 @@ public enum LifeCycle { SNAPSHOT_STATE, NOTIFY_CHECKPOINT_COMPLETE, NOTIFY_CHECKPOINT_ABORT, + END_INPUT, + MAX_WATERMARK, FINISH, CLOSE, } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java index 6491369..9a2b72a 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractInput; import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; import org.apache.flink.streaming.api.operators.Input; import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -57,7 +58,7 @@ import static org.junit.Assert.assertEquals; /** Tests the {@link OneInputAllRoundWrapperOperator}. */ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { - private static List<LifeCycle> lifeCycles = new ArrayList<>(); + private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>(); @Test public void testProcessElementsAndEpochWatermarks() throws Exception { @@ -86,7 +87,7 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { harness.processElement( new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-2")), 2); - // Check the output + // Checks the output assertEquals( Arrays.asList( new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), @@ -97,7 +98,7 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { 5, OperatorUtils.getUniqueSenderId(operatorId, 0)))), new ArrayList<>(harness.getOutput())); - // Check the other lifecycles. + // Checks the other lifecycles. harness.getStreamTask() .triggerCheckpointOnBarrier( new CheckpointMetaData(5, 2), @@ -132,15 +133,21 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { LifeCycle.SNAPSHOT_STATE, LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_ABORT, + // The first input + LifeCycle.END_INPUT, + // The second input + LifeCycle.END_INPUT, + // The third input + LifeCycle.END_INPUT, LifeCycle.FINISH, LifeCycle.CLOSE), - lifeCycles); + LIFE_CYCLES); } } private static class LifeCycleTrackingTwoInputStreamOperator extends AbstractStreamOperatorV2<Integer> - implements MultipleInputStreamOperator<Integer> { + implements MultipleInputStreamOperator<Integer>, BoundedMultiInput { private final int numberOfInputs; @@ -159,7 +166,7 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { @Override public void processElement(StreamRecord element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT); } }); } @@ -170,49 +177,54 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { @Override public void open() throws Exception { super.open(); - lifeCycles.add(LifeCycle.OPEN); + LIFE_CYCLES.add(LifeCycle.OPEN); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - lifeCycles.add(LifeCycle.INITIALIZE_STATE); + LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE); } @Override public void finish() throws Exception { super.finish(); - lifeCycles.add(LifeCycle.FINISH); + LIFE_CYCLES.add(LifeCycle.FINISH); } @Override public void close() throws Exception { super.close(); - lifeCycles.add(LifeCycle.CLOSE); + LIFE_CYCLES.add(LifeCycle.CLOSE); } @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { super.prepareSnapshotPreBarrier(checkpointId); - lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); + LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - lifeCycles.add(LifeCycle.SNAPSHOT_STATE); + LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); } @Override public void notifyCheckpointAborted(long checkpointId) throws Exception { super.notifyCheckpointAborted(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + } + + @Override + public void endInput(int inputId) throws Exception { + LIFE_CYCLES.add(LifeCycle.END_INPUT); } } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java index c58f431..f628b65 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals; /** Tests the {@link OneInputAllRoundWrapperOperator}. */ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { - private static List<LifeCycle> lifeCycles = new ArrayList<>(); + private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>(); @Test public void testProcessElementsAndEpochWatermarks() throws Exception { @@ -78,7 +79,7 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { harness.processElement( new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one"))); - // Check the output + // Checks the output assertEquals( Arrays.asList( new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), @@ -121,15 +122,16 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { LifeCycle.SNAPSHOT_STATE, LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_ABORT, + LifeCycle.END_INPUT, LifeCycle.FINISH, LifeCycle.CLOSE), - lifeCycles); + LIFE_CYCLES); } } private static class LifeCycleTrackingOneInputStreamOperator extends AbstractStreamOperator<Integer> - implements OneInputStreamOperator<Integer, Integer> { + implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput { @Override public void setup( @@ -137,61 +139,66 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { StreamConfig config, Output<StreamRecord<Integer>> output) { super.setup(containingTask, config, output); - lifeCycles.add(LifeCycle.SETUP); + LIFE_CYCLES.add(LifeCycle.SETUP); } @Override public void open() throws Exception { super.open(); - lifeCycles.add(LifeCycle.OPEN); + LIFE_CYCLES.add(LifeCycle.OPEN); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - lifeCycles.add(LifeCycle.INITIALIZE_STATE); + LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE); } @Override public void finish() throws Exception { super.finish(); - lifeCycles.add(LifeCycle.FINISH); + LIFE_CYCLES.add(LifeCycle.FINISH); } @Override public void close() throws Exception { super.close(); - lifeCycles.add(LifeCycle.CLOSE); + LIFE_CYCLES.add(LifeCycle.CLOSE); } @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { super.prepareSnapshotPreBarrier(checkpointId); - lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); + LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - lifeCycles.add(LifeCycle.SNAPSHOT_STATE); + LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); } @Override public void notifyCheckpointAborted(long checkpointId) throws Exception { super.notifyCheckpointAborted(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); } @Override public void processElement(StreamRecord<Integer> element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT); + } + + @Override + public void endInput() throws Exception { + LIFE_CYCLES.add(LifeCycle.END_INPUT); } } } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java index e7e2604..82d5854 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals; /** Tests the {@link OneInputAllRoundWrapperOperator}. */ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { - private static List<LifeCycle> lifeCycles = new ArrayList<>(); + private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>(); @Test public void testProcessElementsAndEpochWatermarks() throws Exception { @@ -81,7 +82,7 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { harness.processElement( new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-1")), 1); - // Check the output + // Checks the output assertEquals( Arrays.asList( new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), @@ -91,7 +92,7 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { 5, OperatorUtils.getUniqueSenderId(operatorId, 0)))), new ArrayList<>(harness.getOutput())); - // Check the other lifecycles. + // Checks the other lifecycles. harness.getStreamTask() .triggerCheckpointOnBarrier( new CheckpointMetaData(5, 2), @@ -125,15 +126,19 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { LifeCycle.SNAPSHOT_STATE, LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_ABORT, + // The first input + LifeCycle.END_INPUT, + // The second input + LifeCycle.END_INPUT, LifeCycle.FINISH, LifeCycle.CLOSE), - lifeCycles); + LIFE_CYCLES); } } private static class LifeCycleTrackingTwoInputStreamOperator extends AbstractStreamOperator<Integer> - implements TwoInputStreamOperator<Integer, Integer, Integer> { + implements TwoInputStreamOperator<Integer, Integer, Integer>, BoundedMultiInput { @Override public void setup( @@ -141,67 +146,72 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { StreamConfig config, Output<StreamRecord<Integer>> output) { super.setup(containingTask, config, output); - lifeCycles.add(LifeCycle.SETUP); + LIFE_CYCLES.add(LifeCycle.SETUP); } @Override public void open() throws Exception { super.open(); - lifeCycles.add(LifeCycle.OPEN); + LIFE_CYCLES.add(LifeCycle.OPEN); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - lifeCycles.add(LifeCycle.INITIALIZE_STATE); + LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE); } @Override public void finish() throws Exception { super.finish(); - lifeCycles.add(LifeCycle.FINISH); + LIFE_CYCLES.add(LifeCycle.FINISH); } @Override public void close() throws Exception { super.close(); - lifeCycles.add(LifeCycle.CLOSE); + LIFE_CYCLES.add(LifeCycle.CLOSE); } @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { super.prepareSnapshotPreBarrier(checkpointId); - lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); + LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - lifeCycles.add(LifeCycle.SNAPSHOT_STATE); + LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); } @Override public void notifyCheckpointAborted(long checkpointId) throws Exception { super.notifyCheckpointAborted(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); } @Override public void processElement1(StreamRecord<Integer> element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT_1); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_1); } @Override public void processElement2(StreamRecord<Integer> element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT_2); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_2); + } + + @Override + public void endInput(int inputId) throws Exception { + LIFE_CYCLES.add(LifeCycle.END_INPUT); } } } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java similarity index 68% copy from flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java copy to flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java index 6491369..9dba5fa 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java @@ -16,12 +16,14 @@ * limitations under the License. */ -package org.apache.flink.iteration.operator.allround; +package org.apache.flink.iteration.operator.perround; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.iteration.IterationRecord; import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.iteration.operator.WrapperOperatorFactory; +import org.apache.flink.iteration.operator.allround.LifeCycle; +import org.apache.flink.iteration.operator.allround.OneInputAllRoundWrapperOperator; import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder; @@ -35,6 +37,7 @@ import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractInput; import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; import org.apache.flink.streaming.api.operators.Input; import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -55,16 +58,16 @@ import java.util.List; import static org.junit.Assert.assertEquals; /** Tests the {@link OneInputAllRoundWrapperOperator}. */ -public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { +public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger { - private static List<LifeCycle> lifeCycles = new ArrayList<>(); + private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>(); @Test public void testProcessElementsAndEpochWatermarks() throws Exception { StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory = new WrapperOperatorFactory<>( - new LifeCycleTrackingTwoInputStreamOperatorFactory(), - new AllRoundOperatorWrapper<>()); + new LifeCycleTrackingMultiInputStreamOperatorFactory(), + new PerRoundOperatorWrapper<>()); OperatorID operatorId = new OperatorID(); try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness = @@ -77,27 +80,16 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId) .build()) { harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0); - harness.processElement(new StreamRecord<>(IterationRecord.newRecord(6, 2), 3), 1); - harness.processElement(new StreamRecord<>(IterationRecord.newRecord(7, 3), 4), 2); - harness.processElement( - new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-0")), 0); - harness.processElement( - new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-1")), 1); - harness.processElement( - new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-2")), 2); + harness.processElement(new StreamRecord<>(IterationRecord.newRecord(6, 2), 3), 2); - // Check the output + // Checks the output assertEquals( Arrays.asList( new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), - new StreamRecord<>(IterationRecord.newRecord(6, 2), 3), - new StreamRecord<>(IterationRecord.newRecord(7, 3), 4), - new StreamRecord<>( - IterationRecord.newEpochWatermark( - 5, OperatorUtils.getUniqueSenderId(operatorId, 0)))), + new StreamRecord<>(IterationRecord.newRecord(6, 2), 3)), new ArrayList<>(harness.getOutput())); - // Check the other lifecycles. + // Checks the other lifecycles. harness.getStreamTask() .triggerCheckpointOnBarrier( new CheckpointMetaData(5, 2), @@ -115,6 +107,31 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { harness.getStreamTask().notifyCheckpointAbortAsync(6, 5); harness.processAll(); + harness.getOutput().clear(); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one-0")), 0); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one-1")), 1); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one-2")), 2); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one-0")), 0); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one-1")), 1); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one-2")), 2); + + // Checks the output + assertEquals( + Arrays.asList( + new StreamRecord<>( + IterationRecord.newEpochWatermark( + 1, OperatorUtils.getUniqueSenderId(operatorId, 0))), + new StreamRecord<>( + IterationRecord.newEpochWatermark( + 2, OperatorUtils.getUniqueSenderId(operatorId, 0)))), + new ArrayList<>(harness.getOutput())); + harness.processEvent(EndOfData.INSTANCE, 0); harness.processEvent(EndOfData.INSTANCE, 1); harness.processEvent(EndOfData.INSTANCE, 2); @@ -123,28 +140,50 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { assertEquals( Arrays.asList( + /* First wrapped operator */ LifeCycle.INITIALIZE_STATE, LifeCycle.OPEN, LifeCycle.PROCESS_ELEMENT, + /* second wrapped operator */ + LifeCycle.INITIALIZE_STATE, + LifeCycle.OPEN, LifeCycle.PROCESS_ELEMENT, - LifeCycle.PROCESS_ELEMENT, + /* states */ + LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER, LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER, LifeCycle.SNAPSHOT_STATE, + LifeCycle.SNAPSHOT_STATE, + LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_ABORT, + LifeCycle.NOTIFY_CHECKPOINT_ABORT, + // The first input + LifeCycle.END_INPUT, + // The second input + LifeCycle.END_INPUT, + // The third input + LifeCycle.END_INPUT, + LifeCycle.FINISH, + LifeCycle.CLOSE, + // The first input + LifeCycle.END_INPUT, + // The second input + LifeCycle.END_INPUT, + // The third input + LifeCycle.END_INPUT, LifeCycle.FINISH, LifeCycle.CLOSE), - lifeCycles); + LIFE_CYCLES); } } - private static class LifeCycleTrackingTwoInputStreamOperator + private static class LifeCycleTrackingMultiInputStreamOperator extends AbstractStreamOperatorV2<Integer> - implements MultipleInputStreamOperator<Integer> { + implements MultipleInputStreamOperator<Integer>, BoundedMultiInput { private final int numberOfInputs; - public LifeCycleTrackingTwoInputStreamOperator( + public LifeCycleTrackingMultiInputStreamOperator( StreamOperatorParameters<Integer> parameters, int numberOfInputs) { super(parameters, numberOfInputs); this.numberOfInputs = numberOfInputs; @@ -159,7 +198,7 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { @Override public void processElement(StreamRecord element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT); } }); } @@ -170,65 +209,70 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger { @Override public void open() throws Exception { super.open(); - lifeCycles.add(LifeCycle.OPEN); + LIFE_CYCLES.add(LifeCycle.OPEN); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - lifeCycles.add(LifeCycle.INITIALIZE_STATE); + LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE); } @Override public void finish() throws Exception { super.finish(); - lifeCycles.add(LifeCycle.FINISH); + LIFE_CYCLES.add(LifeCycle.FINISH); } @Override public void close() throws Exception { super.close(); - lifeCycles.add(LifeCycle.CLOSE); + LIFE_CYCLES.add(LifeCycle.CLOSE); } @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { super.prepareSnapshotPreBarrier(checkpointId); - lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); + LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - lifeCycles.add(LifeCycle.SNAPSHOT_STATE); + LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); } @Override public void notifyCheckpointAborted(long checkpointId) throws Exception { super.notifyCheckpointAborted(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + } + + @Override + public void endInput(int inputId) throws Exception { + LIFE_CYCLES.add(LifeCycle.END_INPUT); } } - /** The operator factory for the lifecycle-tracking operator. */ - public static class LifeCycleTrackingTwoInputStreamOperatorFactory + /** Life-cycle tracking stream operator factory. */ + private static class LifeCycleTrackingMultiInputStreamOperatorFactory extends AbstractStreamOperatorFactory<Integer> { @Override public <T extends StreamOperator<Integer>> T createStreamOperator( StreamOperatorParameters<Integer> parameters) { - return (T) new LifeCycleTrackingTwoInputStreamOperator(parameters, 3); + return (T) new LifeCycleTrackingMultiInputStreamOperator(parameters, 3); } @Override public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader classLoader) { - return LifeCycleTrackingTwoInputStreamOperator.class; + return LifeCycleTrackingMultiInputStreamOperator.class; } } } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java similarity index 73% copy from flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java copy to flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java index c58f431..787dd0f 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java @@ -16,12 +16,13 @@ * limitations under the License. */ -package org.apache.flink.iteration.operator.allround; +package org.apache.flink.iteration.operator.perround; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.iteration.IterationRecord; import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.iteration.operator.WrapperOperatorFactory; +import org.apache.flink.iteration.operator.allround.LifeCycle; import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder; @@ -34,6 +35,7 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; @@ -53,17 +55,17 @@ import java.util.List; import static org.junit.Assert.assertEquals; -/** Tests the {@link OneInputAllRoundWrapperOperator}. */ -public class OneInputAllRoundWrapperOperatorTest extends TestLogger { +/** Tests the {@link OneInputPerRoundWrapperOperator}. */ +public class OneInputPerRoundWrapperOperatorTest extends TestLogger { - private static List<LifeCycle> lifeCycles = new ArrayList<>(); + private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>(); @Test public void testProcessElementsAndEpochWatermarks() throws Exception { StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory = new WrapperOperatorFactory<>( SimpleOperatorFactory.of(new LifeCycleTrackingOneInputStreamOperator()), - new AllRoundOperatorWrapper<>()); + new PerRoundOperatorWrapper<>()); OperatorID operatorId = new OperatorID(); try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness = @@ -75,20 +77,15 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { .build()) { harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2)); harness.processElement(new StreamRecord<>(IterationRecord.newRecord(6, 2), 3)); - harness.processElement( - new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one"))); - // Check the output + // Checks the output assertEquals( Arrays.asList( new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), - new StreamRecord<>(IterationRecord.newRecord(6, 2), 3), - new StreamRecord<>( - IterationRecord.newEpochWatermark( - 5, OperatorUtils.getUniqueSenderId(operatorId, 0)))), + new StreamRecord<>(IterationRecord.newRecord(6, 2), 3)), new ArrayList<>(harness.getOutput())); - // Check the other lifecycles. + // Checks the other lifecycles. harness.getStreamTask() .triggerCheckpointOnBarrier( new CheckpointMetaData(5, 2), @@ -106,30 +103,61 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { harness.getStreamTask().notifyCheckpointAbortAsync(6, 5); harness.processAll(); + harness.getOutput().clear(); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one"))); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one"))); + + // Checks the output + assertEquals( + Arrays.asList( + new StreamRecord<>( + IterationRecord.newEpochWatermark( + 1, OperatorUtils.getUniqueSenderId(operatorId, 0))), + new StreamRecord<>( + IterationRecord.newEpochWatermark( + 2, OperatorUtils.getUniqueSenderId(operatorId, 0)))), + new ArrayList<>(harness.getOutput())); + harness.processEvent(EndOfData.INSTANCE, 0); harness.endInput(); harness.finishProcessing(); assertEquals( Arrays.asList( + /* First wrapped operator */ LifeCycle.SETUP, LifeCycle.INITIALIZE_STATE, LifeCycle.OPEN, LifeCycle.PROCESS_ELEMENT, + /* second wrapped operator */ + LifeCycle.SETUP, + LifeCycle.INITIALIZE_STATE, + LifeCycle.OPEN, LifeCycle.PROCESS_ELEMENT, + /* states */ + LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER, LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER, LifeCycle.SNAPSHOT_STATE, + LifeCycle.SNAPSHOT_STATE, LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, + LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, + LifeCycle.NOTIFY_CHECKPOINT_ABORT, LifeCycle.NOTIFY_CHECKPOINT_ABORT, + LifeCycle.END_INPUT, + LifeCycle.FINISH, + LifeCycle.CLOSE, + LifeCycle.END_INPUT, LifeCycle.FINISH, LifeCycle.CLOSE), - lifeCycles); + LIFE_CYCLES); } } private static class LifeCycleTrackingOneInputStreamOperator extends AbstractStreamOperator<Integer> - implements OneInputStreamOperator<Integer, Integer> { + implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput { @Override public void setup( @@ -137,61 +165,66 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger { StreamConfig config, Output<StreamRecord<Integer>> output) { super.setup(containingTask, config, output); - lifeCycles.add(LifeCycle.SETUP); + LIFE_CYCLES.add(LifeCycle.SETUP); } @Override public void open() throws Exception { super.open(); - lifeCycles.add(LifeCycle.OPEN); + LIFE_CYCLES.add(LifeCycle.OPEN); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - lifeCycles.add(LifeCycle.INITIALIZE_STATE); + LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE); } @Override public void finish() throws Exception { super.finish(); - lifeCycles.add(LifeCycle.FINISH); + LIFE_CYCLES.add(LifeCycle.FINISH); } @Override public void close() throws Exception { super.close(); - lifeCycles.add(LifeCycle.CLOSE); + LIFE_CYCLES.add(LifeCycle.CLOSE); } @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { super.prepareSnapshotPreBarrier(checkpointId); - lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); + LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - lifeCycles.add(LifeCycle.SNAPSHOT_STATE); + LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); } @Override public void notifyCheckpointAborted(long checkpointId) throws Exception { super.notifyCheckpointAborted(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); } @Override public void processElement(StreamRecord<Integer> element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT); + } + + @Override + public void endInput() throws Exception { + LIFE_CYCLES.add(LifeCycle.END_INPUT); } } } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorStateTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorStateTest.java new file mode 100644 index 0000000..c720e57 --- /dev/null +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorStateTest.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.iteration.operator.perround; + +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.java.typeutils.EnumTypeInfo; +import org.apache.flink.contrib.streaming.state.EmbeddedRocksDBStateBackend; +import org.apache.flink.core.memory.ManagedMemoryUseCase; +import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.operator.WrapperOperatorFactory; +import org.apache.flink.iteration.proxy.ProxyKeySelector; +import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.hashmap.HashMapStateBackend; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; +import org.apache.flink.streaming.api.functions.ProcessFunction; +import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.api.operators.ProcessOperator; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; +import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness; +import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder; +import org.apache.flink.util.Collector; +import org.apache.flink.util.TestLogger; + +import org.junit.Test; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; + +/** Tests the state isolation and cleanup for the per-round operators. */ +public class PerRoundOperatorStateTest extends TestLogger { + + @Test + public void testStateIsolationWithoutKeyedStateBackend() throws Exception { + testStateIsolation(null); + } + + @Test + public void testStateIsolationWithHashMapKeyedStateBackend() throws Exception { + testStateIsolation(new HashMapStateBackend()); + } + + @Test + public void testStateIsolationWithRocksDBKeyedStateBackend() throws Exception { + testStateIsolation(new EmbeddedRocksDBStateBackend()); + } + + private void testStateIsolation(@Nullable StateBackend stateBackend) throws Exception { + StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory = + new WrapperOperatorFactory<>( + SimpleOperatorFactory.of( + stateBackend == null + ? new ProcessOperator<>(new StatefulProcessFunction()) + : new KeyedProcessOperator<>( + new KeyedStatefulProcessFunction())), + new PerRoundOperatorWrapper<>()); + OperatorID operatorId = new OperatorID(); + + try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, + new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO)) + .modifyStreamConfig( + streamConfig -> { + if (stateBackend != null) { + streamConfig.setStateBackend(stateBackend); + streamConfig.setManagedMemoryFractionOperatorOfUseCase( + ManagedMemoryUseCase.STATE_BACKEND, 0.2); + streamConfig.setStateKeySerializer(IntSerializer.INSTANCE); + } + }) + .addInput( + new IterationRecordTypeInfo<>(new EnumTypeInfo<>(ActionType.class)), + 1, + stateBackend == null ? null : new ProxyKeySelector<>(x -> 10)) + .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId) + .build()) { + + // Set round 0 + harness.processElement( + new StreamRecord<>(IterationRecord.newRecord(ActionType.SET, 0)), 0); + testGetRound(harness, Arrays.asList(10, 10, stateBackend == null ? -1 : 10), 0); + testGetRound(harness, Arrays.asList(-1, -1, -1), 1); + + // Set round 1 + harness.processElement( + new StreamRecord<>(IterationRecord.newRecord(ActionType.SET, 1)), 0); + testGetRound(harness, Arrays.asList(10, 10, stateBackend == null ? -1 : 10), 1); + + // Clear round 0. Although after round 0 we should not receive records for round 0 in + // realistic, we use this method to check the current value of states. + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(0, "sender")), 0); + testGetRound(harness, Arrays.asList(-1, -1, -1), 0); + testGetRound(harness, Arrays.asList(10, 10, stateBackend == null ? -1 : 10), 1); + + // Clear round 1 + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "sender")), 0); + testGetRound(harness, Arrays.asList(-1, -1, -1), 0); + testGetRound(harness, Arrays.asList(-1, -1, -1), 1); + } + } + + private void testGetRound( + StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness, + List<Integer> expectedValues, + int round) + throws Exception { + harness.getOutput().clear(); + harness.processElement( + new StreamRecord<>(IterationRecord.newRecord(ActionType.GET, round)), 0); + + assertEquals( + expectedValues.stream() + .map(i -> IterationRecord.newRecord(i, round)) + .collect(Collectors.toList()), + harness.getOutput().stream() + .map(r -> ((StreamRecord<?>) r).getValue()) + .collect(Collectors.toList())); + } + + enum ActionType { + SET, + GET + } + + private static class StatefulProcessFunction extends ProcessFunction<ActionType, Integer> + implements CheckpointedFunction { + + private transient BroadcastState<Integer, Integer> broadcastState; + + private transient ListState<Integer> operatorState; + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + this.operatorState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("opState", IntSerializer.INSTANCE)); + this.broadcastState = + context.getOperatorStateStore() + .getBroadcastState( + new MapStateDescriptor<>( + "broadState", + IntSerializer.INSTANCE, + IntSerializer.INSTANCE)); + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception {} + + @Override + public void processElement(ActionType value, Context ctx, Collector<Integer> out) + throws Exception { + switch (value) { + case SET: + operatorState.add(10); + broadcastState.put(10, 10); + break; + case GET: + out.collect( + operatorState.get().iterator().hasNext() + ? operatorState.get().iterator().next() + : -1); + out.collect(mapNullToMinusOne(broadcastState.get(10))); + // To keep the same amount of outputs with the keyed one. + out.collect(mapNullToMinusOne(null)); + break; + } + } + } + + private static class KeyedStatefulProcessFunction + extends KeyedProcessFunction<Integer, ActionType, Integer> + implements CheckpointedFunction { + + private transient BroadcastState<Integer, Integer> broadcastState; + + private transient ListState<Integer> operatorState; + + private transient ValueState<Integer> keyedState; + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + this.operatorState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("opState", IntSerializer.INSTANCE)); + this.broadcastState = + context.getOperatorStateStore() + .getBroadcastState( + new MapStateDescriptor<>( + "broadState", + IntSerializer.INSTANCE, + IntSerializer.INSTANCE)); + this.keyedState = + context.getKeyedStateStore() + .getState( + new ValueStateDescriptor<>( + "keyedState", IntSerializer.INSTANCE)); + } + + @Override + public void snapshotState(FunctionSnapshotContext context) throws Exception {} + + @Override + public void processElement(ActionType value, Context ctx, Collector<Integer> out) + throws Exception { + switch (value) { + case SET: + operatorState.add(10); + broadcastState.put(10, 10); + keyedState.update(10); + break; + case GET: + out.collect( + operatorState.get().iterator().hasNext() + ? operatorState.get().iterator().next() + : -1); + out.collect(mapNullToMinusOne(broadcastState.get(10))); + out.collect(mapNullToMinusOne(keyedState == null ? null : keyedState.value())); + break; + } + } + } + + private static int mapNullToMinusOne(Integer value) { + return value == null ? -1 : value; + } +} diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java similarity index 71% copy from flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java copy to flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java index e7e2604..f2134b8 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java @@ -16,12 +16,13 @@ * limitations under the License. */ -package org.apache.flink.iteration.operator.allround; +package org.apache.flink.iteration.operator.perround; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.iteration.IterationRecord; import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.iteration.operator.WrapperOperatorFactory; +import org.apache.flink.iteration.operator.allround.LifeCycle; import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder; @@ -34,6 +35,7 @@ import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; @@ -53,17 +55,17 @@ import java.util.List; import static org.junit.Assert.assertEquals; -/** Tests the {@link OneInputAllRoundWrapperOperator}. */ -public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { +/** Tests the {@link OneInputPerRoundWrapperOperator}. */ +public class TwoInputPerRoundWrapperOperatorTest extends TestLogger { - private static List<LifeCycle> lifeCycles = new ArrayList<>(); + private static final List<LifeCycle> LIFE_CYCLES = new ArrayList<>(); @Test public void testProcessElementsAndEpochWatermarks() throws Exception { StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory = new WrapperOperatorFactory<>( SimpleOperatorFactory.of(new LifeCycleTrackingTwoInputStreamOperator()), - new AllRoundOperatorWrapper<>()); + new PerRoundOperatorWrapper<>()); OperatorID operatorId = new OperatorID(); try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness = @@ -76,22 +78,15 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { .build()) { harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0); harness.processElement(new StreamRecord<>(IterationRecord.newRecord(6, 2), 3), 1); - harness.processElement( - new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-0")), 0); - harness.processElement( - new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-1")), 1); - // Check the output + // Checks the output assertEquals( Arrays.asList( new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), - new StreamRecord<>(IterationRecord.newRecord(6, 2), 3), - new StreamRecord<>( - IterationRecord.newEpochWatermark( - 5, OperatorUtils.getUniqueSenderId(operatorId, 0)))), + new StreamRecord<>(IterationRecord.newRecord(6, 2), 3)), new ArrayList<>(harness.getOutput())); - // Check the other lifecycles. + // Checks the other lifecycles. harness.getStreamTask() .triggerCheckpointOnBarrier( new CheckpointMetaData(5, 2), @@ -109,6 +104,27 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { harness.getStreamTask().notifyCheckpointAbortAsync(6, 5); harness.processAll(); + harness.getOutput().clear(); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one")), 0); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one")), 1); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one")), 0); + harness.processElement( + new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one")), 1); + + // Checks the output + assertEquals( + Arrays.asList( + new StreamRecord<>( + IterationRecord.newEpochWatermark( + 1, OperatorUtils.getUniqueSenderId(operatorId, 0))), + new StreamRecord<>( + IterationRecord.newEpochWatermark( + 2, OperatorUtils.getUniqueSenderId(operatorId, 0)))), + new ArrayList<>(harness.getOutput())); + harness.processEvent(EndOfData.INSTANCE, 0); harness.processEvent(EndOfData.INSTANCE, 1); harness.endInput(); @@ -116,24 +132,44 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { assertEquals( Arrays.asList( + /* First wrapped operator */ LifeCycle.SETUP, LifeCycle.INITIALIZE_STATE, LifeCycle.OPEN, LifeCycle.PROCESS_ELEMENT_1, + /* second wrapped operator */ + LifeCycle.SETUP, + LifeCycle.INITIALIZE_STATE, + LifeCycle.OPEN, LifeCycle.PROCESS_ELEMENT_2, + /* states */ + LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER, LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER, LifeCycle.SNAPSHOT_STATE, + LifeCycle.SNAPSHOT_STATE, + LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_COMPLETE, LifeCycle.NOTIFY_CHECKPOINT_ABORT, + LifeCycle.NOTIFY_CHECKPOINT_ABORT, + // The first input + LifeCycle.END_INPUT, + // The second input + LifeCycle.END_INPUT, + LifeCycle.FINISH, + LifeCycle.CLOSE, + // The first input + LifeCycle.END_INPUT, + // The second input + LifeCycle.END_INPUT, LifeCycle.FINISH, LifeCycle.CLOSE), - lifeCycles); + LIFE_CYCLES); } } private static class LifeCycleTrackingTwoInputStreamOperator extends AbstractStreamOperator<Integer> - implements TwoInputStreamOperator<Integer, Integer, Integer> { + implements TwoInputStreamOperator<Integer, Integer, Integer>, BoundedMultiInput { @Override public void setup( @@ -141,67 +177,72 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger { StreamConfig config, Output<StreamRecord<Integer>> output) { super.setup(containingTask, config, output); - lifeCycles.add(LifeCycle.SETUP); + LIFE_CYCLES.add(LifeCycle.SETUP); } @Override public void open() throws Exception { super.open(); - lifeCycles.add(LifeCycle.OPEN); + LIFE_CYCLES.add(LifeCycle.OPEN); } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - lifeCycles.add(LifeCycle.INITIALIZE_STATE); + LIFE_CYCLES.add(LifeCycle.INITIALIZE_STATE); } @Override public void finish() throws Exception { super.finish(); - lifeCycles.add(LifeCycle.FINISH); + LIFE_CYCLES.add(LifeCycle.FINISH); } @Override public void close() throws Exception { super.close(); - lifeCycles.add(LifeCycle.CLOSE); + LIFE_CYCLES.add(LifeCycle.CLOSE); } @Override public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { super.prepareSnapshotPreBarrier(checkpointId); - lifeCycles.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); + LIFE_CYCLES.add(LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - lifeCycles.add(LifeCycle.SNAPSHOT_STATE); + LIFE_CYCLES.add(LifeCycle.SNAPSHOT_STATE); } @Override public void notifyCheckpointComplete(long checkpointId) throws Exception { super.notifyCheckpointComplete(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_COMPLETE); } @Override public void notifyCheckpointAborted(long checkpointId) throws Exception { super.notifyCheckpointAborted(checkpointId); - lifeCycles.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); + LIFE_CYCLES.add(LifeCycle.NOTIFY_CHECKPOINT_ABORT); } @Override public void processElement1(StreamRecord<Integer> element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT_1); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_1); } @Override public void processElement2(StreamRecord<Integer> element) throws Exception { output.collect(element); - lifeCycles.add(LifeCycle.PROCESS_ELEMENT_2); + LIFE_CYCLES.add(LifeCycle.PROCESS_ELEMENT_2); + } + + @Override + public void endInput(int inputId) throws Exception { + LIFE_CYCLES.add(LifeCycle.END_INPUT); } } }
