This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.13 by this push:
new 9f40251 [FLINK-22348][python] Fix DataStream.execute_and_collect
which doesn't declare managed memory for Python operators
9f40251 is described below
commit 9f4025123d7f9da2d05cafea5b7cd0f186dab467
Author: huangxingbo <[email protected]>
AuthorDate: Mon Apr 19 16:10:23 2021 +0800
[FLINK-22348][python] Fix DataStream.execute_and_collect which doesn't
declare managed memory for Python operators
This closes #15665.
---
flink-python/pyflink/datastream/data_stream.py | 2 +
.../pyflink/datastream/tests/test_data_stream.py | 10 +-
.../tests/test_stream_execution_environment.py | 42 ++-
flink-python/pyflink/table/table_environment.py | 4 +-
.../table/tests/test_table_environment_api.py | 2 +-
flink-python/pyflink/util/java_utils.py | 8 +-
.../java/org/apache/flink/python/PythonConfig.java | 8 +-
.../apache/flink/python/util/PythonConfigUtil.java | 301 +++++++++++----------
8 files changed, 218 insertions(+), 159 deletions(-)
diff --git a/flink-python/pyflink/datastream/data_stream.py
b/flink-python/pyflink/datastream/data_stream.py
index 6cebafd..ce85de9 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -641,6 +641,8 @@ class DataStream(object):
:param job_execution_name: The name of the job execution.
:param limit: The limit for the collected elements.
"""
+ JPythonConfigUtil =
get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil
+
JPythonConfigUtil.configPythonOperator(self._j_data_stream.getExecutionEnvironment())
if job_execution_name is None and limit is None:
return CloseableIterator(self._j_data_stream.executeAndCollect(),
self.get_type())
elif job_execution_name is not None and limit is None:
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 9d0543a..994cd60 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -387,11 +387,11 @@ class DataStreamTests(object):
decimal.Decimal('2000000000000000000.061111111111111'
'11111111111111'))]
expected = test_data
- ds = self.env.from_collection(test_data)
+ ds = self.env.from_collection(test_data).map(lambda a: a)
with ds.execute_and_collect() as results:
- actual = []
- for result in results:
- actual.append(result)
+ actual = [result for result in results]
+ actual.sort()
+ expected.sort()
self.assertEqual(expected, actual)
def test_key_by_map(self):
@@ -942,7 +942,7 @@ class StreamingModeDataStreamTests(DataStreamTests,
PyFlinkStreamingTestCase):
expected_num_partitions = 5
def my_partitioner(key, num_partitions):
- assert expected_num_partitions, num_partitions
+ assert expected_num_partitions == num_partitions
return key % num_partitions
partitioned_stream = ds.map(lambda x: x,
output_type=Types.ROW([Types.STRING(),
diff --git
a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
index ffc54ae..5df261e 100644
--- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
+++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py
@@ -38,9 +38,11 @@ from pyflink.datastream.tests.test_util import
DataStreamTestSinkFunction
from pyflink.find_flink_home import _find_flink_source_root
from pyflink.java_gateway import get_gateway
from pyflink.pyflink_gateway_server import on_windows
-from pyflink.table import DataTypes, CsvTableSource, CsvTableSink,
StreamTableEnvironment
+from pyflink.table import DataTypes, CsvTableSource, CsvTableSink,
StreamTableEnvironment, \
+ EnvironmentSettings
from pyflink.testing.test_case_utils import PyFlinkTestCase,
exec_insert_table, \
invoke_java_object_method
+from pyflink.util.java_utils import get_j_env_configuration
class StreamExecutionEnvironmentTests(PyFlinkTestCase):
@@ -337,20 +339,48 @@ class StreamExecutionEnvironmentTests(PyFlinkTestCase):
import uuid
python_file_dir = os.path.join(self.tempdir, "python_file_dir_" +
str(uuid.uuid4()))
os.mkdir(python_file_dir)
- python_file_path = os.path.join(python_file_dir,
"test_stream_dependency_manage_lib.py")
+ python_file_path = os.path.join(python_file_dir, "test_dep1.py")
with open(python_file_path, 'w') as f:
f.write("def add_two(a):\n return a + 2")
def plus_two_map(value):
- from test_stream_dependency_manage_lib import add_two
+ from test_dep1 import add_two
return add_two(value)
+ get_j_env_configuration(self.env._j_stream_execution_environment).\
+ setString("taskmanager.numberOfTaskSlots", "10")
self.env.add_python_file(python_file_path)
ds = self.env.from_collection([1, 2, 3, 4, 5])
- ds.map(plus_two_map).add_sink(self.test_sink)
- self.env.execute("test add python file")
+ ds = ds.map(plus_two_map, Types.LONG()) \
+ .slot_sharing_group("data_stream") \
+ .map(lambda i: i, Types.LONG()) \
+ .slot_sharing_group("table")
+
+ python_file_path = os.path.join(python_file_dir, "test_dep2.py")
+ with open(python_file_path, 'w') as f:
+ f.write("def add_three(a):\n return a + 3")
+
+ def plus_three(value):
+ from test_dep2 import add_three
+ return add_three(value)
+
+ t_env = StreamTableEnvironment.create(
+ stream_execution_environment=self.env,
+
environment_settings=EnvironmentSettings.new_instance().use_blink_planner().build())
+ self.env.add_python_file(python_file_path)
+
+ from pyflink.table.udf import udf
+ from pyflink.table.expressions import col
+ add_three = udf(plus_three, result_type=DataTypes.BIGINT())
+
+ tab = t_env.from_data_stream(ds, 'a') \
+ .select(add_three(col('a')))
+ t_env.to_append_stream(tab, Types.ROW([Types.LONG()])) \
+ .map(lambda i: i[0]) \
+ .add_sink(self.test_sink)
+ self.env.execute("test add_python_file")
result = self.test_sink.get_results(True)
- expected = ['3', '4', '5', '6', '7']
+ expected = ['6', '7', '8', '9', '10']
result.sort()
expected.sort()
self.assertEqual(expected, result)
diff --git a/flink-python/pyflink/table/table_environment.py
b/flink-python/pyflink/table/table_environment.py
index ee43076..2114dad 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -1548,7 +1548,7 @@ class TableEnvironment(object):
def _set_python_executable_for_local_executor(self):
jvm = get_gateway().jvm
- j_config = get_j_env_configuration(self)
+ j_config = get_j_env_configuration(self._get_j_env())
if not j_config.containsKey(jvm.PythonOptions.PYTHON_EXECUTABLE.key())
\
and is_local_deployment(j_config):
j_config.setString(jvm.PythonOptions.PYTHON_EXECUTABLE.key(),
sys.executable)
@@ -1559,7 +1559,7 @@ class TableEnvironment(object):
if jar_urls is not None:
# normalize and remove duplicates
jar_urls_set = set([jvm.java.net.URL(url).toString() for url in
jar_urls.split(";")])
- j_configuration = get_j_env_configuration(self)
+ j_configuration = get_j_env_configuration(self._get_j_env())
if j_configuration.containsKey(config_key):
for url in j_configuration.getString(config_key,
"").split(";"):
jar_urls_set.add(url)
diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py
b/flink-python/pyflink/table/tests/test_table_environment_api.py
index 7c6778c..3774178 100644
--- a/flink-python/pyflink/table/tests/test_table_environment_api.py
+++ b/flink-python/pyflink/table/tests/test_table_environment_api.py
@@ -51,7 +51,7 @@ class TableEnvironmentTest(object):
def test_set_sys_executable_for_local_mode(self):
jvm = get_gateway().jvm
- actual_executable = get_j_env_configuration(self.t_env) \
+ actual_executable = get_j_env_configuration(self.t_env._get_j_env()) \
.getString(jvm.PythonOptions.PYTHON_EXECUTABLE.key(), None)
self.assertEqual(sys.executable, actual_executable)
diff --git a/flink-python/pyflink/util/java_utils.py
b/flink-python/pyflink/util/java_utils.py
index 3ffe223..8ea8d9b 100644
--- a/flink-python/pyflink/util/java_utils.py
+++ b/flink-python/pyflink/util/java_utils.py
@@ -80,12 +80,12 @@ def is_instance_of(java_object, java_class):
param, java_object)
-def get_j_env_configuration(t_env):
- if is_instance_of(t_env._get_j_env(),
"org.apache.flink.api.java.ExecutionEnvironment"):
- return t_env._get_j_env().getConfiguration()
+def get_j_env_configuration(j_env):
+ if is_instance_of(j_env, "org.apache.flink.api.java.ExecutionEnvironment"):
+ return j_env.getConfiguration()
else:
return invoke_method(
- t_env._get_j_env(),
+ j_env,
"org.apache.flink.streaming.api.environment.StreamExecutionEnvironment",
"getConfiguration"
)
diff --git
a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
index d8d4144..9fa8383 100644
--- a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
+++ b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
@@ -84,10 +84,10 @@ public class PythonConfig implements Serializable {
private final boolean isUsingManagedMemory;
/** The Configuration that contains execution configs and dependencies
info. */
- private final Configuration mergedConfig;
+ private final Configuration config;
public PythonConfig(Configuration config) {
- mergedConfig = config;
+ this.config = config;
maxBundleSize = config.get(PythonOptions.MAX_BUNDLE_SIZE);
maxBundleTimeMills = config.get(PythonOptions.MAX_BUNDLE_TIME_MILLS);
maxArrowBatchSize = config.get(PythonOptions.MAX_ARROW_BATCH_SIZE);
@@ -148,7 +148,7 @@ public class PythonConfig implements Serializable {
return isUsingManagedMemory;
}
- public Configuration getMergedConfig() {
- return mergedConfig;
+ public Configuration getConfig() {
+ return config;
}
}
diff --git
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
index 6681c1f..cc65713 100644
---
a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
+++
b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java
@@ -30,20 +30,16 @@ import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.graph.StreamEdge;
import org.apache.flink.streaming.api.graph.StreamGraph;
-import org.apache.flink.streaming.api.graph.StreamNode;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
-import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import
org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
import
org.apache.flink.streaming.api.operators.python.OneInputPythonFunctionOperator;
-import
org.apache.flink.streaming.api.operators.python.PythonKeyedProcessOperator;
import
org.apache.flink.streaming.api.operators.python.PythonPartitionCustomOperator;
import
org.apache.flink.streaming.api.operators.python.PythonTimestampsAndWatermarksOperator;
-import
org.apache.flink.streaming.api.operators.python.TwoInputPythonFunctionOperator;
import
org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.WithBoundedness;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
@@ -53,7 +49,6 @@ import org.apache.flink.table.api.TableException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
-import java.util.Collection;
import java.util.List;
/**
@@ -102,37 +97,6 @@ public class PythonConfigUtil {
return (Configuration) getConfigurationMethod.invoke(env);
}
- /**
- * Configure the {@link OneInputPythonFunctionOperator} to be chained with
the
- * upstream/downstream operator by setting their parallelism, slot sharing
group, co-location
- * group to be the same, and applying a {@link ForwardPartitioner}. 1.
operator with name
- * "_keyed_stream_values_operator" should align with its downstream
operator. 2. operator with
- * name "_stream_key_by_map_operator" should align with its upstream
operator.
- */
- private static void alignStreamNode(StreamNode streamNode, StreamGraph
streamGraph) {
- if
(streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
- StreamEdge downStreamEdge = streamNode.getOutEdges().get(0);
- StreamNode downStreamNode =
streamGraph.getStreamNode(downStreamEdge.getTargetId());
- chainStreamNode(downStreamEdge, streamNode, downStreamNode);
- downStreamEdge.setPartitioner(new ForwardPartitioner());
- }
-
- if
(streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
- ||
streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME))
{
- StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
- StreamNode upStreamNode =
streamGraph.getStreamNode(upStreamEdge.getSourceId());
- chainStreamNode(upStreamEdge, streamNode, upStreamNode);
- }
- }
-
- private static void chainStreamNode(
- StreamEdge streamEdge, StreamNode firstStream, StreamNode
secondStream) {
- streamEdge.setPartitioner(new ForwardPartitioner<>());
- firstStream.setParallelism(secondStream.getParallelism());
- firstStream.setCoLocationGroup(secondStream.getCoLocationGroup());
- firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup());
- }
-
/** Set Python Operator Use Managed Memory. */
public static void declareManagedMemory(
Transformation<?> transformation,
@@ -144,16 +108,6 @@ public class PythonConfigUtil {
}
}
- private static void declareManagedMemory(Transformation<?> transformation)
{
- if (isPythonOperator(transformation)) {
-
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
- }
- List<Transformation<?>> inputTransformations =
transformation.getInputs();
- for (Transformation inputTransformation : inputTransformations) {
- declareManagedMemory(inputTransformation);
- }
- }
-
/**
* Generate a {@link StreamGraph} for transformations maintained by
current {@link
* StreamExecutionEnvironment}, and reset the merged env configurations
with dependencies to
@@ -165,98 +119,174 @@ public class PythonConfigUtil {
StreamExecutionEnvironment env, boolean clearTransformations)
throws IllegalAccessException, NoSuchMethodException,
InvocationTargetException,
NoSuchFieldException {
- Configuration mergedConfig = getEnvConfigWithDependencies(env);
-
- boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
-
- if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) {
- Field transformationsField =
-
StreamExecutionEnvironment.class.getDeclaredField("transformations");
- transformationsField.setAccessible(true);
- for (Transformation transform :
- (List<Transformation<?>>) transformationsField.get(env)) {
- if (isPythonOperator(transform)) {
-
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
- }
- }
- }
+ configPythonOperator(env);
String jobName =
getEnvironmentConfig(env)
.getString(
PipelineOptions.NAME,
StreamExecutionEnvironment.DEFAULT_JOB_NAME);
- StreamGraph streamGraph = env.getStreamGraph(jobName,
clearTransformations);
- Collection<StreamNode> streamNodes = streamGraph.getStreamNodes();
- for (StreamNode streamNode : streamNodes) {
- alignStreamNode(streamNode, streamGraph);
- StreamOperatorFactory streamOperatorFactory =
streamNode.getOperatorFactory();
- if (streamOperatorFactory instanceof SimpleOperatorFactory) {
- StreamOperator streamOperator =
- ((SimpleOperatorFactory)
streamOperatorFactory).getOperator();
- if ((streamOperator instanceof OneInputPythonFunctionOperator)
- || (streamOperator instanceof
TwoInputPythonFunctionOperator)
- || (streamOperator instanceof
PythonKeyedProcessOperator)) {
- AbstractPythonFunctionOperator pythonFunctionOperator =
- (AbstractPythonFunctionOperator) streamOperator;
-
- Configuration oldConfig =
-
pythonFunctionOperator.getPythonConfig().getMergedConfig();
+ return env.getStreamGraph(jobName, clearTransformations);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static void configPythonOperator(StreamExecutionEnvironment env)
+ throws IllegalAccessException, NoSuchMethodException,
InvocationTargetException,
+ NoSuchFieldException {
+ Configuration mergedConfig = getEnvConfigWithDependencies(env);
+
+ boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
+
+ Field transformationsField =
+
StreamExecutionEnvironment.class.getDeclaredField("transformations");
+ transformationsField.setAccessible(true);
+ List<Transformation<?>> transformations =
+ (List<Transformation<?>>) transformationsField.get(env);
+ for (Transformation<?> transformation : transformations) {
+ alignTransformation(transformation);
+
+ if (isPythonOperator(transformation)) {
+ // declare it is a Python operator
+
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
+
+ AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+ getPythonOperator(transformation);
+ if (pythonFunctionOperator != null) {
+ Configuration oldConfig =
pythonFunctionOperator.getPythonConfig().getConfig();
+ // update dependency related configurations for Python
operators
pythonFunctionOperator.setPythonConfig(
generateNewPythonConfig(oldConfig, mergedConfig));
- if (streamOperator instanceof
PythonTimestampsAndWatermarksOperator) {
- ((PythonTimestampsAndWatermarksOperator)
streamOperator)
+ // set the emitProgressiveWatermarks flag for
+ // PythonTimestampsAndWatermarksOperator
+ if (pythonFunctionOperator instanceof
PythonTimestampsAndWatermarksOperator) {
+ ((PythonTimestampsAndWatermarksOperator<?>)
pythonFunctionOperator)
.configureEmitProgressiveWatermarks(!executedInBatchMode);
}
}
}
}
- setStreamPartitionCustomOperatorNumPartitions(streamNodes,
streamGraph);
+ setPartitionCustomOperatorNumPartitions(transformations);
+ }
- return streamGraph;
+ public static Configuration getMergedConfig(
+ StreamExecutionEnvironment env, TableConfig tableConfig) {
+ try {
+ Configuration config = new
Configuration(getEnvironmentConfig(env));
+ PythonDependencyUtils.merge(config,
tableConfig.getConfiguration());
+ Configuration mergedConfig =
+
PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config);
+ mergedConfig.setString("table.exec.timezone",
tableConfig.getLocalTimeZone().getId());
+ return mergedConfig;
+ } catch (IllegalAccessException | NoSuchMethodException |
InvocationTargetException e) {
+ throw new TableException("Method getMergedConfig failed.", e);
+ }
}
- private static boolean isPythonOperator(StreamOperatorFactory
streamOperatorFactory) {
- if (streamOperatorFactory instanceof SimpleOperatorFactory) {
- return ((SimpleOperatorFactory)
streamOperatorFactory).getOperator()
- instanceof AbstractPythonFunctionOperator;
- } else {
- return false;
+ @SuppressWarnings("unchecked")
+ public static Configuration getMergedConfig(ExecutionEnvironment env,
TableConfig tableConfig) {
+ try {
+ Field field =
ExecutionEnvironment.class.getDeclaredField("cacheFile");
+ field.setAccessible(true);
+ Configuration config = new Configuration(env.getConfiguration());
+ PythonDependencyUtils.merge(config,
tableConfig.getConfiguration());
+ Configuration mergedConfig =
+ PythonDependencyUtils.configurePythonDependencies(
+ (List<Tuple2<String,
DistributedCache.DistributedCacheEntry>>)
+ field.get(env),
+ config);
+ mergedConfig.setString("table.exec.timezone",
tableConfig.getLocalTimeZone().getId());
+ return mergedConfig;
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new TableException("Method getMergedConfig failed.", e);
+ }
+ }
+
+ /**
+ * Configure the {@link OneInputPythonFunctionOperator} to be chained with
the
+ * upstream/downstream operator by setting their parallelism, slot sharing
group, co-location
+ * group to be the same, and applying a {@link ForwardPartitioner}. 1.
operator with name
+ * "_keyed_stream_values_operator" should align with its downstream
operator. 2. operator with
+ * name "_stream_key_by_map_operator" should align with its upstream
operator.
+ */
+ private static void alignTransformation(Transformation<?> transformation)
+ throws NoSuchFieldException, IllegalAccessException {
+ String transformName = transformation.getName();
+ Transformation<?> inputTransformation =
transformation.getInputs().get(0);
+ String inputTransformName = inputTransformation.getName();
+ if (inputTransformName.equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
+ chainTransformation(inputTransformation, transformation);
+ configForwardPartitioner(inputTransformation, transformation);
+ }
+ if (transformName.equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
+ ||
transformName.equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
+ chainTransformation(transformation, inputTransformation);
+ configForwardPartitioner(inputTransformation, transformation);
+ }
+ }
+
+ private static void chainTransformation(
+ Transformation<?> firstTransformation, Transformation<?>
secondTransformation) {
+
firstTransformation.setSlotSharingGroup(secondTransformation.getSlotSharingGroup());
+
firstTransformation.setCoLocationGroupKey(secondTransformation.getCoLocationGroupKey());
+
firstTransformation.setParallelism(secondTransformation.getParallelism());
+ }
+
+ private static void configForwardPartitioner(
+ Transformation<?> upTransformation, Transformation<?>
transformation)
+ throws IllegalAccessException, NoSuchFieldException {
+ // set ForwardPartitioner
+ PartitionTransformation<?> partitionTransform =
+ new PartitionTransformation<>(upTransformation, new
ForwardPartitioner<>());
+ Field inputTransformationField =
transformation.getClass().getDeclaredField("input");
+ inputTransformationField.setAccessible(true);
+ inputTransformationField.set(transformation, partitionTransform);
+ }
+
+ private static AbstractPythonFunctionOperator<?> getPythonOperator(
+ Transformation<?> transformation) {
+ StreamOperatorFactory<?> operatorFactory = null;
+ if (transformation instanceof OneInputTransformation) {
+ operatorFactory = ((OneInputTransformation<?, ?>)
transformation).getOperatorFactory();
+ } else if (transformation instanceof TwoInputTransformation) {
+ operatorFactory =
+ ((TwoInputTransformation<?, ?, ?>)
transformation).getOperatorFactory();
+ } else if (transformation instanceof
AbstractMultipleInputTransformation) {
+ operatorFactory =
+ ((AbstractMultipleInputTransformation<?>)
transformation).getOperatorFactory();
+ }
+
+ if (operatorFactory instanceof SimpleOperatorFactory
+ && ((SimpleOperatorFactory<?>) operatorFactory).getOperator()
+ instanceof AbstractPythonFunctionOperator) {
+ return (AbstractPythonFunctionOperator<?>)
+ ((SimpleOperatorFactory<?>) operatorFactory).getOperator();
}
+
+ return null;
}
private static boolean isPythonOperator(Transformation<?> transform) {
if (transform instanceof OneInputTransformation) {
- return isPythonOperator(((OneInputTransformation)
transform).getOperatorFactory());
+ return isPythonOperator(
+ ((OneInputTransformation<?, ?>)
transform).getOperatorFactory());
} else if (transform instanceof TwoInputTransformation) {
- return isPythonOperator(((TwoInputTransformation)
transform).getOperatorFactory());
+ return isPythonOperator(
+ ((TwoInputTransformation<?, ?, ?>)
transform).getOperatorFactory());
} else if (transform instanceof AbstractMultipleInputTransformation) {
return isPythonOperator(
- ((AbstractMultipleInputTransformation)
transform).getOperatorFactory());
+ ((AbstractMultipleInputTransformation<?>)
transform).getOperatorFactory());
} else {
return false;
}
}
- private static void setStreamPartitionCustomOperatorNumPartitions(
- Collection<StreamNode> streamNodes, StreamGraph streamGraph) {
- for (StreamNode streamNode : streamNodes) {
- StreamOperatorFactory streamOperatorFactory =
streamNode.getOperatorFactory();
- if (streamOperatorFactory instanceof SimpleOperatorFactory) {
- StreamOperator streamOperator =
- ((SimpleOperatorFactory)
streamOperatorFactory).getOperator();
- if (streamOperator instanceof PythonPartitionCustomOperator) {
- PythonPartitionCustomOperator
partitionCustomFunctionOperator =
- (PythonPartitionCustomOperator) streamOperator;
- // Update the numPartitions of PartitionCustomOperator
after aligned all
- // operators.
- partitionCustomFunctionOperator.setNumPartitions(
- streamGraph
-
.getStreamNode(streamNode.getOutEdges().get(0).getTargetId())
- .getParallelism());
- }
- }
+ private static boolean isPythonOperator(StreamOperatorFactory<?>
streamOperatorFactory) {
+ if (streamOperatorFactory instanceof SimpleOperatorFactory) {
+ return ((SimpleOperatorFactory<?>)
streamOperatorFactory).getOperator()
+ instanceof AbstractPythonFunctionOperator;
+ } else {
+ return false;
}
}
@@ -285,7 +315,8 @@ public class PythonConfigUtil {
StreamExecutionEnvironment.class.getDeclaredField("transformations");
transformationsField.setAccessible(true);
boolean existsUnboundedSource = false;
- for (Transformation transform : (List<Transformation<?>>)
transformationsField.get(env)) {
+ for (Transformation<?> transform :
+ (List<Transformation<?>>) transformationsField.get(env)) {
existsUnboundedSource =
existsUnboundedSource
|| (transform instanceof WithBoundedness
@@ -295,36 +326,32 @@ public class PythonConfigUtil {
return !existsUnboundedSource;
}
- public static Configuration getMergedConfig(
- StreamExecutionEnvironment env, TableConfig tableConfig) {
- try {
- Configuration config = new
Configuration(getEnvironmentConfig(env));
- PythonDependencyUtils.merge(config,
tableConfig.getConfiguration());
- Configuration mergedConfig =
-
PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config);
- mergedConfig.setString("table.exec.timezone",
tableConfig.getLocalTimeZone().getId());
- return mergedConfig;
- } catch (IllegalAccessException | NoSuchMethodException |
InvocationTargetException e) {
- throw new TableException("Method getMergedConfig failed.", e);
+ private static void declareManagedMemory(Transformation<?> transformation)
{
+ if (isPythonOperator(transformation)) {
+
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
+ }
+
+ for (Transformation<?> inputTransformation :
transformation.getInputs()) {
+ declareManagedMemory(inputTransformation);
}
}
- @SuppressWarnings("unchecked")
- public static Configuration getMergedConfig(ExecutionEnvironment env,
TableConfig tableConfig) {
- try {
- Field field =
ExecutionEnvironment.class.getDeclaredField("cacheFile");
- field.setAccessible(true);
- Configuration config = new Configuration(env.getConfiguration());
- PythonDependencyUtils.merge(config,
tableConfig.getConfiguration());
- Configuration mergedConfig =
- PythonDependencyUtils.configurePythonDependencies(
- (List<Tuple2<String,
DistributedCache.DistributedCacheEntry>>)
- field.get(env),
- config);
- mergedConfig.setString("table.exec.timezone",
tableConfig.getLocalTimeZone().getId());
- return mergedConfig;
- } catch (NoSuchFieldException | IllegalAccessException e) {
- throw new TableException("Method getMergedConfig failed.", e);
+ private static void setPartitionCustomOperatorNumPartitions(
+ List<Transformation<?>> transformations) {
+ // Update the numPartitions of PartitionCustomOperator after aligned
all operators.
+ for (Transformation<?> transformation : transformations) {
+ Transformation<?> firstInputTransformation =
transformation.getInputs().get(0);
+ if (firstInputTransformation instanceof PartitionTransformation) {
+ firstInputTransformation =
firstInputTransformation.getInputs().get(0);
+ }
+ AbstractPythonFunctionOperator<?> pythonFunctionOperator =
+ getPythonOperator(firstInputTransformation);
+ if (pythonFunctionOperator instanceof
PythonPartitionCustomOperator) {
+ PythonPartitionCustomOperator<?, ?>
partitionCustomFunctionOperator =
+ (PythonPartitionCustomOperator<?, ?>)
pythonFunctionOperator;
+
+
partitionCustomFunctionOperator.setNumPartitions(transformation.getParallelism());
+ }
}
}
}