This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.17 by this push:
     new 7040af5b793 [FLINK-31185][python] Support side-output in broadcast 
processing
7040af5b793 is described below

commit 7040af5b7933905798ff6af0b35ac364b5fbe432
Author: Juntao Hu <[email protected]>
AuthorDate: Thu Feb 23 15:47:42 2023 +0800

    [FLINK-31185][python] Support side-output in broadcast processing
    
    This closes #22003.
---
 .../pyflink/datastream/tests/test_data_stream.py   |  77 +++++++++++++
 .../fn_execution/datastream/embedded/operations.py |  25 ++--
 .../chain/PythonOperatorChainingOptimizer.java     |   7 ++
 .../apache/flink/python/util/PythonConfigUtil.java |   9 ++
 .../python/DelegateOperatorTransformation.java     | 128 +++++++++++++++++++++
 .../python/PythonBroadcastStateTransformation.java |  10 +-
 .../PythonKeyedBroadcastStateTransformation.java   |  11 +-
 ...thonBroadcastStateTransformationTranslator.java |  11 +-
 ...eyedBroadcastStateTransformationTranslator.java |  11 +-
 9 files changed, 269 insertions(+), 20 deletions(-)

diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 3b7bb626b1a..8d4e0bea1b8 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -589,6 +589,40 @@ class DataStreamTests(object):
         side_expected = ['0', '0', '1', '1', '2', '3']
         self.assert_equals_sorted(side_expected, side_sink.get_results())
 
