This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 89468545 [FLINK-31255] OperatorUtils#createWrappedOperatorConfig 
should update input and sideOutput serializers
89468545 is described below

commit 894685455d1c26fd45198857b7a96ee850725a59
Author: JiangXin <[email protected]>
AuthorDate: Tue Apr 18 13:47:52 2023 +0800

    [FLINK-31255] OperatorUtils#createWrappedOperatorConfig should update input 
and sideOutput serializers
    
    This closes #229.
---
 .../flink/iteration/operator/OperatorUtils.java    | 65 +++++++++++++++++++--
 flink-ml-tests/pom.xml                             |  7 +++
 .../iteration/UnboundedStreamIterationITCase.java  | 67 ++++++++++++++++++++++
 3 files changed, 135 insertions(+), 4 deletions(-)

diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index b3b629fe..3d67b7fa 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -18,21 +18,26 @@
 
 package org.apache.flink.iteration.operator;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.config.IterationOptions;
 import org.apache.flink.iteration.proxy.ProxyKeySelector;
+import org.apache.flink.iteration.typeinfo.IterationRecordSerializer;
+import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.iteration.utils.ReflectionUtils;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamConfig.NetworkInputConfig;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.function.SupplierWithException;
 import org.apache.flink.util.function.ThrowingConsumer;
 
@@ -42,6 +47,7 @@ import java.util.Arrays;
 import java.util.Random;
 import java.util.UUID;
 import java.util.concurrent.Executor;
+import java.util.stream.Stream;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -89,11 +95,10 @@ public class OperatorUtils {
         }
     }
 
