This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch release-1.12
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.12 by this push:
     new 91d9004  [FLINK-23133][python] Properly handle the dependencies when 
mixing use of Python Table API and Python DataStream API
91d9004 is described below

commit 91d9004a4b41f3f7c499d410c1dd09299dd44b25
Author: Dian Fu <[email protected]>
AuthorDate: Thu Jun 24 15:26:11 2021 +0800

    [FLINK-23133][python] Properly handle the dependencies when mixing use of 
Python Table API and Python DataStream API
    
    This closes #16272.
---
 docs/dev/python/dependency_management.md           |   5 +
 docs/dev/python/dependency_management.zh.md        |   5 +
 .../tests/test_stream_execution_environment.py     |  43 ++++-
 flink-python/pyflink/table/table_environment.py    |   5 +-
 .../apache/flink/python/util/PythonConfigUtil.java | 209 ++++++++++++---------
 5 files changed, 168 insertions(+), 99 deletions(-)

diff --git a/docs/dev/python/dependency_management.md 
b/docs/dev/python/dependency_management.md
index 5174af7..813878e 100644
--- a/docs/dev/python/dependency_management.md
+++ b/docs/dev/python/dependency_management.md
@@ -30,6 +30,11 @@ the local Python environment, download the machine learning 
model to local, etc.
 However, this approach doesn't work well when users want to submit the PyFlink 
jobs to remote clusters.
 In the following sections, we will introduce the options provided in PyFlink 
for these requirements.
 
+<span class="label label-info">Note</span> Both Python DataStream API and 
Python Table API have provided
+APIs for each kind of dependency. If you are mixing use of Python DataStream 
API and Python Table API
+in a single job, you should specify the dependencies via Python DataStream API 
to make them work for
+both the Python DataStream API and Python Table API.
+
 * This will be replaced by the TOC
 {:toc}
 
diff --git a/docs/dev/python/dependency_management.zh.md 
b/docs/dev/python/dependency_management.zh.md
index 9da533e..8e40d11 100644
--- a/docs/dev/python/dependency_management.zh.md
+++ b/docs/dev/python/dependency_management.zh.md
@@ -30,6 +30,11 @@ the local Python environment, download the machine learning 
model to local, etc.
 However, this approach doesn't work well when users want to submit the PyFlink 
jobs to remote clusters.
 In the following sections, we will introduce the options provided in PyFlink 
for these requirements.
 
+<span class="label label-info">Note</span> Both Python DataStream API and 
Python Table API have provided
+APIs for each kind of dependency. If you are mixing use of Python DataStream 
API and Python Table API
+in a single job, you should specify the dependencies via Python DataStream API 
to make them work for
+both the Python DataStream API and Python Table API.
+
 * This will be replaced by the TOC
 {:toc}
 
diff --git 
a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py 
b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
index f4756ed..34677f9 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
@@ -37,7 +37,8 @@ from pyflink.datastream.tests.test_util import 
DataStreamTestSinkFunction
 from pyflink.find_flink_home import _find_flink_source_root
 from pyflink.java_gateway import get_gateway
 from pyflink.pyflink_gateway_server import on_windows
-from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, 
StreamTableEnvironment
+from pyflink.table import DataTypes, CsvTableSource, CsvTableSink, 
StreamTableEnvironment, \
+    EnvironmentSettings
 from pyflink.testing.test_case_utils import PyFlinkTestCase, exec_insert_table
 
 
@@ -350,6 +351,46 @@ class StreamExecutionEnvironmentTests(PyFlinkTestCase):
         expected.sort()
         self.assertEqual(expected, result)
 
