This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch release-1.12
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.12 by this push:
new 91d9004 [FLINK-23133][python] Properly handle the dependencies when
mixing use of Python Table API and Python DataStream API
91d9004 is described below
commit 91d9004a4b41f3f7c499d410c1dd09299dd44b25
Author: Dian Fu <[email protected]>
AuthorDate: Thu Jun 24 15:26:11 2021 +0800
[FLINK-23133][python] Properly handle the dependencies when mixing use of
Python Table API and Python DataStream API
This closes #16272.
---
docs/dev/python/dependency_management.md | 5 +
docs/dev/python/dependency_management.zh.md | 5 +
.../tests/test_stream_execution_environment.py | 43 ++++-
flink-python/pyflink/table/table_environment.py | 5 +-
.../apache/flink/python/util/PythonConfigUtil.java | 209 ++++++++++++---------
5 files changed, 168 insertions(+), 99 deletions(-)
diff --git a/docs/dev/python/dependency_management.md
b/docs/dev/python/dependency_management.md
index 5174af7..813878e 100644
--- a/docs/dev/python/dependency_management.md
+++ b/docs/dev/python/dependency_management.md
@@ -30,6 +30,11 @@ the local Python environment, download the machine learning
model to local, etc.
However, this approach doesn't work well when users want to submit the PyFlink
jobs to remote clusters.
In the following sections, we will introduce the options provided in PyFlink
for these requirements.
+<span class="label label-info">Note</span> Both Python DataStream API and
Python Table API have provided
+APIs for each kind of dependency. If you are mixing use of Python DataStream
API and Python Table API
+in a single job, you should specify the dependencies via Python DataStream API
to make them work for
+both the Python DataStream API and Python Table API.
+
* This will be replaced by the TOC
{:toc}
diff --git a/docs/dev/python/dependency_management.zh.md
b/docs/dev/python/dependency_management.zh.md
index 9da533e..8e40d11 100644
--- a/docs/dev/python/dependency_management.zh.md
+++ b/docs/dev/python/dependency_management.zh.md
@@ -30,6 +30,11 @@ the local Python environment, download the machine learning
model to local, etc.
However, this approach doesn't work well when users want to submit the PyFlink
jobs to remote clusters.
In the following sections, we will introduce the options provided in PyFlink
for these requirements.
+<span class="label label-info">Note</span> Both Python DataStream API and
Python Table API have provided
+APIs for each kind of dependency. If you are mixing use of Python DataStream
API and Python Table API
+in a single job, you should specify the dependencies via Python DataStream API
to make them work for
+both the Python DataStream API and Python Table API.
+
* This will be replaced by the TOC
{:toc}
diff --git
a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
index f4756ed..34677f9 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
@@ -37,7 +37,8 @@ from pyflink.datastream.tests.test_util import
DataStreamTestSinkFunction
from pyflink.find_flink_home import _find_flink_source_root
from pyflink.java_gateway import get_gateway
from pyflink.pyflink_gateway_server import on_windows
-from pyflink.table import DataTypes, CsvTableSource, CsvTableSink,
StreamTableEnvironment
+from pyflink.table import DataTypes, CsvTableSource, CsvTableSink,
StreamTableEnvironment, \
+ EnvironmentSettings
from pyflink.testing.test_case_utils import PyFlinkTestCase, exec_insert_table
@@ -350,6 +351,46 @@ class StreamExecutionEnvironmentTests(PyFlinkTestCase):
expected.sort()
self.assertEqual(expected, result)
+ def test_add_python_file_2(self):
+ import uuid
+ python_file_dir = os.path.join(self.tempdir, "python_file_dir_" +
str(uuid.uuid4()))
+ os.mkdir(python_file_dir)
+ python_file_path = os.path.join(python_file_dir, "test_dep1.py")
+ with open(python_file_path, 'w') as f:
+ f.write("def add_two(a):\n return a + 2")
+
+ def plus_two_map(value):
+ from test_dep1 import add_two
+ return add_two(value)
+
+ self.env.add_python_file(python_file_path)
+ ds = self.env.from_collection([1, 2, 3, 4, 5])
+ ds = ds.map(plus_two_map, Types.LONG())
+ python_file_path = os.path.join(python_file_dir, "test_dep2.py")
+ with open(python_file_path, 'w') as f:
+ f.write("def add_three(a):\n return a + 3")
+
+ def plus_three(value):
+ from test_dep2 import add_three
+ return add_three(value)
+
+ t_env = StreamTableEnvironment.create(
+ stream_execution_environment=self.env,
+
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
+ self.env.add_python_file(python_file_path)
+
+ from pyflink.table.udf import udf
+ from pyflink.table.expressions import col
+ add_three = udf(plus_three, result_type=DataTypes.BIGINT())
+
+ tab = t_env.from_data_stream(ds, 'a') \
+ .select(add_three(col('a')))
+ result = [i[0] for i in tab.execute().collect()]
+ expected = [6, 7, 8, 9, 10]
+ result.sort()
+ expected.sort()
+ self.assertEqual(expected, result)
+
def test_set_requirements_without_cached_directory(self):
import uuid
requirements_txt_path = os.path.join(self.tempdir, str(uuid.uuid4()))
diff --git a/flink-python/pyflink/table/table_environment.py
b/flink-python/pyflink/table/table_environment.py
index 0938069..fd7b9e5 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -1731,10 +1731,7 @@ class StreamTableEnvironment(TableEnvironment):
"""
j_data_stream = data_stream._j_data_stream
JPythonConfigUtil =
get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil
- JPythonConfigUtil.declareManagedMemory(
- j_data_stream.getTransformation(),
- self._get_j_env(),
- self._j_tenv.getConfig())
+
JPythonConfigUtil.configPythonOperator(j_data_stream.getExecutionEnvironment())
if len(fields) == 0:
return Table(j_table=self._j_tenv.fromDataStream(j_data_stream),
t_env=self)
elif all(isinstance(f, Expression) for f in fields):
diff --git
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 78818c3..49960d8 100644
---
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -30,20 +30,16 @@ import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.graph.StreamGraph;
-import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
-import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
import
org.apache.flink.streaming.api.operators.python.OneInputPythonFunctionOperator;
-import
org.apache.flink.streaming.api.operators.python.PythonKeyedProcessOperator;
import
org.apache.flink.streaming.api.operators.python.PythonPartitionCustomOperator;
import
org.apache.flink.streaming.api.operators.python.PythonTimestampsAndWatermarksOperator;
-import
org.apache.flink.streaming.api.operators.python.TwoInputPythonFunctionOperator;
import
org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.WithBoundedness;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
@@ -53,7 +49,6 @@ import org.apache.flink.table.api.TableException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
-import java.util.Collection;
import java.util.List;
/**
@@ -102,37 +97,6 @@ public class PythonConfigUtil {
return (Configuration) getConfigurationMethod.invoke(env);
}
- /**
- * Configure the {@link OneInputPythonFunctionOperator} to be chained with
the
- * upstream/downstream operator by setting their parallelism, slot sharing
group, co-location
- * group to be the same, and applying a {@link ForwardPartitioner}. 1.
operator with name
- * "_keyed_stream_values_operator" should align with its downstream
operator. 2. operator with
- * name "_stream_key_by_map_operator" should align with its upstream
operator.
- */
- private static void alignStreamNode(StreamNode streamNode, StreamGraph
streamGraph) {
- if
(streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
- StreamEdge downStreamEdge = streamNode.getOutEdges().get(0);
- StreamNode downStreamNode =
streamGraph.getStreamNode(downStreamEdge.getTargetId());
- chainStreamNode(downStreamEdge, streamNode, downStreamNode);
- downStreamEdge.setPartitioner(new ForwardPartitioner());
- }
-
- if
(streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
- ||
streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME))
{
- StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
- StreamNode upStreamNode =
streamGraph.getStreamNode(upStreamEdge.getSourceId());
- chainStreamNode(upStreamEdge, streamNode, upStreamNode);
- }
- }
-
- private static void chainStreamNode(
- StreamEdge streamEdge, StreamNode firstStream, StreamNode
secondStream) {
- streamEdge.setPartitioner(new ForwardPartitioner<>());
- firstStream.setParallelism(secondStream.getParallelism());
- firstStream.setCoLocationGroup(secondStream.getCoLocationGroup());
- firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup());
- }
-
/** Set Python Operator Use Managed Memory. */
public static void declareManagedMemory(
Transformation<?> transformation,
@@ -165,60 +129,119 @@ public class PythonConfigUtil {
StreamExecutionEnvironment env, boolean clearTransformations)
throws IllegalAccessException, NoSuchMethodException,
InvocationTargetException,
NoSuchFieldException {
- Configuration mergedConfig = getEnvConfigWithDependencies(env);
-
- boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
- if (executedInBatchMode) {
- throw new UnsupportedOperationException(
- "Batch mode is still not supported in Python DataStream
API.");
- }
-
- if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) {
- Field transformationsField =
-
StreamExecutionEnvironment.class.getDeclaredField("transformations");
- transformationsField.setAccessible(true);
- for (Transformation transform :
- (List<Transformation<?>>) transformationsField.get(env)) {
- if (isPythonOperator(transform)) {
-
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
- }
- }
- }
+ configPythonOperator(env);
String jobName =
getEnvironmentConfig(env)
.getString(
PipelineOptions.NAME,
StreamExecutionEnvironment.DEFAULT_JOB_NAME);
- StreamGraph streamGraph = env.getStreamGraph(jobName,
clearTransformations);
- Collection<StreamNode> streamNodes = streamGraph.getStreamNodes();
- for (StreamNode streamNode : streamNodes) {
- alignStreamNode(streamNode, streamGraph);
- StreamOperatorFactory streamOperatorFactory =
streamNode.getOperatorFactory();
- if (streamOperatorFactory instanceof SimpleOperatorFactory) {
- StreamOperator streamOperator =
- ((SimpleOperatorFactory)
streamOperatorFactory).getOperator();
- if ((streamOperator instanceof OneInputPythonFunctionOperator)
- || (streamOperator instanceof
TwoInputPythonFunctionOperator)
- || (streamOperator instanceof
PythonKeyedProcessOperator)) {
- AbstractPythonFunctionOperator pythonFunctionOperator =
- (AbstractPythonFunctionOperator) streamOperator;
+ return env.getStreamGraph(jobName, clearTransformations);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static void configPythonOperator(StreamExecutionEnvironment env)
+ throws IllegalAccessException, NoSuchMethodException,
InvocationTargetException,
+ NoSuchFieldException {
+ Configuration mergedConfig = getEnvConfigWithDependencies(env);
+
+ boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
+
+ Field transformationsField =
+
StreamExecutionEnvironment.class.getDeclaredField("transformations");
+ transformationsField.setAccessible(true);
+ List<Transformation<?>> transformations =
+ (List<Transformation<?>>) transformationsField.get(env);
+ for (Transformation<?> transformation : transformations) {
+ alignTransformation(transformation);
+ if (isPythonOperator(transformation)) {
+ // declare it is a Python operator
+
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
+
+ AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+ getPythonOperator(transformation);
+ if (pythonFunctionOperator != null) {
Configuration oldConfig =
pythonFunctionOperator.getPythonConfig().getMergedConfig();
+ // update dependency related configurations for Python
operators
pythonFunctionOperator.setPythonConfig(
generateNewPythonConfig(oldConfig, mergedConfig));
- if (streamOperator instanceof
PythonTimestampsAndWatermarksOperator) {
- ((PythonTimestampsAndWatermarksOperator)
streamOperator)
+ // set the emitProgressiveWatermarks flag for
+ // PythonTimestampsAndWatermarksOperator
+ if (pythonFunctionOperator instanceof
PythonTimestampsAndWatermarksOperator) {
+ ((PythonTimestampsAndWatermarksOperator<?>)
pythonFunctionOperator)
.configureEmitProgressiveWatermarks(!executedInBatchMode);
}
}
}
}
- setStreamPartitionCustomOperatorNumPartitions(streamNodes,
streamGraph);
+ setPartitionCustomOperatorNumPartitions(transformations);
+ }
- return streamGraph;
+ /**
+ * Configure the {@link OneInputPythonFunctionOperator} to be chained with
the
+ * upstream/downstream operator by setting their parallelism, slot sharing
group, co-location
+ * group to be the same, and applying a {@link ForwardPartitioner}. 1.
operator with name
+ * "_keyed_stream_values_operator" should align with its downstream
operator. 2. operator with
+ * name "_stream_key_by_map_operator" should align with its upstream
operator.
+ */
+ private static void alignTransformation(Transformation<?> transformation)
+ throws NoSuchFieldException, IllegalAccessException {
+ String transformName = transformation.getName();
+ Transformation<?> inputTransformation =
transformation.getInputs().get(0);
+ String inputTransformName = inputTransformation.getName();
+ if (inputTransformName.equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
+ chainTransformation(inputTransformation, transformation);
+ configForwardPartitioner(inputTransformation, transformation);
+ }
+ if (transformName.equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
+ ||
transformName.equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
+ chainTransformation(transformation, inputTransformation);
+ configForwardPartitioner(inputTransformation, transformation);
+ }
+ }
+
+ private static void chainTransformation(
+ Transformation<?> firstTransformation, Transformation<?>
secondTransformation) {
+
firstTransformation.setSlotSharingGroup(secondTransformation.getSlotSharingGroup());
+
firstTransformation.setCoLocationGroupKey(secondTransformation.getCoLocationGroupKey());
+
firstTransformation.setParallelism(secondTransformation.getParallelism());
+ }
+
+ private static void configForwardPartitioner(
+ Transformation<?> upTransformation, Transformation<?>
transformation)
+ throws IllegalAccessException, NoSuchFieldException {
+ // set ForwardPartitioner
+ PartitionTransformation<?> partitionTransform =
+ new PartitionTransformation<>(upTransformation, new
ForwardPartitioner<>());
+ Field inputTransformationField =
transformation.getClass().getDeclaredField("input");
+ inputTransformationField.setAccessible(true);
+ inputTransformationField.set(transformation, partitionTransform);
+ }
+
+ private static AbstractPythonFunctionOperator<?> getPythonOperator(
+ Transformation<?> transformation) {
+ StreamOperatorFactory<?> operatorFactory = null;
+ if (transformation instanceof OneInputTransformation) {
+ operatorFactory = ((OneInputTransformation<?, ?>)
transformation).getOperatorFactory();
+ } else if (transformation instanceof TwoInputTransformation) {
+ operatorFactory =
+ ((TwoInputTransformation<?, ?, ?>)
transformation).getOperatorFactory();
+ } else if (transformation instanceof
AbstractMultipleInputTransformation) {
+ operatorFactory =
+ ((AbstractMultipleInputTransformation<?>)
transformation).getOperatorFactory();
+ }
+
+ if (operatorFactory instanceof SimpleOperatorFactory
+ && ((SimpleOperatorFactory<?>) operatorFactory).getOperator()
+ instanceof AbstractPythonFunctionOperator) {
+ return (AbstractPythonFunctionOperator<?>)
+ ((SimpleOperatorFactory<?>) operatorFactory).getOperator();
+ }
+
+ return null;
}
private static boolean isPythonOperator(StreamOperatorFactory
streamOperatorFactory) {
@@ -243,27 +266,6 @@ public class PythonConfigUtil {
}
}
- private static void setStreamPartitionCustomOperatorNumPartitions(
- Collection<StreamNode> streamNodes, StreamGraph streamGraph) {
- for (StreamNode streamNode : streamNodes) {
- StreamOperatorFactory streamOperatorFactory =
streamNode.getOperatorFactory();
- if (streamOperatorFactory instanceof SimpleOperatorFactory) {
- StreamOperator streamOperator =
- ((SimpleOperatorFactory)
streamOperatorFactory).getOperator();
- if (streamOperator instanceof PythonPartitionCustomOperator) {
- PythonPartitionCustomOperator
partitionCustomFunctionOperator =
- (PythonPartitionCustomOperator) streamOperator;
- // Update the numPartitions of PartitionCustomOperator
after aligned all
- // operators.
- partitionCustomFunctionOperator.setNumPartitions(
- streamGraph
-
.getStreamNode(streamNode.getOutEdges().get(0).getTargetId())
- .getParallelism());
- }
- }
- }
- }
-
/**
* Generator a new {@link PythonConfig} with the combined config which is
derived from
* oldConfig.
@@ -331,4 +333,23 @@ public class PythonConfigUtil {
throw new TableException("Method getMergedConfig failed.", e);
}
}
+
+ private static void setPartitionCustomOperatorNumPartitions(
+ List<Transformation<?>> transformations) {
+ // Update the numPartitions of PartitionCustomOperator after aligned
all operators.
+ for (Transformation<?> transformation : transformations) {
+ Transformation<?> firstInputTransformation =
transformation.getInputs().get(0);
+ if (firstInputTransformation instanceof PartitionTransformation) {
+ firstInputTransformation =
firstInputTransformation.getInputs().get(0);
+ }
+ AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+ getPythonOperator(firstInputTransformation);
+ if (pythonFunctionOperator instanceof
PythonPartitionCustomOperator) {
+ PythonPartitionCustomOperator<?, ?>
partitionCustomFunctionOperator =
+ (PythonPartitionCustomOperator<?, ?>)
pythonFunctionOperator;
+
+
partitionCustomFunctionOperator.setNumPartitions(transformation.getParallelism());
+ }
+ }
+ }
}