This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.13 by this push:
     new 9f40251  [FLINK-22348][python] Fix DataStream.execute_and_collect 
which doesn't declare managed memory for Python operators
9f40251 is described below

commit 9f4025123d7f9da2d05cafea5b7cd0f186dab467
Author: huangxingbo <[email protected]>
AuthorDate: Mon Apr 19 16:10:23 2021 +0800

    [FLINK-22348][python] Fix DataStream.execute_and_collect which doesn't 
declare managed memory for Python operators
    
    This closes #15665.
---
 flink-python/pyflink/datastream/data_stream.py     |   2 +
 .../pyflink/datastream/tests/test_data_stream.py   |  10 +-
 .../tests/test_stream_execution_environment.py     |  42 ++-
 flink-python/pyflink/table/table_environment.py    |   4 +-
 .../table/tests/test_table_environment_api.py      |   2 +-
 flink-python/pyflink/util/java_utils.py            |   8 +-
 .../java/org/apache/flink/python/PythonConfig.java |   8 +-
 .../apache/flink/python/util/PythonConfigUtil.java | 301 +++++++++++----------
 8 files changed, 218 insertions(+), 159 deletions(-)

diff --git a/flink-python/pyflink/datastream/data_stream.py 
b/flink-python/pyflink/datastream/data_stream.py
index 6cebafd..ce85de9 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -641,6 +641,8 @@ class DataStream(object):
         :param job_execution_name: The name of the job execution.
         :param limit: The limit for the collected elements.
         """
+        JPythonConfigUtil = 
get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil
+        
JPythonConfigUtil.configPythonOperator(self._j_data_stream.getExecutionEnvironment())
         if job_execution_name is None and limit is None:
             return CloseableIterator(self._j_data_stream.executeAndCollect(), 
self.get_type())
         elif job_execution_name is not None and limit is None:
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 9d0543a..994cd60 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -387,11 +387,11 @@ class DataStreamTests(object):
                       decimal.Decimal('2000000000000000000.061111111111111'
                                       '11111111111111'))]
         expected = test_data
-        ds = self.env.from_collection(test_data)
+        ds = self.env.from_collection(test_data).map(lambda a: a)
         with ds.execute_and_collect() as results:
-            actual = []
-            for result in results:
-                actual.append(result)
+            actual = [result for result in results]
+            actual.sort()
+            expected.sort()
             self.assertEqual(expected, actual)
 
     def test_key_by_map(self):
@@ -942,7 +942,7 @@ class StreamingModeDataStreamTests(DataStreamTests, 
PyFlinkStreamingTestCase):
         expected_num_partitions = 5
 
         def my_partitioner(key, num_partitions):
-            assert expected_num_partitions, num_partitions
+            assert expected_num_partitions == num_partitions
             return key % num_partitions
 
         partitioned_stream = ds.map(lambda x: x, 
output_type=Types.ROW([Types.STRING(),
diff --git 
a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py 
b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
index ffc54ae..5df261e 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
@@ -38,9 +38,11 @@ from pyflink.datastream.tests.test_util import 
DataStreamTestSinkFunction
 from pyflink.find_flink_home import _find_flink_source_root
 from pyflink.java_gateway import get_gateway
 from pyflink.pyflink_gateway_server import on_windows
-from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, 
StreamTableEnvironment
+from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, 
StreamTableEnvironment, \
+    EnvironmentSettings
 from pyflink.testing.test_case_utils import PyFlinkTestCase, 
exec_insert_table, \
     invoke_java_object_method
+from pyflink.util.java_utils import get_j_env_configuration
 
 
 class StreamExecutionEnvironmentTests(PyFlinkTestCase):
@@ -337,20 +339,48 @@ class StreamExecutionEnvironmentTests(PyFlinkTestCase):
         import uuid
         python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + 
str(uuid.uuid4()))
         os.mkdir(python_file_dir)
-        python_file_path = os.path.join(python_file_dir, 
"test_stream_dependency_manage_lib.py")
+        python_file_path = os.path.join(python_file_dir, "test_dep1.py")
         with open(python_file_path, 'w') as f:
             f.write("def add_two(a):\n    return a + 2")
 
         def plus_two_map(value):
-            from test_stream_dependency_manage_lib import add_two
+            from test_dep1 import add_two
             return add_two(value)
 
+        get_j_env_configuration(self.env._j_stream_execution_environment).\
+            setString("taskmanager.numberOfTaskSlots", "10")
         self.env.add_python_file(python_file_path)
         ds = self.env.from_collection([1, 2, 3, 4, 5])
-        ds.map(plus_two_map).add_sink(self.test_sink)
-        self.env.execute("test add python file")
+        ds = ds.map(plus_two_map, Types.LONG()) \
+               .slot_sharing_group("data_stream") \
+               .map(lambda i: i, Types.LONG()) \
+               .slot_sharing_group("table")
+
+        python_file_path = os.path.join(python_file_dir, "test_dep2.py")
+        with open(python_file_path, 'w') as f:
+            f.write("def add_three(a):\n    return a + 3")
+
+        def plus_three(value):
+            from test_dep2 import add_three
+            return add_three(value)
+
+        t_env = StreamTableEnvironment.create(
+            stream_execution_environment=self.env,
+            
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
+        self.env.add_python_file(python_file_path)
+
+        from pyflink.table.udf import udf
+        from pyflink.table.expressions import col
+        add_three = udf(plus_three, result_type=DataTypes.BIGINT())
+
+        tab = t_env.from_data_stream(ds, 'a') \
+                   .select(add_three(col('a')))
+        t_env.to_append_stream(tab, Types.ROW([Types.LONG()])) \
+             .map(lambda i: i[0]) \
+             .add_sink(self.test_sink)
+        self.env.execute("test add_python_file")
         result = self.test_sink.get_results(True)
-        expected = ['3', '4', '5', '6', '7']
+        expected = ['6', '7', '8', '9', '10']
         result.sort()
         expected.sort()
         self.assertEqual(expected, result)
diff --git a/flink-python/pyflink/table/table_environment.py 
b/flink-python/pyflink/table/table_environment.py
index ee43076..2114dad 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -1548,7 +1548,7 @@ class TableEnvironment(object):
 
     def _set_python_executable_for_local_executor(self):
         jvm = get_gateway().jvm
-        j_config = get_j_env_configuration(self)
+        j_config = get_j_env_configuration(self._get_j_env())
         if not j_config.containsKey(jvm.PythonOptions.PYTHON_EXECUTABLE.key()) 
\
                 and is_local_deployment(j_config):
             j_config.setString(jvm.PythonOptions.PYTHON_EXECUTABLE.key(), 
sys.executable)
@@ -1559,7 +1559,7 @@ class TableEnvironment(object):
         if jar_urls is not None:
             # normalize and remove duplicates
             jar_urls_set = set([jvm.java.net.URL(url).toString() for url in 
jar_urls.split(";")])
-            j_configuration = get_j_env_configuration(self)
+            j_configuration = get_j_env_configuration(self._get_j_env())
             if j_configuration.containsKey(config_key):
                 for url in j_configuration.getString(config_key, 
"").split(";"):
                     jar_urls_set.add(url)
diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py 
b/flink-python/pyflink/table/tests/test_table_environment_api.py
index 7c6778c..3774178 100644
--- a/flink-python/pyflink/table/tests/test_table_environment_api.py
+++ b/flink-python/pyflink/table/tests/test_table_environment_api.py
@@ -51,7 +51,7 @@ class TableEnvironmentTest(object):
 
     def test_set_sys_executable_for_local_mode(self):
         jvm = get_gateway().jvm
-        actual_executable = get_j_env_configuration(self.t_env) \
+        actual_executable = get_j_env_configuration(self.t_env._get_j_env()) \
             .getString(jvm.PythonOptions.PYTHON_EXECUTABLE.key(), None)
         self.assertEqual(sys.executable, actual_executable)
 
diff --git a/flink-python/pyflink/util/java_utils.py 
b/flink-python/pyflink/util/java_utils.py
index 3ffe223..8ea8d9b 100644
--- a/flink-python/pyflink/util/java_utils.py
+++ b/flink-python/pyflink/util/java_utils.py
@@ -80,12 +80,12 @@ def is_instance_of(java_object, java_class):
         param, java_object)
 
 
-def get_j_env_configuration(t_env):
-    if is_instance_of(t_env._get_j_env(), 
"org.apache.flink.api.java.ExecutionEnvironment"):
-        return t_env._get_j_env().getConfiguration()
+def get_j_env_configuration(j_env):
+    if is_instance_of(j_env, "org.apache.flink.api.java.ExecutionEnvironment"):
+        return j_env.getConfiguration()
     else:
         return invoke_method(
-            t_env._get_j_env(),
+            j_env,
             
"org.apache.flink.streaming.api.environment.StreamExecutionEnvironment",
             "getConfiguration"
         )
diff --git 
a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java 
b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
index d8d4144..9fa8383 100644
--- a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
+++ b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
@@ -84,10 +84,10 @@ public class PythonConfig implements Serializable {
     private final boolean isUsingManagedMemory;
 
     /** The Configuration that contains execution configs and dependencies 
info. */
-    private final Configuration mergedConfig;
+    private final Configuration config;
 
     public PythonConfig(Configuration config) {
-        mergedConfig = config;
+        this.config = config;
         maxBundleSize = config.get(PythonOptions.MAX_BUNDLE_SIZE);
         maxBundleTimeMills = config.get(PythonOptions.MAX_BUNDLE_TIME_MILLS);
         maxArrowBatchSize = config.get(PythonOptions.MAX_ARROW_BATCH_SIZE);
@@ -148,7 +148,7 @@ public class PythonConfig implements Serializable {
         return isUsingManagedMemory;
     }
 
-    public Configuration getMergedConfig() {
-        return mergedConfig;
+    public Configuration getConfig() {
+        return config;
     }
 }
diff --git 
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java 
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 6681c1f..cc65713 100644
--- 
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++ 
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -30,20 +30,16 @@ import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.python.PythonConfig;
 import org.apache.flink.python.PythonOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamGraph;
-import org.apache.flink.streaming.api.graph.StreamNode;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
-import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import 
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
 import 
org.apache.flink.streaming.api.operators.python.OneInputPythonFunctionOperator;
-import 
org.apache.flink.streaming.api.operators.python.PythonKeyedProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.PythonPartitionCustomOperator;
 import 
org.apache.flink.streaming.api.operators.python.PythonTimestampsAndWatermarksOperator;
-import 
org.apache.flink.streaming.api.operators.python.TwoInputPythonFunctionOperator;
 import 
org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation;
 import org.apache.flink.streaming.api.transformations.OneInputTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
 import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
 import org.apache.flink.streaming.api.transformations.WithBoundedness;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
@@ -53,7 +49,6 @@ import org.apache.flink.table.api.TableException;
 import java.lang.reflect.Field;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
-import java.util.Collection;
 import java.util.List;
 
 /**
@@ -102,37 +97,6 @@ public class PythonConfigUtil {
         return (Configuration) getConfigurationMethod.invoke(env);
     }
 
-    /**
-     * Configure the {@link OneInputPythonFunctionOperator} to be chained with 
the
-     * upstream/downstream operator by setting their parallelism, slot sharing 
group, co-location
-     * group to be the same, and applying a {@link ForwardPartitioner}. 1. 
operator with name
-     * "_keyed_stream_values_operator" should align with its downstream 
operator. 2. operator with
-     * name "_stream_key_by_map_operator" should align with its upstream 
operator.
-     */
-    private static void alignStreamNode(StreamNode streamNode, StreamGraph 
streamGraph) {
-        if 
(streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
-            StreamEdge downStreamEdge = streamNode.getOutEdges().get(0);
-            StreamNode downStreamNode = 
streamGraph.getStreamNode(downStreamEdge.getTargetId());
-            chainStreamNode(downStreamEdge, streamNode, downStreamNode);
-            downStreamEdge.setPartitioner(new ForwardPartitioner());
-        }
-
-        if 
(streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
-                || 
streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) 
{
-            StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
-            StreamNode upStreamNode = 
streamGraph.getStreamNode(upStreamEdge.getSourceId());
-            chainStreamNode(upStreamEdge, streamNode, upStreamNode);
-        }
-    }
-
-    private static void chainStreamNode(
-            StreamEdge streamEdge, StreamNode firstStream, StreamNode 
secondStream) {
-        streamEdge.setPartitioner(new ForwardPartitioner<>());
-        firstStream.setParallelism(secondStream.getParallelism());
-        firstStream.setCoLocationGroup(secondStream.getCoLocationGroup());
-        firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup());
-    }
-
     /** Set Python Operator Use Managed Memory. */
     public static void declareManagedMemory(
             Transformation<?> transformation,
@@ -144,16 +108,6 @@ public class PythonConfigUtil {
         }
     }
 
-    private static void declareManagedMemory(Transformation<?> transformation) 
{
-        if (isPythonOperator(transformation)) {
-            
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
-        }
-        List<Transformation<?>> inputTransformations = 
transformation.getInputs();
-        for (Transformation inputTransformation : inputTransformations) {
-            declareManagedMemory(inputTransformation);
-        }
-    }
-
     /**
      * Generate a {@link StreamGraph} for transformations maintained by 
current {@link
      * StreamExecutionEnvironment}, and reset the merged env configurations 
with dependencies to
@@ -165,98 +119,174 @@ public class PythonConfigUtil {
             StreamExecutionEnvironment env, boolean clearTransformations)
             throws IllegalAccessException, NoSuchMethodException, 
InvocationTargetException,
                     NoSuchFieldException {
-        Configuration mergedConfig = getEnvConfigWithDependencies(env);
-
-        boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
-
-        if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) {
-            Field transformationsField =
-                    
StreamExecutionEnvironment.class.getDeclaredField("transformations");
-            transformationsField.setAccessible(true);
-            for (Transformation transform :
-                    (List<Transformation<?>>) transformationsField.get(env)) {
-                if (isPythonOperator(transform)) {
-                    
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
-                }
-            }
-        }
+        configPythonOperator(env);
 
         String jobName =
                 getEnvironmentConfig(env)
                         .getString(
                                 PipelineOptions.NAME, 
StreamExecutionEnvironment.DEFAULT_JOB_NAME);
-        StreamGraph streamGraph = env.getStreamGraph(jobName, 
clearTransformations);
-        Collection<StreamNode> streamNodes = streamGraph.getStreamNodes();
-        for (StreamNode streamNode : streamNodes) {
-            alignStreamNode(streamNode, streamGraph);
-            StreamOperatorFactory streamOperatorFactory = 
streamNode.getOperatorFactory();
-            if (streamOperatorFactory instanceof SimpleOperatorFactory) {
-                StreamOperator streamOperator =
-                        ((SimpleOperatorFactory) 
streamOperatorFactory).getOperator();
-                if ((streamOperator instanceof OneInputPythonFunctionOperator)
-                        || (streamOperator instanceof 
TwoInputPythonFunctionOperator)
-                        || (streamOperator instanceof 
PythonKeyedProcessOperator)) {
-                    AbstractPythonFunctionOperator pythonFunctionOperator =
-                            (AbstractPythonFunctionOperator) streamOperator;
-
-                    Configuration oldConfig =
-                            
pythonFunctionOperator.getPythonConfig().getMergedConfig();
+        return env.getStreamGraph(jobName, clearTransformations);
+    }
+
+    @SuppressWarnings("unchecked")
+    public static void configPythonOperator(StreamExecutionEnvironment env)
+            throws IllegalAccessException, NoSuchMethodException, 
InvocationTargetException,
+                    NoSuchFieldException {
+        Configuration mergedConfig = getEnvConfigWithDependencies(env);
+
+        boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
+
+        Field transformationsField =
+                
StreamExecutionEnvironment.class.getDeclaredField("transformations");
+        transformationsField.setAccessible(true);
+        List<Transformation<?>> transformations =
+                (List<Transformation<?>>) transformationsField.get(env);
+        for (Transformation<?> transformation : transformations) {
+            alignTransformation(transformation);
+
+            if (isPythonOperator(transformation)) {
+                // declare it is a Python operator
+                
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
+
+                AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+                        getPythonOperator(transformation);
+                if (pythonFunctionOperator != null) {
+                    Configuration oldConfig = 
pythonFunctionOperator.getPythonConfig().getConfig();
+                    // update dependency related configurations for Python 
operators
                     pythonFunctionOperator.setPythonConfig(
                             generateNewPythonConfig(oldConfig, mergedConfig));
 
-                    if (streamOperator instanceof 
PythonTimestampsAndWatermarksOperator) {
-                        ((PythonTimestampsAndWatermarksOperator) 
streamOperator)
+                    // set the emitProgressiveWatermarks flag for
+                    // PythonTimestampsAndWatermarksOperator
+                    if (pythonFunctionOperator instanceof 
PythonTimestampsAndWatermarksOperator) {
+                        ((PythonTimestampsAndWatermarksOperator<?>) 
pythonFunctionOperator)
                                 
.configureEmitProgressiveWatermarks(!executedInBatchMode);
                     }
                 }
             }
         }
 
-        setStreamPartitionCustomOperatorNumPartitions(streamNodes, 
streamGraph);
+        setPartitionCustomOperatorNumPartitions(transformations);
+    }
 
-        return streamGraph;
+    public static Configuration getMergedConfig(
+            StreamExecutionEnvironment env, TableConfig tableConfig) {
+        try {
+            Configuration config = new 
Configuration(getEnvironmentConfig(env));
+            PythonDependencyUtils.merge(config, 
tableConfig.getConfiguration());
+            Configuration mergedConfig =
+                    
PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config);
+            mergedConfig.setString("table.exec.timezone", 
tableConfig.getLocalTimeZone().getId());
+            return mergedConfig;
+        } catch (IllegalAccessException | NoSuchMethodException | 
InvocationTargetException e) {
+            throw new TableException("Method getMergedConfig failed.", e);
+        }
     }
 
-    private static boolean isPythonOperator(StreamOperatorFactory 
streamOperatorFactory) {
-        if (streamOperatorFactory instanceof SimpleOperatorFactory) {
-            return ((SimpleOperatorFactory) 
streamOperatorFactory).getOperator()
-                    instanceof AbstractPythonFunctionOperator;
-        } else {
-            return false;
+    @SuppressWarnings("unchecked")
+    public static Configuration getMergedConfig(ExecutionEnvironment env, 
TableConfig tableConfig) {
+        try {
+            Field field = 
ExecutionEnvironment.class.getDeclaredField("cacheFile");
+            field.setAccessible(true);
+            Configuration config = new Configuration(env.getConfiguration());
+            PythonDependencyUtils.merge(config, 
tableConfig.getConfiguration());
+            Configuration mergedConfig =
+                    PythonDependencyUtils.configurePythonDependencies(
+                            (List<Tuple2<String, 
DistributedCache.DistributedCacheEntry>>)
+                                    field.get(env),
+                            config);
+            mergedConfig.setString("table.exec.timezone", 
tableConfig.getLocalTimeZone().getId());
+            return mergedConfig;
+        } catch (NoSuchFieldException | IllegalAccessException e) {
+            throw new TableException("Method getMergedConfig failed.", e);
+        }
+    }
+
+    /**
+     * Configure the {@link OneInputPythonFunctionOperator} to be chained with 
the
+     * upstream/downstream operator by setting their parallelism, slot sharing 
group, co-location
+     * group to be the same, and applying a {@link ForwardPartitioner}. 1. 
operator with name
+     * "_keyed_stream_values_operator" should align with its downstream 
operator. 2. operator with
+     * name "_stream_key_by_map_operator" should align with its upstream 
operator.
+     */
+    private static void alignTransformation(Transformation<?> transformation)
+            throws NoSuchFieldException, IllegalAccessException {
+        String transformName = transformation.getName();
+        Transformation<?> inputTransformation = 
transformation.getInputs().get(0);
+        String inputTransformName = inputTransformation.getName();
+        if (inputTransformName.equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
+            chainTransformation(inputTransformation, transformation);
+            configForwardPartitioner(inputTransformation, transformation);
+        }
+        if (transformName.equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
+                || 
transformName.equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
+            chainTransformation(transformation, inputTransformation);
+            configForwardPartitioner(inputTransformation, transformation);
+        }
+    }
+
+    private static void chainTransformation(
+            Transformation<?> firstTransformation, Transformation<?> 
secondTransformation) {
+        
firstTransformation.setSlotSharingGroup(secondTransformation.getSlotSharingGroup());
+        
firstTransformation.setCoLocationGroupKey(secondTransformation.getCoLocationGroupKey());
+        
firstTransformation.setParallelism(secondTransformation.getParallelism());
+    }
+
+    private static void configForwardPartitioner(
+            Transformation<?> upTransformation, Transformation<?> 
transformation)
+            throws IllegalAccessException, NoSuchFieldException {
+        // set ForwardPartitioner
+        PartitionTransformation<?> partitionTransform =
+                new PartitionTransformation<>(upTransformation, new 
ForwardPartitioner<>());
+        Field inputTransformationField = 
transformation.getClass().getDeclaredField("input");
+        inputTransformationField.setAccessible(true);
+        inputTransformationField.set(transformation, partitionTransform);
+    }
+
+    private static AbstractPythonFunctionOperator<?> getPythonOperator(
+            Transformation<?> transformation) {
+        StreamOperatorFactory<?> operatorFactory = null;
+        if (transformation instanceof OneInputTransformation) {
+            operatorFactory = ((OneInputTransformation<?, ?>) 
transformation).getOperatorFactory();
+        } else if (transformation instanceof TwoInputTransformation) {
+            operatorFactory =
+                    ((TwoInputTransformation<?, ?, ?>) 
transformation).getOperatorFactory();
+        } else if (transformation instanceof 
AbstractMultipleInputTransformation) {
+            operatorFactory =
+                    ((AbstractMultipleInputTransformation<?>) 
transformation).getOperatorFactory();
+        }
+
+        if (operatorFactory instanceof SimpleOperatorFactory
+                && ((SimpleOperatorFactory<?>) operatorFactory).getOperator()
+                        instanceof AbstractPythonFunctionOperator) {
+            return (AbstractPythonFunctionOperator<?>)
+                    ((SimpleOperatorFactory<?>) operatorFactory).getOperator();
         }
+
+        return null;
     }
 
     private static boolean isPythonOperator(Transformation<?> transform) {
         if (transform instanceof OneInputTransformation) {
-            return isPythonOperator(((OneInputTransformation) 
transform).getOperatorFactory());
+            return isPythonOperator(
+                    ((OneInputTransformation<?, ?>) 
transform).getOperatorFactory());
         } else if (transform instanceof TwoInputTransformation) {
-            return isPythonOperator(((TwoInputTransformation) 
transform).getOperatorFactory());
+            return isPythonOperator(
+                    ((TwoInputTransformation<?, ?, ?>) 
transform).getOperatorFactory());
         } else if (transform instanceof AbstractMultipleInputTransformation) {
             return isPythonOperator(
-                    ((AbstractMultipleInputTransformation) 
transform).getOperatorFactory());
+                    ((AbstractMultipleInputTransformation<?>) 
transform).getOperatorFactory());
         } else {
             return false;
         }
     }
 
-    private static void setStreamPartitionCustomOperatorNumPartitions(
-            Collection<StreamNode> streamNodes, StreamGraph streamGraph) {
-        for (StreamNode streamNode : streamNodes) {
-            StreamOperatorFactory streamOperatorFactory = 
streamNode.getOperatorFactory();
-            if (streamOperatorFactory instanceof SimpleOperatorFactory) {
-                StreamOperator streamOperator =
-                        ((SimpleOperatorFactory) 
streamOperatorFactory).getOperator();
-                if (streamOperator instanceof PythonPartitionCustomOperator) {
-                    PythonPartitionCustomOperator 
partitionCustomFunctionOperator =
-                            (PythonPartitionCustomOperator) streamOperator;
-                    // Update the numPartitions of PartitionCustomOperator 
after aligned all
-                    // operators.
-                    partitionCustomFunctionOperator.setNumPartitions(
-                            streamGraph
-                                    
.getStreamNode(streamNode.getOutEdges().get(0).getTargetId())
-                                    .getParallelism());
-                }
-            }
+    private static boolean isPythonOperator(StreamOperatorFactory<?> 
streamOperatorFactory) {
+        if (streamOperatorFactory instanceof SimpleOperatorFactory) {
+            return ((SimpleOperatorFactory<?>) 
streamOperatorFactory).getOperator()
+                    instanceof AbstractPythonFunctionOperator;
+        } else {
+            return false;
         }
     }
 
@@ -285,7 +315,8 @@ public class PythonConfigUtil {
                 
StreamExecutionEnvironment.class.getDeclaredField("transformations");
         transformationsField.setAccessible(true);
         boolean existsUnboundedSource = false;
-        for (Transformation transform : (List<Transformation<?>>) 
transformationsField.get(env)) {
+        for (Transformation<?> transform :
+                (List<Transformation<?>>) transformationsField.get(env)) {
             existsUnboundedSource =
                     existsUnboundedSource
                             || (transform instanceof WithBoundedness
@@ -295,36 +326,32 @@ public class PythonConfigUtil {
         return !existsUnboundedSource;
     }
 
-    public static Configuration getMergedConfig(
-            StreamExecutionEnvironment env, TableConfig tableConfig) {
-        try {
-            Configuration config = new 
Configuration(getEnvironmentConfig(env));
-            PythonDependencyUtils.merge(config, 
tableConfig.getConfiguration());
-            Configuration mergedConfig =
-                    
PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config);
-            mergedConfig.setString("table.exec.timezone", 
tableConfig.getLocalTimeZone().getId());
-            return mergedConfig;
-        } catch (IllegalAccessException | NoSuchMethodException | 
InvocationTargetException e) {
-            throw new TableException("Method getMergedConfig failed.", e);
+    private static void declareManagedMemory(Transformation<?> transformation) 
{
+        if (isPythonOperator(transformation)) {
+            
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
+        }
+
+        for (Transformation<?> inputTransformation : 
transformation.getInputs()) {
+            declareManagedMemory(inputTransformation);
         }
     }
 
-    @SuppressWarnings("unchecked")
-    public static Configuration getMergedConfig(ExecutionEnvironment env, 
TableConfig tableConfig) {
-        try {
-            Field field = 
ExecutionEnvironment.class.getDeclaredField("cacheFile");
-            field.setAccessible(true);
-            Configuration config = new Configuration(env.getConfiguration());
-            PythonDependencyUtils.merge(config, 
tableConfig.getConfiguration());
-            Configuration mergedConfig =
-                    PythonDependencyUtils.configurePythonDependencies(
-                            (List<Tuple2<String, 
DistributedCache.DistributedCacheEntry>>)
-                                    field.get(env),
-                            config);
-            mergedConfig.setString("table.exec.timezone", 
tableConfig.getLocalTimeZone().getId());
-            return mergedConfig;
-        } catch (NoSuchFieldException | IllegalAccessException e) {
-            throw new TableException("Method getMergedConfig failed.", e);
+    private static void setPartitionCustomOperatorNumPartitions(
+            List<Transformation<?>> transformations) {
+        // Update the numPartitions of PartitionCustomOperator after aligned 
all operators.
+        for (Transformation<?> transformation : transformations) {
+            Transformation<?> firstInputTransformation = 
transformation.getInputs().get(0);
+            if (firstInputTransformation instanceof PartitionTransformation) {
+                firstInputTransformation = 
firstInputTransformation.getInputs().get(0);
+            }
+            AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+                    getPythonOperator(firstInputTransformation);
+            if (pythonFunctionOperator instanceof 
PythonPartitionCustomOperator) {
+                PythonPartitionCustomOperator<?, ?> 
partitionCustomFunctionOperator =
+                        (PythonPartitionCustomOperator<?, ?>) 
pythonFunctionOperator;
+
+                
partitionCustomFunctionOperator.setNumPartitions(transformation.getParallelism());
+            }
         }
     }
 }

Reply via email to