+    def test_co_broadcast_side_output(self):
+        tag = OutputTag("side", Types.INT())
+
+        class MyBroadcastProcessFunction(BroadcastProcessFunction):
+
+            def process_element(self, value, ctx):
+                yield value[0]
+                yield tag, value[1]
+
+            def process_broadcast_element(self, value, ctx):
+                yield value[1]
+                yield tag, value[0]
+
+        self.env.set_parallelism(2)
+        ds = self.env.from_collection([('a', 0), ('b', 1), ('c', 2)],
+                                      type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
+        ds_broadcast = self.env.from_collection([(3, 'd'), (4, 'f')],
+                                                
type_info=Types.ROW([Types.INT(), Types.STRING()]))
+        map_state_desc = MapStateDescriptor(
+            "dummy", key_type_info=Types.INT(), value_type_info=Types.STRING()
+        )
+        ds = ds.connect(ds_broadcast.broadcast(map_state_desc)).process(
+            MyBroadcastProcessFunction(), output_type=Types.STRING()
+        )
+        side_sink = DataStreamTestSinkFunction()
+        ds.get_side_output(tag).add_sink(side_sink)
+        ds.add_sink(self.test_sink)
+
+        self.env.execute("test_co_broadcast_process_side_output")
+        main_expected = ['a', 'b', 'c', 'd', 'd', 'f', 'f']
+        self.assert_equals_sorted(main_expected, self.test_sink.get_results())
+        side_expected = ['0', '1', '2', '3', '3', '4', '4']
+        self.assert_equals_sorted(side_expected, side_sink.get_results())
+
     def test_keyed_process_side_output(self):
         tag = OutputTag("side", Types.INT())
 
@@ -665,6 +699,49 @@ class DataStreamTests(object):
         side_expected = ['1', '1', '2', '2', '3', '3', '4', '4']
         self.assert_equals_sorted(side_expected, side_sink.get_results())
 
+    def test_keyed_co_broadcast_side_output(self):
+        tag = OutputTag("side", Types.INT())
+
+        class MyKeyedBroadcastProcessFunction(KeyedBroadcastProcessFunction):
+
+            def __init__(self):
+                self.reducing_state = None  # type: ReducingState
+
+            def open(self, context: RuntimeContext):
+                self.reducing_state = context.get_reducing_state(
+                    ReducingStateDescriptor("reduce", lambda i, j: i+j, 
Types.INT())
+                )
+
+            def process_element(self, value, ctx):
+                self.reducing_state.add(value[1])
+                yield value[0]
+                yield tag, self.reducing_state.get()
+
+            def process_broadcast_element(self, value, ctx):
+                yield value[1]
+                yield tag, value[0]
+
+        self.env.set_parallelism(2)
+        ds = self.env.from_collection([('a', 0), ('b', 1), ('a', 2), ('b', 3)],
+                                      type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
+        ds_broadcast = self.env.from_collection([(5, 'c'), (6, 'd')],
+                                                
type_info=Types.ROW([Types.INT(), Types.STRING()]))
+        map_state_desc = MapStateDescriptor(
+            "dummy", key_type_info=Types.INT(), value_type_info=Types.STRING()
+        )
+        ds = ds.key_by(lambda e: 
e[0]).connect(ds_broadcast.broadcast(map_state_desc)).process(
+            MyKeyedBroadcastProcessFunction(), output_type=Types.STRING()
+        )
+        side_sink = DataStreamTestSinkFunction()
+        ds.get_side_output(tag).add_sink(side_sink)
+        ds.add_sink(self.test_sink)
+
+        self.env.execute("test_keyed_co_broadcast_process_side_output")
+        main_expected = ['a', 'a', 'b', 'b', 'c', 'c', 'd', 'd']
+        self.assert_equals_sorted(main_expected, self.test_sink.get_results())
+        side_expected = ['0', '1', '2', '4', '5', '5', '6', '6']
+        self.assert_equals_sorted(side_expected, side_sink.get_results())
+
     def test_side_output_stream_execute_and_collect(self):
         tag = OutputTag("side", Types.INT())
 
diff --git 
a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py 
b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
index 7bac48cdf10..5160bc9de26 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
@@ -100,6 +100,8 @@ def extract_process_function(
         side_output_context = SideOutputContext(j_side_output_context)
 
         def process_func(values):
+            if values is None:
+                return
             for value in values:
                 if isinstance(value, tuple) and isinstance(value[0], 
OutputTag):
                     output_tag = value[0]  # type: OutputTag
@@ -108,6 +110,8 @@ def extract_process_function(
                     yield value
     else:
         def process_func(values):
+            if values is None:
+                return
             yield from values
 
     def open_func():
@@ -174,14 +178,10 @@ def extract_process_function(
         process_broadcast_element = user_defined_func.process_broadcast_element
 
         def process_element_func1(value):
-            elements = process_element(value, read_only_broadcast_ctx)
-            if elements:
-                yield from elements
+            yield from process_func(process_element(value, 
read_only_broadcast_ctx))
 
         def process_element_func2(value):
-            elements = process_broadcast_element(value, broadcast_ctx)
-            if elements:
-                yield from elements
+            yield from process_func(process_broadcast_element(value, 
broadcast_ctx))
 
         return TwoInputOperation(
             open_func, close_func, process_element_func1, 
process_element_func2)
@@ -221,19 +221,20 @@ def extract_process_function(
         timer_context = InternalKeyedBroadcastProcessFunctionOnTimerContext(
             j_timer_context, user_defined_function_proto.key_type_info, 
j_operator_state_backend)
 
+        keyed_state_backend = KeyedStateBackend(
+            read_only_broadcast_ctx,
+            j_keyed_state_backend)
+        runtime_context.set_keyed_state_backend(keyed_state_backend)
+
         process_element = user_defined_func.process_element
         process_broadcast_element = user_defined_func.process_broadcast_element
         on_timer = user_defined_func.on_timer
 
         def process_element_func1(value):
-            elements = process_element(value[1], read_only_broadcast_ctx)
-            if elements:
-                yield from elements
+            yield from process_func(process_element(value[1], 
read_only_broadcast_ctx))
 
         def process_element_func2(value):
-            elements = process_broadcast_element(value, broadcast_ctx)
-            if elements:
-                yield from elements
+            yield from process_func(process_broadcast_element(value, 
broadcast_ctx))
 
         def on_timer_func(timestamp):
             yield from on_timer(timestamp, timer_context)
diff --git 
a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
 
b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
index 79d6ca8d213..fbbc63420f5 100644
--- 
a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
+++ 
b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
@@ -56,6 +56,8 @@ import 
org.apache.flink.streaming.api.transformations.SinkTransformation;
 import 
org.apache.flink.streaming.api.transformations.TimestampsAndWatermarksTransformation;
 import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
 import org.apache.flink.streaming.api.transformations.UnionTransformation;
+import 
org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
+import 
org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 
 import org.apache.flink.shaded.guava30.com.google.common.collect.Lists;
@@ -409,6 +411,11 @@ public class PythonOperatorChainingOptimizer {
             return false;
         }
 
+        if (upTransform instanceof PythonBroadcastStateTransformation
+                || upTransform instanceof 
PythonKeyedBroadcastStateTransformation) {
+            return false;
+        }
+
         DataStreamPythonFunctionOperator<?> upOperator =
                 (DataStreamPythonFunctionOperator<?>)
                         ((SimpleOperatorFactory<?>) 
getOperatorFactory(upTransform)).getOperator();
diff --git 
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java 
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 8f942818d64..4666c5df790 100644
--- 
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++ 
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -42,6 +42,7 @@ import 
org.apache.flink.streaming.api.transformations.OneInputTransformation;
 import org.apache.flink.streaming.api.transformations.PartitionTransformation;
 import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
 import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
+import 
org.apache.flink.streaming.api.transformations.python.DelegateOperatorTransformation;
 import 
org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
 import 
org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
 import org.apache.flink.streaming.api.utils.ByteArrayWrapper;
@@ -152,6 +153,8 @@ public class PythonConfigUtil {
             return ((TwoInputTransformation<?, ?, ?>) 
transform).getOperatorFactory();
         } else if (transform instanceof AbstractMultipleInputTransformation) {
             return ((AbstractMultipleInputTransformation<?>) 
transform).getOperatorFactory();
+        } else if (transform instanceof DelegateOperatorTransformation<?>) {
+            return ((DelegateOperatorTransformation<?>) 
transform).getOperatorFactory();
         } else {
             return null;
         }
@@ -214,6 +217,9 @@ public class PythonConfigUtil {
         } else if (transformation instanceof 
AbstractMultipleInputTransformation) {
             operatorFactory =
                     ((AbstractMultipleInputTransformation<?>) 
transformation).getOperatorFactory();
+        } else if (transformation instanceof DelegateOperatorTransformation) {
+            operatorFactory =
+                    ((DelegateOperatorTransformation<?>) 
transformation).getOperatorFactory();
         }
 
         if (operatorFactory instanceof SimpleOperatorFactory
@@ -260,6 +266,9 @@ public class PythonConfigUtil {
         } else if (transform instanceof TwoInputTransformation) {
             return isPythonDataStreamOperator(
                     ((TwoInputTransformation<?, ?, ?>) 
transform).getOperatorFactory());
+        } else if (transform instanceof PythonBroadcastStateTransformation
+                || transform instanceof 
PythonKeyedBroadcastStateTransformation) {
+            return true;
         } else {
             return false;
         }
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/DelegateOperatorTransformation.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/DelegateOperatorTransformation.java
new file mode 100644
index 00000000000..40623cb7e11
--- /dev/null
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/DelegateOperatorTransformation.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.transformations.python;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.env.PythonEnvironmentManager;
+import 
org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import 
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
+import 
org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
+import org.apache.flink.util.OutputTag;
+
+import javax.annotation.Nullable;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * For those {@link org.apache.flink.api.dag.Transformation} that don't have 
an operator entity,
+ * {@link DelegateOperatorTransformation} provides a {@link 
SimpleOperatorFactory} containing a
+ * {@link DelegateOperator} , which can hold special configurations during 
transformation
+ * preprocessing for Python jobs, and later be queried at translation stage. 
Currently, those
+ * configurations include {@link OutputTag}s, {@code numPartitions} and 
general {@link
+ * Configuration}.
+ */
+public interface DelegateOperatorTransformation<OUT> {
+
+    SimpleOperatorFactory<OUT> getOperatorFactory();
+
+    static void configureOperator(
+            DelegateOperatorTransformation<?> transformation,
+            AbstractPythonFunctionOperator<?> operator) {
+        DelegateOperator<?> delegateOperator =
+                (DelegateOperator<?>) 
transformation.getOperatorFactory().getOperator();
+
+        
operator.getConfiguration().addAll(delegateOperator.getConfiguration());
+
+        if (operator instanceof DataStreamPythonFunctionOperator) {
+            DataStreamPythonFunctionOperator<?> dataStreamOperator =
+                    (DataStreamPythonFunctionOperator<?>) operator;
+            
dataStreamOperator.addSideOutputTags(delegateOperator.getSideOutputTags());
+            if (delegateOperator.getNumPartitions() != null) {
+                
dataStreamOperator.setNumPartitions(delegateOperator.getNumPartitions());
+            }
+        }
+    }
+
+    /**
+     * {@link DelegateOperator} holds configurations, e.g. {@link OutputTag}s, 
which will be applied
+     * to the actual python operator at translation stage.
+     */
+    class DelegateOperator<OUT> extends AbstractPythonFunctionOperator<OUT>
+            implements DataStreamPythonFunctionOperator<OUT> {
+
+        private final Map<String, OutputTag<?>> sideOutputTags = new 
HashMap<>();
+        private @Nullable Integer numPartitions = null;
+
+        public DelegateOperator() {
+            super(new Configuration());
+        }
+
+        @Override
+        public void addSideOutputTags(Collection<OutputTag<?>> outputTags) {
+            for (OutputTag<?> outputTag : outputTags) {
+                sideOutputTags.put(outputTag.getId(), outputTag);
+            }
+        }
+
+        @Override
+        public Collection<OutputTag<?>> getSideOutputTags() {
+            return sideOutputTags.values();
+        }
+
+        @Override
+        public void setNumPartitions(int numPartitions) {
+            this.numPartitions = numPartitions;
+        }
+
+        @Nullable
+        public Integer getNumPartitions() {
+            return numPartitions;
+        }
+
+        @Override
+        public TypeInformation<OUT> getProducedType() {
+            throw new RuntimeException("This should not be invoked on a 
DelegateOperator!");
+        }
+
+        @Override
+        public DataStreamPythonFunctionInfo getPythonFunctionInfo() {
+            throw new RuntimeException("This should not be invoked on a 
DelegateOperator!");
+        }
+
+        @Override
+        public <T> DataStreamPythonFunctionOperator<T> copy(
+                DataStreamPythonFunctionInfo pythonFunctionInfo,
+                TypeInformation<T> outputTypeInfo) {
+            throw new RuntimeException("This should not be invoked on a 
DelegateOperator!");
+        }
+
+        @Override
+        protected void invokeFinishBundle() throws Exception {
+            throw new RuntimeException("This should not be invoked on a 
DelegateOperator!");
+        }
+
+        @Override
+        protected PythonEnvironmentManager createPythonEnvironmentManager() {
+            throw new RuntimeException("This should not be invoked on a 
DelegateOperator!");
+        }
+    }
+}
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonBroadcastStateTransformation.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonBroadcastStateTransformation.java
index b6c2be777a1..6b2544e21e0 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonBroadcastStateTransformation.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonBroadcastStateTransformation.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.dag.Transformation;
 import org.apache.flink.configuration.Configuration;
 import 
org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import 
org.apache.flink.streaming.api.transformations.AbstractBroadcastStateTransformation;
 
 import java.util.List;
@@ -34,10 +35,12 @@ import java.util.List;
  */
 @Internal
 public class PythonBroadcastStateTransformation<IN1, IN2, OUT>
-        extends AbstractBroadcastStateTransformation<IN1, IN2, OUT> {
+        extends AbstractBroadcastStateTransformation<IN1, IN2, OUT>
+        implements DelegateOperatorTransformation<OUT> {
 
     private final Configuration configuration;
     private final DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo;
+    private final SimpleOperatorFactory<OUT> delegateOperatorFactory;
 
     public PythonBroadcastStateTransformation(
             String name,
@@ -57,6 +60,7 @@ public class PythonBroadcastStateTransformation<IN1, IN2, OUT>
                 parallelism);
         this.configuration = configuration;
         this.dataStreamPythonFunctionInfo = dataStreamPythonFunctionInfo;
+        this.delegateOperatorFactory = SimpleOperatorFactory.of(new 
DelegateOperator<>());
         updateManagedMemoryStateBackendUseCase(false);
     }
 
@@ -67,4 +71,8 @@ public class PythonBroadcastStateTransformation<IN1, IN2, OUT>
     public DataStreamPythonFunctionInfo getDataStreamPythonFunctionInfo() {
         return dataStreamPythonFunctionInfo;
     }
+
+    public SimpleOperatorFactory<OUT> getOperatorFactory() {
+        return delegateOperatorFactory;
+    }
 }
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonKeyedBroadcastStateTransformation.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonKeyedBroadcastStateTransformation.java
index 03919f609eb..72341622243 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonKeyedBroadcastStateTransformation.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonKeyedBroadcastStateTransformation.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.dag.Transformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import 
org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import 
org.apache.flink.streaming.api.transformations.AbstractBroadcastStateTransformation;
 import org.apache.flink.types.Row;
 
@@ -36,12 +37,14 @@ import java.util.List;
  */
 @Internal
 public class PythonKeyedBroadcastStateTransformation<OUT>
-        extends AbstractBroadcastStateTransformation<Row, Row, OUT> {
+        extends AbstractBroadcastStateTransformation<Row, Row, OUT>
+        implements DelegateOperatorTransformation<OUT> {
 
     private final Configuration configuration;
     private final DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo;
     private final TypeInformation<Row> stateKeyType;
     private final KeySelector<Row, Row> keySelector;
+    private final SimpleOperatorFactory<OUT> delegateOperatorFactory;
 
     public PythonKeyedBroadcastStateTransformation(
             String name,
@@ -65,6 +68,7 @@ public class PythonKeyedBroadcastStateTransformation<OUT>
         this.dataStreamPythonFunctionInfo = dataStreamPythonFunctionInfo;
         this.stateKeyType = keyType;
         this.keySelector = keySelector;
+        this.delegateOperatorFactory = SimpleOperatorFactory.of(new 
DelegateOperator<>());
         updateManagedMemoryStateBackendUseCase(true);
     }
 
@@ -83,4 +87,9 @@ public class PythonKeyedBroadcastStateTransformation<OUT>
     public KeySelector<Row, Row> getKeySelector() {
         return keySelector;
     }
+
+    @Override
+    public SimpleOperatorFactory<OUT> getOperatorFactory() {
+        return delegateOperatorFactory;
+    }
 }
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
index 32d9a259754..6bd777c0afa 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
@@ -21,11 +21,12 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.python.PythonOptions;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
-import org.apache.flink.streaming.api.operators.StreamOperator;
+import 
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
 import 
org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonBatchCoBroadcastProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonCoProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.process.ExternalPythonBatchCoBroadcastProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.process.ExternalPythonCoProcessOperator;
+import 
org.apache.flink.streaming.api.transformations.python.DelegateOperatorTransformation;
 import 
org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
 import 
org.apache.flink.streaming.runtime.translators.AbstractTwoInputTransformationTranslator;
 import org.apache.flink.util.Preconditions;
@@ -52,7 +53,7 @@ public class 
PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
 
         Configuration config = transformation.getConfiguration();
 
-        StreamOperator<OUT> operator;
+        AbstractPythonFunctionOperator<OUT> operator;
 
         if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
             operator =
@@ -72,6 +73,8 @@ public class 
PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
                             transformation.getOutputType());
         }
 
+        DelegateOperatorTransformation.configureOperator(transformation, 
operator);
+
         return translateInternal(
                 transformation,
                 transformation.getRegularInput(),
@@ -91,7 +94,7 @@ public class 
PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
 
         Configuration config = transformation.getConfiguration();
 
-        StreamOperator<OUT> operator;
+        AbstractPythonFunctionOperator<OUT> operator;
 
         if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
             operator =
@@ -112,6 +115,8 @@ public class 
PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
                             transformation.getOutputType());
         }
 
+        DelegateOperatorTransformation.configureOperator(transformation, 
operator);
+
         return translateInternal(
                 transformation,
                 transformation.getRegularInput(),
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
index cdbf89c1420..9fac56d246f 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
@@ -21,11 +21,12 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.python.PythonOptions;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
-import org.apache.flink.streaming.api.operators.StreamOperator;
+import 
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
 import 
org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonBatchKeyedCoBroadcastProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedCoProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.process.ExternalPythonBatchKeyedCoBroadcastProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.process.ExternalPythonKeyedCoProcessOperator;
+import 
org.apache.flink.streaming.api.transformations.python.DelegateOperatorTransformation;
 import 
org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
 import 
org.apache.flink.streaming.runtime.translators.AbstractTwoInputTransformationTranslator;
 import org.apache.flink.types.Row;
@@ -53,7 +54,7 @@ public class 
PythonKeyedBroadcastStateTransformationTranslator<OUT>
 
         Configuration config = transformation.getConfiguration();
 
-        StreamOperator<OUT> operator;
+        AbstractPythonFunctionOperator<OUT> operator;
 
         if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
             operator =
@@ -73,6 +74,8 @@ public class 
PythonKeyedBroadcastStateTransformationTranslator<OUT>
                             transformation.getOutputType());
         }
 
+        DelegateOperatorTransformation.configureOperator(transformation, 
operator);
+
         return translateInternal(
                 transformation,
                 transformation.getRegularInput(),
@@ -92,7 +95,7 @@ public class 
PythonKeyedBroadcastStateTransformationTranslator<OUT>
 
         Configuration config = transformation.getConfiguration();
 
-        StreamOperator<OUT> operator;
+        AbstractPythonFunctionOperator<OUT> operator;
 
         if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
             operator =
@@ -113,6 +116,8 @@ public class 
PythonKeyedBroadcastStateTransformationTranslator<OUT>
                             transformation.getOutputType());
         }
 
+        DelegateOperatorTransformation.configureOperator(transformation, 
operator);
+
         return translateInternal(
                 transformation,
                 transformation.getRegularInput(),

Reply via email to