This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 89468545 [FLINK-31255] OperatorUtils#createWrappedOperatorConfig
should update input and sideOutput serializers
89468545 is described below
commit 894685455d1c26fd45198857b7a96ee850725a59
Author: JiangXin <[email protected]>
AuthorDate: Tue Apr 18 13:47:52 2023 +0800
[FLINK-31255] OperatorUtils#createWrappedOperatorConfig should update input
and sideOutput serializers
This closes #229.
---
.../flink/iteration/operator/OperatorUtils.java | 65 +++++++++++++++++++--
flink-ml-tests/pom.xml | 7 +++
.../iteration/UnboundedStreamIterationITCase.java | 67 ++++++++++++++++++++++
3 files changed, 135 insertions(+), 4 deletions(-)
diff --git
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index b3b629fe..3d67b7fa 100644
---
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -18,21 +18,26 @@
package org.apache.flink.iteration.operator;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.config.IterationOptions;
import org.apache.flink.iteration.proxy.ProxyKeySelector;
+import org.apache.flink.iteration.typeinfo.IterationRecordSerializer;
+import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamConfig.NetworkInputConfig;
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.OutputTag;
import org.apache.flink.util.function.SupplierWithException;
import org.apache.flink.util.function.ThrowingConsumer;
@@ -42,6 +47,7 @@ import java.util.Arrays;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executor;
+import java.util.stream.Stream;
import static org.apache.flink.util.Preconditions.checkState;
@@ -89,11 +95,10 @@ public class OperatorUtils {
}
}
- public static StreamConfig createWrappedOperatorConfig(
- StreamConfig wrapperConfig, ClassLoader cl) {
- StreamConfig wrappedConfig = new
StreamConfig(wrapperConfig.getConfiguration().clone());
+ public static StreamConfig createWrappedOperatorConfig(StreamConfig
config, ClassLoader cl) {
+ StreamConfig wrappedConfig = new
StreamConfig(config.getConfiguration().clone());
for (int i = 0; i < wrappedConfig.getNumberOfNetworkInputs(); ++i) {
- KeySelector keySelector = wrapperConfig.getStatePartitioner(i, cl);
+ KeySelector keySelector = config.getStatePartitioner(i, cl);
if (keySelector != null) {
checkState(
keySelector instanceof ProxyKeySelector,
@@ -104,6 +109,58 @@ public class OperatorUtils {
}
}
+ StreamConfig.InputConfig[] inputs = config.getInputs(cl);
+ for (int i = 0; i < inputs.length; ++i) {
+ if (inputs[i] instanceof NetworkInputConfig) {
+ TypeSerializer<?> typeSerializerIn =
+ ((NetworkInputConfig) inputs[i]).getTypeSerializer();
+ checkState(
+ typeSerializerIn instanceof IterationRecordSerializer,
+ "The serializer of input[%s] should be
IterationRecordSerializer but it is %s.",
+ i,
+ typeSerializerIn);
+ inputs[i] =
+ new NetworkInputConfig(
+ ((IterationRecordSerializer<?>)
typeSerializerIn)
+ .getInnerSerializer(),
+ i);
+ }
+ }
+ wrappedConfig.setInputs(inputs);
+
+ TypeSerializer<?> typeSerializerOut = config.getTypeSerializerOut(cl);
+ checkState(
+ typeSerializerOut instanceof IterationRecordSerializer,
+ "The serializer of output should be IterationRecordSerializer
but it is %s.",
+ typeSerializerOut);
+ wrappedConfig.setTypeSerializerOut(
+ ((IterationRecordSerializer<?>)
typeSerializerOut).getInnerSerializer());
+
+ Stream.concat(
+ config.getChainedOutputs(cl).stream(),
+ config.getNonChainedOutputs(cl).stream())
+ .forEach(
+ edge -> {
+ OutputTag<?> outputTag = edge.getOutputTag();
+ if (outputTag != null) {
+ TypeSerializer<?> typeSerializerSideOut =
+
config.getTypeSerializerSideOut(outputTag, cl);
+ checkState(
+ typeSerializerSideOut instanceof
IterationRecordSerializer,
+ "The serializer of side output with
tag[%s] should be IterationRecordSerializer but it is %s.",
+ outputTag,
+ typeSerializerSideOut);
+ wrappedConfig.setTypeSerializerSideOut(
+ new OutputTag<>(
+ outputTag.getId(),
+ ((IterationRecordTypeInfo<?>)
+
outputTag.getTypeInfo())
+ .getInnerTypeInfo()),
+ ((IterationRecordSerializer)
typeSerializerSideOut)
+ .getInnerSerializer());
+ }
+ });
+
return wrappedConfig;
}
diff --git a/flink-ml-tests/pom.xml b/flink-ml-tests/pom.xml
index 153f89c6..534b9ca7 100644
--- a/flink-ml-tests/pom.xml
+++ b/flink-ml-tests/pom.xml
@@ -60,5 +60,12 @@ under the License.
<version>${flink.version}</version>
<scope>test</scope>
</dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-clients</artifactId>
+ <version>${flink.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index 206ba21b..54e7eefd 100644
---
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -18,6 +18,8 @@
package org.apache.flink.test.iteration;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.RestOptions;
@@ -32,6 +34,10 @@ import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.test.iteration.operators.CollectSink;
import org.apache.flink.test.iteration.operators.EpochRecord;
import org.apache.flink.test.iteration.operators.IncrementEpochMap;
@@ -45,11 +51,13 @@ import org.apache.flink.testutils.junit.SharedReference;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.TestLogger;
+import org.apache.commons.collections.IteratorUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
@@ -152,6 +160,39 @@ public class UnboundedStreamIterationITCase extends
TestLogger {
assertEquals(OutputRecord.Event.TERMINATED,
result.get().take().getEvent());
}
+ @Test
+ public void testBoundedIterationWithSideOutput() throws Exception {
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(1);
+ env.getConfig().enableObjectReuse();
+
+ final OutputTag<Integer> outputTag = new OutputTag("0", Types.INT) {};
+ final Integer[] sourceData = new Integer[] {1, 2, 3};
+
+ DataStream<Integer> variableStream =
+ env.addSource(new
DraftExecutionEnvironment.EmptySource<Integer>() {});
+ DataStream<Integer> dataStream = env.fromElements(sourceData);
+
+ DataStreamList result =
+ Iterations.iterateUnboundedStreams(
+ DataStreamList.of(variableStream),
+ DataStreamList.of(dataStream),
+ (variableStreams, dataStreams) -> {
+ SingleOutputStreamOperator transformed =
+ dataStreams
+ .<Integer>get(0)
+ .transform(
+ "side-output",
+ Types.INT,
+ new
SideOutputOperator(outputTag));
+ return new IterationBodyResult(
+ DataStreamList.of(variableStreams.get(0)),
+
DataStreamList.of(transformed.getSideOutput(outputTag)));
+ });
+ assertEquals(
+ Arrays.asList(sourceData),
IteratorUtils.toList(result.get(0).executeAndCollect()));
+ }
+
public static MiniClusterConfiguration createMiniClusterConfiguration(int
numTm, int numSlot) {
Configuration configuration = new Configuration();
configuration.set(RestOptions.BIND_PORT, "18081-19091");
@@ -270,6 +311,32 @@ public class UnboundedStreamIterationITCase extends
TestLogger {
return env.getStreamGraph().getJobGraph();
}
+ private static class SideOutputOperator extends
AbstractStreamOperator<Integer>
+ implements OneInputStreamOperator<Integer, Integer> {
+
+ private final OutputTag<Integer> outputTag;
+
+ public SideOutputOperator(OutputTag<Integer> outputTag) {
+ this.outputTag = outputTag;
+ }
+
+ @Override
+ public void open() throws Exception {
+ super.open();
+ StreamConfig config = getOperatorConfig();
+ ClassLoader cl = getClass().getClassLoader();
+
+ assertEquals(IntSerializer.INSTANCE, config.getTypeSerializerIn(0,
cl));
+ assertEquals(IntSerializer.INSTANCE,
config.getTypeSerializerOut(cl));
+ assertEquals(IntSerializer.INSTANCE,
config.getTypeSerializerSideOut(outputTag, cl));
+ }
+
+ @Override
+ public void processElement(StreamRecord<Integer> element) {
+ output.collect(outputTag, element);
+ }
+ }
+
static Map<Integer, Tuple2<Integer, Integer>> computeRoundStat(
BlockingQueue<OutputRecord<Integer>> result,
OutputRecord.Event event,