-    public static StreamConfig createWrappedOperatorConfig(
-            StreamConfig wrapperConfig, ClassLoader cl) {
-        StreamConfig wrappedConfig = new 
StreamConfig(wrapperConfig.getConfiguration().clone());
+    public static StreamConfig createWrappedOperatorConfig(StreamConfig 
config, ClassLoader cl) {
+        StreamConfig wrappedConfig = new 
StreamConfig(config.getConfiguration().clone());
         for (int i = 0; i < wrappedConfig.getNumberOfNetworkInputs(); ++i) {
-            KeySelector keySelector = wrapperConfig.getStatePartitioner(i, cl);
+            KeySelector keySelector = config.getStatePartitioner(i, cl);
             if (keySelector != null) {
                 checkState(
                         keySelector instanceof ProxyKeySelector,
@@ -104,6 +109,58 @@ public class OperatorUtils {
             }
         }
 
+        StreamConfig.InputConfig[] inputs = config.getInputs(cl);
+        for (int i = 0; i < inputs.length; ++i) {
+            if (inputs[i] instanceof NetworkInputConfig) {
+                TypeSerializer<?> typeSerializerIn =
+                        ((NetworkInputConfig) inputs[i]).getTypeSerializer();
+                checkState(
+                        typeSerializerIn instanceof IterationRecordSerializer,
+                        "The serializer of input[%s] should be 
IterationRecordSerializer but it is %s.",
+                        i,
+                        typeSerializerIn);
+                inputs[i] =
+                        new NetworkInputConfig(
+                                ((IterationRecordSerializer<?>) 
typeSerializerIn)
+                                        .getInnerSerializer(),
+                                i);
+            }
+        }
+        wrappedConfig.setInputs(inputs);
+
+        TypeSerializer<?> typeSerializerOut = config.getTypeSerializerOut(cl);
+        checkState(
+                typeSerializerOut instanceof IterationRecordSerializer,
+                "The serializer of output should be IterationRecordSerializer 
but it is %s.",
+                typeSerializerOut);
+        wrappedConfig.setTypeSerializerOut(
+                ((IterationRecordSerializer<?>) 
typeSerializerOut).getInnerSerializer());
+
+        Stream.concat(
+                        config.getChainedOutputs(cl).stream(),
+                        config.getNonChainedOutputs(cl).stream())
+                .forEach(
+                        edge -> {
+                            OutputTag<?> outputTag = edge.getOutputTag();
+                            if (outputTag != null) {
+                                TypeSerializer<?> typeSerializerSideOut =
+                                        
config.getTypeSerializerSideOut(outputTag, cl);
+                                checkState(
+                                        typeSerializerSideOut instanceof 
IterationRecordSerializer,
+                                        "The serializer of side output with 
tag[%s] should be IterationRecordSerializer but it is %s.",
+                                        outputTag,
+                                        typeSerializerSideOut);
+                                wrappedConfig.setTypeSerializerSideOut(
+                                        new OutputTag<>(
+                                                outputTag.getId(),
+                                                ((IterationRecordTypeInfo<?>)
+                                                                
outputTag.getTypeInfo())
+                                                        .getInnerTypeInfo()),
+                                        ((IterationRecordSerializer) 
typeSerializerSideOut)
+                                                .getInnerSerializer());
+                            }
+                        });
+
         return wrappedConfig;
     }
 
diff --git a/flink-ml-tests/pom.xml b/flink-ml-tests/pom.xml
index 153f89c6..534b9ca7 100644
--- a/flink-ml-tests/pom.xml
+++ b/flink-ml-tests/pom.xml
@@ -60,5 +60,12 @@ under the License.
             <version>${flink.version}</version>
             <scope>test</scope>
         </dependency>
+
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-clients</artifactId>
+            <version>${flink.version}</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 </project>
diff --git 
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
 
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index 206ba21b..54e7eefd 100644
--- 
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++ 
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.test.iteration;
 
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.RestOptions;
@@ -32,6 +34,10 @@ import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.test.iteration.operators.CollectSink;
 import org.apache.flink.test.iteration.operators.EpochRecord;
 import org.apache.flink.test.iteration.operators.IncrementEpochMap;
@@ -45,11 +51,13 @@ import org.apache.flink.testutils.junit.SharedReference;
 import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.TestLogger;
 
+import org.apache.commons.collections.IteratorUtils;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.BlockingQueue;
@@ -152,6 +160,39 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
         assertEquals(OutputRecord.Event.TERMINATED, 
result.get().take().getEvent());
     }
 
+    @Test
+    public void testBoundedIterationWithSideOutput() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+        env.getConfig().enableObjectReuse();
+
+        final OutputTag<Integer> outputTag = new OutputTag("0", Types.INT) {};
+        final Integer[] sourceData = new Integer[] {1, 2, 3};
+
+        DataStream<Integer> variableStream =
+                env.addSource(new 
DraftExecutionEnvironment.EmptySource<Integer>() {});
+        DataStream<Integer> dataStream = env.fromElements(sourceData);
+
+        DataStreamList result =
+                Iterations.iterateUnboundedStreams(
+                        DataStreamList.of(variableStream),
+                        DataStreamList.of(dataStream),
+                        (variableStreams, dataStreams) -> {
+                            SingleOutputStreamOperator transformed =
+                                    dataStreams
+                                            .<Integer>get(0)
+                                            .transform(
+                                                    "side-output",
+                                                    Types.INT,
+                                                    new 
SideOutputOperator(outputTag));
+                            return new IterationBodyResult(
+                                    DataStreamList.of(variableStreams.get(0)),
+                                    
DataStreamList.of(transformed.getSideOutput(outputTag)));
+                        });
+        assertEquals(
+                Arrays.asList(sourceData), 
IteratorUtils.toList(result.get(0).executeAndCollect()));
+    }
+
     public static MiniClusterConfiguration createMiniClusterConfiguration(int 
numTm, int numSlot) {
         Configuration configuration = new Configuration();
         configuration.set(RestOptions.BIND_PORT, "18081-19091");
@@ -270,6 +311,32 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
         return env.getStreamGraph().getJobGraph();
     }
 
+    private static class SideOutputOperator extends 
AbstractStreamOperator<Integer>
+            implements OneInputStreamOperator<Integer, Integer> {
+
+        private final OutputTag<Integer> outputTag;
+
+        public SideOutputOperator(OutputTag<Integer> outputTag) {
+            this.outputTag = outputTag;
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+            StreamConfig config = getOperatorConfig();
+            ClassLoader cl = getClass().getClassLoader();
+
+            assertEquals(IntSerializer.INSTANCE, config.getTypeSerializerIn(0, 
cl));
+            assertEquals(IntSerializer.INSTANCE, 
config.getTypeSerializerOut(cl));
+            assertEquals(IntSerializer.INSTANCE, 
config.getTypeSerializerSideOut(outputTag, cl));
+        }
+
+        @Override
+        public void processElement(StreamRecord<Integer> element) {
+            output.collect(outputTag, element);
+        }
+    }
+
     static Map<Integer, Tuple2<Integer, Integer>> computeRoundStat(
             BlockingQueue<OutputRecord<Integer>> result,
             OutputRecord.Event event,

Reply via email to