This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 92ecb0e591f30ff7dc4bd4db027350ad4edf4000 Author: congzhou.zzp <[email protected]> AuthorDate: Tue Apr 18 20:35:27 2023 +0800 [FLINK-31173] Fix wrong typeinfo in ProxyOperatorStateBackend This closes #216. --- .../proxy/state/ProxyOperatorStateBackend.java | 58 ++++++++++++--- .../BoundedPerRoundStreamIterationITCase.java | 87 ++++++++++++++++++++++ 2 files changed, 134 insertions(+), 11 deletions(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java index a6558861..58fc57a5 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java @@ -21,6 +21,11 @@ import org.apache.flink.api.common.state.BroadcastState; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ListTypeInfo; +import org.apache.flink.api.java.typeutils.MapTypeInfo; +import org.apache.flink.iteration.utils.ReflectionUtils; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.OperatorStateBackend; @@ -50,20 +55,35 @@ public class ProxyOperatorStateBackend implements OperatorStateBackend { @Override public <K, V> BroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, V> stateDescriptor) throws Exception { - MapStateDescriptor<K, V> newDescriptor = - new MapStateDescriptor<>( - stateNamePrefix.prefix(stateDescriptor.getName()), - stateDescriptor.getKeySerializer(), - stateDescriptor.getValueSerializer()); + MapStateDescriptor<K, V> newDescriptor; + if (stateDescriptor.isSerializerInitialized()) { + newDescriptor = + new MapStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + stateDescriptor.getKeySerializer(), + stateDescriptor.getValueSerializer()); + } else { + MapTypeInfo<K, V> mapTypeInfo = getMapTypeInfo(stateDescriptor); + newDescriptor = + new MapStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + mapTypeInfo.getKeyTypeInfo(), + mapTypeInfo.getValueTypeInfo()); + } return wrappedBackend.getBroadcastState(newDescriptor); } @Override public <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception { ListStateDescriptor<S> newDescriptor = - new ListStateDescriptor<>( - stateNamePrefix.prefix(stateDescriptor.getName()), - stateDescriptor.getElementSerializer()); + stateDescriptor.isSerializerInitialized() + ? new ListStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + stateDescriptor.getElementSerializer()) + : new ListStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + getElementTypeInfo(stateDescriptor)); + return wrappedBackend.getListState(newDescriptor); } @@ -71,9 +91,13 @@ public class ProxyOperatorStateBackend implements OperatorStateBackend { public <S> ListState<S> getUnionListState(ListStateDescriptor<S> stateDescriptor) throws Exception { ListStateDescriptor<S> newDescriptor = - new ListStateDescriptor<S>( - stateNamePrefix.prefix(stateDescriptor.getName()), - stateDescriptor.getElementSerializer()); + stateDescriptor.isSerializerInitialized() + ? new ListStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + stateDescriptor.getElementSerializer()) + : new ListStateDescriptor<>( + stateNamePrefix.prefix(stateDescriptor.getName()), + getElementTypeInfo(stateDescriptor)); return wrappedBackend.getUnionListState(newDescriptor); } @@ -125,4 +149,16 @@ public class ProxyOperatorStateBackend implements OperatorStateBackend { throws Exception { return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions); } + + @SuppressWarnings("unchecked,rawtypes") + private <S> TypeInformation<S> getElementTypeInfo(ListStateDescriptor<S> stateDescriptor) { + return ((ListTypeInfo) + ReflectionUtils.getFieldValue( + stateDescriptor, StateDescriptor.class, "typeInfo")) + .getElementTypeInfo(); + } + + private <K, V> MapTypeInfo<K, V> getMapTypeInfo(MapStateDescriptor<K, V> stateDescriptor) { + return ReflectionUtils.getFieldValue(stateDescriptor, StateDescriptor.class, "typeInfo"); + } } diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java index 825aa312..a1b5609f 100644 --- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java @@ -19,6 +19,10 @@ package org.apache.flink.test.iteration; import org.apache.flink.api.common.functions.JoinFunction; +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; @@ -26,12 +30,14 @@ import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; import org.apache.flink.iteration.IterationBodyResult; import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; import org.apache.flink.ml.common.datastream.EndOfStreamWindows; import org.apache.flink.ml.common.iteration.TerminateOnMaxIter; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -55,6 +61,8 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; @@ -73,6 +81,7 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { private SharedReference<BlockingQueue<OutputRecord<Integer>>> collectedOutputRecord; private SharedReference<BlockingQueue<Long>> collectedWatermarks; + private SharedReference<BlockingQueue<Long>> collectedOutputs; @Before public void setup() throws Exception { @@ -81,6 +90,7 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { collectedOutputRecord = sharedObjects.add(new LinkedBlockingQueue<>()); collectedWatermarks = sharedObjects.add(new LinkedBlockingQueue<>()); + collectedOutputs = sharedObjects.add(new LinkedBlockingQueue<>()); } @After @@ -136,6 +146,33 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { .forEachRemaining(x -> assertEquals(Long.MAX_VALUE, (long) x)); } + @Test + public void testPerRoundIterationWithState() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + DataStream<Long> broadcastStream = env.fromElements(1L); + DataStream<Long> inputStream = env.fromElements(1L); + DataStreamList outputStream = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(inputStream), + ReplayableDataStreamList.replay(broadcastStream), + IterationConfig.newBuilder() + .setOperatorLifeCycle(OperatorLifeCycle.PER_ROUND) + .build(), + new PerRoundIterationBodyWithState()); + + outputStream.<Long>get(0).addSink(new LongSink(collectedOutputs)); + JobGraph jobGraph = env.getStreamGraph().getJobGraph(); + miniCluster.executeJobBlocking(jobGraph); + + List<Long> result = new ArrayList<>(3); + collectedOutputs.get().drainTo(result); + assertEquals(3, result.size()); + for (long value : result) { + assertEquals(1L, value); + } + } + private static JobGraph createPerRoundJobGraph( int numSources, int numRecordsPerSource, @@ -229,6 +266,56 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { } } + private static class PerRoundIterationBodyWithState implements IterationBody { + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<Long> variableStream = variableStreams.get(0); + + DataStream<Long> feedback = + variableStream.transform("mapWithState", Types.LONG, new MapWithState()); + + DataStream<Integer> terminationCriteria = + feedback.<Long>flatMap(new TerminateOnMaxIter(2)).returns(Types.INT); + + return new IterationBodyResult( + DataStreamList.of(feedback), DataStreamList.of(feedback), terminationCriteria); + } + } + + private static class MapWithState extends AbstractStreamOperator<Long> + implements OneInputStreamOperator<Long, Long> { + private ListState<Long> listState; + private ListState<Long> unionState; + private BroadcastState<Long, Long> broadcastState; + + @Override + public void processElement(StreamRecord<Long> element) throws Exception { + long val = element.getValue(); + listState.add(val); + unionState.add(val); + broadcastState.put(val, val); + output.collect(element); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + listState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("longState", Types.LONG)); + unionState = + context.getOperatorStateStore() + .getUnionListState(new ListStateDescriptor<>("unionState", Types.LONG)); + broadcastState = + context.getOperatorStateStore() + .getBroadcastState( + new MapStateDescriptor<>( + "broadcastState", Types.LONG, Types.LONG)); + } + } + private static class LongSink implements SinkFunction<Long> { private final SharedReference<BlockingQueue<Long>> collectedLong;