+    def test_add_python_file_2(self):
+        import uuid
+        python_file_dir = os.path.join(self.tempdir, "python_file_dir_" + 
str(uuid.uuid4()))
+        os.mkdir(python_file_dir)
+        python_file_path = os.path.join(python_file_dir, "test_dep1.py")
+        with open(python_file_path, 'w') as f:
+            f.write("def add_two(a):\n    return a + 2")
+
+        def plus_two_map(value):
+            from test_dep1 import add_two
+            return add_two(value)
+
+        self.env.add_python_file(python_file_path)
+        ds = self.env.from_collection([1, 2, 3, 4, 5])
+        ds = ds.map(plus_two_map, Types.LONG())
+        python_file_path = os.path.join(python_file_dir, "test_dep2.py")
+        with open(python_file_path, 'w') as f:
+            f.write("def add_three(a):\n    return a + 3")
+
+        def plus_three(value):
+            from test_dep2 import add_three
+            return add_three(value)
+
+        t_env = StreamTableEnvironment.create(
+            stream_execution_environment=self.env,
+            
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
+        self.env.add_python_file(python_file_path)
+
+        from pyflink.table.udf import udf
+        from pyflink.table.expressions import col
+        add_three = udf(plus_three, result_type=DataTypes.BIGINT())
+
+        tab = t_env.from_data_stream(ds, 'a') \
+                   .select(add_three(col('a')))
+        result = [i[0] for i in tab.execute().collect()]
+        expected = [6, 7, 8, 9, 10]
+        result.sort()
+        expected.sort()
+        self.assertEqual(expected, result)
+
     def test_set_requirements_without_cached_directory(self):
         import uuid
         requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
diff --git a/flink-python/pyflink/table/table_environment.py 
b/flink-python/pyflink/table/table_environment.py
index 0938069..fd7b9e5 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -1731,10 +1731,7 @@ class StreamTableEnvironment(TableEnvironment):
         """
         j_data_stream = data_stream._j_data_stream
         JPythonConfigUtil = 
get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil
-        JPythonConfigUtil.declareManagedMemory(
-            j_data_stream.getTransformation(),
-            self._get_j_env(),
-            self._j_tenv.getConfig())
+        
JPythonConfigUtil.configPythonOperator(j_data_stream.getExecutionEnvironment())
         if len(fields) == 0:
             return Table(j_table=self._j_tenv.fromDataStream(j_data_stream), 
t_env=self)
         elif all(isinstance(f, Expression) for f in fields):
diff --git 
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java 
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 78818c3..49960d8 100644
--- 
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++ 
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -30,20 +30,16 @@ import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.python.PythonConfig;
 import org.apache.flink.python.PythonOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamGraph;
-import org.apache.flink.streaming.api.graph.StreamNode;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
-import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import 
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
 import 
org.apache.flink.streaming.api.operators.python.OneInputPythonFunctionOperator;
-import 
org.apache.flink.streaming.api.operators.python.PythonKeyedProcessOperator;
 import 
org.apache.flink.streaming.api.operators.python.PythonPartitionCustomOperator;
 import 
org.apache.flink.streaming.api.operators.python.PythonTimestampsAndWatermarksOperator;
-import 
org.apache.flink.streaming.api.operators.python.TwoInputPythonFunctionOperator;
 import 
org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation;
 import org.apache.flink.streaming.api.transformations.OneInputTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
 import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
 import org.apache.flink.streaming.api.transformations.WithBoundedness;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
@@ -53,7 +49,6 @@ import org.apache.flink.table.api.TableException;
 import java.lang.reflect.Field;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
-import java.util.Collection;
 import java.util.List;
 
 /**
@@ -102,37 +97,6 @@ public class PythonConfigUtil {
         return (Configuration) getConfigurationMethod.invoke(env);
     }
 
-    /**
-     * Configure the {@link OneInputPythonFunctionOperator} to be chained with 
the
-     * upstream/downstream operator by setting their parallelism, slot sharing 
group, co-location
-     * group to be the same, and applying a {@link ForwardPartitioner}. 1. 
operator with name
-     * "_keyed_stream_values_operator" should align with its downstream 
operator. 2. operator with
-     * name "_stream_key_by_map_operator" should align with its upstream 
operator.
-     */
-    private static void alignStreamNode(StreamNode streamNode, StreamGraph 
streamGraph) {
-        if 
(streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
-            StreamEdge downStreamEdge = streamNode.getOutEdges().get(0);
-            StreamNode downStreamNode = 
streamGraph.getStreamNode(downStreamEdge.getTargetId());
-            chainStreamNode(downStreamEdge, streamNode, downStreamNode);
-            downStreamEdge.setPartitioner(new ForwardPartitioner());
-        }
-
-        if 
(streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
-                || 
streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) 
{
-            StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
-            StreamNode upStreamNode = 
streamGraph.getStreamNode(upStreamEdge.getSourceId());
-            chainStreamNode(upStreamEdge, streamNode, upStreamNode);
-        }
-    }
-
-    private static void chainStreamNode(
-            StreamEdge streamEdge, StreamNode firstStream, StreamNode 
secondStream) {
-        streamEdge.setPartitioner(new ForwardPartitioner<>());
-        firstStream.setParallelism(secondStream.getParallelism());
-        firstStream.setCoLocationGroup(secondStream.getCoLocationGroup());
-        firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup());
-    }
-
     /** Set Python Operator Use Managed Memory. */
     public static void declareManagedMemory(
             Transformation<?> transformation,
@@ -165,60 +129,119 @@ public class PythonConfigUtil {
             StreamExecutionEnvironment env, boolean clearTransformations)
             throws IllegalAccessException, NoSuchMethodException, 
InvocationTargetException,
                     NoSuchFieldException {
-        Configuration mergedConfig = getEnvConfigWithDependencies(env);
-
-        boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
-        if (executedInBatchMode) {
-            throw new UnsupportedOperationException(
-                    "Batch mode is still not supported in Python DataStream 
API.");
-        }
-
-        if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) {
-            Field transformationsField =
-                    
StreamExecutionEnvironment.class.getDeclaredField("transformations");
-            transformationsField.setAccessible(true);
-            for (Transformation transform :
-                    (List<Transformation<?>>) transformationsField.get(env)) {
-                if (isPythonOperator(transform)) {
-                    
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
-                }
-            }
-        }
+        configPythonOperator(env);
 
         String jobName =
                 getEnvironmentConfig(env)
                         .getString(
                                 PipelineOptions.NAME, 
StreamExecutionEnvironment.DEFAULT_JOB_NAME);
-        StreamGraph streamGraph = env.getStreamGraph(jobName, 
clearTransformations);
-        Collection<StreamNode> streamNodes = streamGraph.getStreamNodes();
-        for (StreamNode streamNode : streamNodes) {
-            alignStreamNode(streamNode, streamGraph);
-            StreamOperatorFactory streamOperatorFactory = 
streamNode.getOperatorFactory();
-            if (streamOperatorFactory instanceof SimpleOperatorFactory) {
-                StreamOperator streamOperator =
-                        ((SimpleOperatorFactory) 
streamOperatorFactory).getOperator();
-                if ((streamOperator instanceof OneInputPythonFunctionOperator)
-                        || (streamOperator instanceof 
TwoInputPythonFunctionOperator)
-                        || (streamOperator instanceof 
PythonKeyedProcessOperator)) {
-                    AbstractPythonFunctionOperator pythonFunctionOperator =
-                            (AbstractPythonFunctionOperator) streamOperator;
+        return env.getStreamGraph(jobName, clearTransformations);
+    }
+
+    @SuppressWarnings("unchecked")
+    public static void configPythonOperator(StreamExecutionEnvironment env)
+            throws IllegalAccessException, NoSuchMethodException, 
InvocationTargetException,
+                    NoSuchFieldException {
+        Configuration mergedConfig = getEnvConfigWithDependencies(env);
+
+        boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
+
+        Field transformationsField =
+                
StreamExecutionEnvironment.class.getDeclaredField("transformations");
+        transformationsField.setAccessible(true);
+        List<Transformation<?>> transformations =
+                (List<Transformation<?>>) transformationsField.get(env);
+        for (Transformation<?> transformation : transformations) {
+            alignTransformation(transformation);
 
+            if (isPythonOperator(transformation)) {
+                // declare it is a Python operator
+                
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
+
+                AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+                        getPythonOperator(transformation);
+                if (pythonFunctionOperator != null) {
                     Configuration oldConfig =
                             
pythonFunctionOperator.getPythonConfig().getMergedConfig();
+                    // update dependency related configurations for Python 
operators
                     pythonFunctionOperator.setPythonConfig(
                             generateNewPythonConfig(oldConfig, mergedConfig));
 
-                    if (streamOperator instanceof 
PythonTimestampsAndWatermarksOperator) {
-                        ((PythonTimestampsAndWatermarksOperator) 
streamOperator)
+                    // set the emitProgressiveWatermarks flag for
+                    // PythonTimestampsAndWatermarksOperator
+                    if (pythonFunctionOperator instanceof 
PythonTimestampsAndWatermarksOperator) {
+                        ((PythonTimestampsAndWatermarksOperator<?>) 
pythonFunctionOperator)
                                 
.configureEmitProgressiveWatermarks(!executedInBatchMode);
                     }
                 }
             }
         }
 
-        setStreamPartitionCustomOperatorNumPartitions(streamNodes, 
streamGraph);
+        setPartitionCustomOperatorNumPartitions(transformations);
+    }
 
-        return streamGraph;
+    /**
+     * Configure the {@link OneInputPythonFunctionOperator} to be chained with 
the
+     * upstream/downstream operator by setting their parallelism, slot sharing 
group, co-location
+     * group to be the same, and applying a {@link ForwardPartitioner}. 1. 
operator with name
+     * "_keyed_stream_values_operator" should align with its downstream 
operator. 2. operator with
+     * name "_stream_key_by_map_operator" should align with its upstream 
operator.
+     */
+    private static void alignTransformation(Transformation<?> transformation)
+            throws NoSuchFieldException, IllegalAccessException {
+        String transformName = transformation.getName();
+        Transformation<?> inputTransformation = 
transformation.getInputs().get(0);
+        String inputTransformName = inputTransformation.getName();
+        if (inputTransformName.equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
+            chainTransformation(inputTransformation, transformation);
+            configForwardPartitioner(inputTransformation, transformation);
+        }
+        if (transformName.equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
+                || 
transformName.equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
+            chainTransformation(transformation, inputTransformation);
+            configForwardPartitioner(inputTransformation, transformation);
+        }
+    }
+
+    private static void chainTransformation(
+            Transformation<?> firstTransformation, Transformation<?> 
secondTransformation) {
+        
firstTransformation.setSlotSharingGroup(secondTransformation.getSlotSharingGroup());
+        
firstTransformation.setCoLocationGroupKey(secondTransformation.getCoLocationGroupKey());
+        
firstTransformation.setParallelism(secondTransformation.getParallelism());
+    }
+
+    private static void configForwardPartitioner(
+            Transformation<?> upTransformation, Transformation<?> 
transformation)
+            throws IllegalAccessException, NoSuchFieldException {
+        // set ForwardPartitioner
+        PartitionTransformation<?> partitionTransform =
+                new PartitionTransformation<>(upTransformation, new 
ForwardPartitioner<>());
+        Field inputTransformationField = 
transformation.getClass().getDeclaredField("input");
+        inputTransformationField.setAccessible(true);
+        inputTransformationField.set(transformation, partitionTransform);
+    }
+
+    private static AbstractPythonFunctionOperator<?> getPythonOperator(
+            Transformation<?> transformation) {
+        StreamOperatorFactory<?> operatorFactory = null;
+        if (transformation instanceof OneInputTransformation) {
+            operatorFactory = ((OneInputTransformation<?, ?>) 
transformation).getOperatorFactory();
+        } else if (transformation instanceof TwoInputTransformation) {
+            operatorFactory =
+                    ((TwoInputTransformation<?, ?, ?>) 
transformation).getOperatorFactory();
+        } else if (transformation instanceof 
AbstractMultipleInputTransformation) {
+            operatorFactory =
+                    ((AbstractMultipleInputTransformation<?>) 
transformation).getOperatorFactory();
+        }
+
+        if (operatorFactory instanceof SimpleOperatorFactory
+                && ((SimpleOperatorFactory<?>) operatorFactory).getOperator()
+                        instanceof AbstractPythonFunctionOperator) {
+            return (AbstractPythonFunctionOperator<?>)
+                    ((SimpleOperatorFactory<?>) operatorFactory).getOperator();
+        }
+
+        return null;
     }
 
     private static boolean isPythonOperator(StreamOperatorFactory 
streamOperatorFactory) {
@@ -243,27 +266,6 @@ public class PythonConfigUtil {
         }
     }
 
-    private static void setStreamPartitionCustomOperatorNumPartitions(
-            Collection<StreamNode> streamNodes, StreamGraph streamGraph) {
-        for (StreamNode streamNode : streamNodes) {
-            StreamOperatorFactory streamOperatorFactory = 
streamNode.getOperatorFactory();
-            if (streamOperatorFactory instanceof SimpleOperatorFactory) {
-                StreamOperator streamOperator =
-                        ((SimpleOperatorFactory) 
streamOperatorFactory).getOperator();
-                if (streamOperator instanceof PythonPartitionCustomOperator) {
-                    PythonPartitionCustomOperator 
partitionCustomFunctionOperator =
-                            (PythonPartitionCustomOperator) streamOperator;
-                    // Update the numPartitions of PartitionCustomOperator 
after aligned all
-                    // operators.
-                    partitionCustomFunctionOperator.setNumPartitions(
-                            streamGraph
-                                    
.getStreamNode(streamNode.getOutEdges().get(0).getTargetId())
-                                    .getParallelism());
-                }
-            }
-        }
-    }
-
     /**
      * Generator a new {@link PythonConfig} with the combined config which is 
derived from
      * oldConfig.
@@ -331,4 +333,23 @@ public class PythonConfigUtil {
             throw new TableException("Method getMergedConfig failed.", e);
         }
     }
+
+    private static void setPartitionCustomOperatorNumPartitions(
+            List<Transformation<?>> transformations) {
+        // Update the numPartitions of PartitionCustomOperator after aligned 
all operators.
+        for (Transformation<?> transformation : transformations) {
+            Transformation<?> firstInputTransformation = 
transformation.getInputs().get(0);
+            if (firstInputTransformation instanceof PartitionTransformation) {
+                firstInputTransformation = 
firstInputTransformation.getInputs().get(0);
+            }
+            AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+                    getPythonOperator(firstInputTransformation);
+            if (pythonFunctionOperator instanceof 
PythonPartitionCustomOperator) {
+                PythonPartitionCustomOperator<?, ?> 
partitionCustomFunctionOperator =
+                        (PythonPartitionCustomOperator<?, ?>) 
pythonFunctionOperator;
+
+                
partitionCustomFunctionOperator.setNumPartitions(transformation.getParallelism());
+            }
+        }
+    }
 }

Reply via email to