This is an automated email from the ASF dual-hosted git repository. dianfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 77e5494 [FLINK-17923][python] Allow Python worker to use off-heap memory 77e5494 is described below commit 77e5494c1c252ba2dd458078380ee862fa423e4e Author: Dian Fu <dia...@apache.org> AuthorDate: Thu May 28 13:34:38 2020 +0800 [FLINK-17923][python] Allow Python worker to use off-heap memory This closes #12370. --- docs/_includes/generated/python_configuration.html | 6 +++ flink-python/pyflink/testing/test_case_utils.py | 8 +++ .../java/org/apache/flink/python/PythonConfig.java | 10 ++++ .../org/apache/flink/python/PythonOptions.java | 14 +++++ .../python/AbstractPythonFunctionOperator.java | 4 +- .../client/python/PythonFunctionFactoryTest.java | 3 ++ .../org/apache/flink/python/PythonConfigTest.java | 9 ++++ .../PythonScalarFunctionOperatorTestBase.java | 3 ++ .../plan/nodes/common/CommonPythonBase.scala | 61 ++++++++++++++++++++- .../nodes/physical/batch/BatchExecPythonCalc.scala | 21 ++------ .../physical/batch/BatchExecPythonCorrelate.scala | 9 +++- .../physical/stream/StreamExecPythonCalc.scala | 6 ++- .../stream/StreamExecPythonCorrelate.scala | 9 +++- .../flink/table/plan/nodes/CommonPythonBase.scala | 62 +++++++++++++++++++++- 14 files changed, 201 insertions(+), 24 deletions(-) diff --git a/docs/_includes/generated/python_configuration.html b/docs/_includes/generated/python_configuration.html index 967cb66..890d025 100644 --- a/docs/_includes/generated/python_configuration.html +++ b/docs/_includes/generated/python_configuration.html @@ -63,6 +63,12 @@ <td>The amount of memory to be allocated by the Python framework. The sum of the value of this configuration and "python.fn-execution.buffer.memory.size" represents the total memory of a Python worker. The memory will be accounted as managed memory if the actual memory allocated to an operator is no less than the total memory of a Python worker. Otherwise, this configuration takes no effect.</td> </tr> <tr> + <td><h5>python.fn-execution.memory.managed</h5></td> + <td style="word-wrap: break-word;">false</td> + <td>Boolean</td> + <td>If set, the Python worker will configure itself to use the managed memory budget of the task slot. Otherwise, it will use the Off-Heap Memory of the task slot. In this case, users should set the Task Off-Heap Memory using the configuration key taskmanager.memory.task.off-heap.size. For each Python worker, the required Task Off-Heap Memory is the sum of the value of python.fn-execution.framework.memory.size and python.fn-execution.buffer.memory.size.</td> + </tr> + <tr> <td><h5>python.metric.enabled</h5></td> <td style="word-wrap: break-word;">true</td> <td>Boolean</td> diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py index 35e889a..ac562db 100644 --- a/flink-python/pyflink/testing/test_case_utils.py +++ b/flink-python/pyflink/testing/test_case_utils.py @@ -128,6 +128,8 @@ class PyFlinkStreamTableTestCase(PyFlinkTestCase): self.env, environment_settings=EnvironmentSettings.new_instance() .in_streaming_mode().use_old_planner().build()) + self.t_env.get_config().get_configuration().set_string( + "taskmanager.memory.task.off-heap.size", "80mb") class PyFlinkBatchTableTestCase(PyFlinkTestCase): @@ -140,6 +142,8 @@ class PyFlinkBatchTableTestCase(PyFlinkTestCase): self.env = ExecutionEnvironment.get_execution_environment() self.env.set_parallelism(2) self.t_env = BatchTableEnvironment.create(self.env, TableConfig()) + self.t_env.get_config().get_configuration().set_string( + "taskmanager.memory.task.off-heap.size", "80mb") def collect(self, table): j_table = table._j_table @@ -162,6 +166,8 @@ class PyFlinkBlinkStreamTableTestCase(PyFlinkTestCase): self.t_env = StreamTableEnvironment.create( self.env, environment_settings=EnvironmentSettings.new_instance() .in_streaming_mode().use_blink_planner().build()) + self.t_env.get_config().get_configuration().set_string( + "taskmanager.memory.task.off-heap.size", "80mb") class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase): @@ -174,6 +180,8 @@ class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase): self.t_env = BatchTableEnvironment.create( environment_settings=EnvironmentSettings.new_instance() .in_batch_mode().use_blink_planner().build()) + self.t_env.get_config().get_configuration().set_string( + "taskmanager.memory.task.off-heap.size", "80mb") self.t_env._j_tenv.getPlanner().getExecEnv().setParallelism(2) diff --git a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java index 01a7b95..1bec9d4f 100644 --- a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java +++ b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java @@ -102,6 +102,11 @@ public class PythonConfig implements Serializable { */ private final boolean metricEnabled; + /** + * Whether to use managed memory for the Python worker. + */ + private final boolean isUsingManagedMemory; + public PythonConfig(Configuration config) { maxBundleSize = config.get(PythonOptions.MAX_BUNDLE_SIZE); maxBundleTimeMills = config.get(PythonOptions.MAX_BUNDLE_TIME_MILLS); @@ -118,6 +123,7 @@ public class PythonConfig implements Serializable { pythonArchivesInfo = config.getOptional(PythonDependencyUtils.PYTHON_ARCHIVES).orElse(new HashMap<>()); pythonExec = config.get(PythonOptions.PYTHON_EXECUTABLE); metricEnabled = config.getBoolean(PythonOptions.PYTHON_METRIC_ENABLED); + isUsingManagedMemory = config.getBoolean(PythonOptions.USE_MANAGED_MEMORY); } public int getMaxBundleSize() { @@ -163,4 +169,8 @@ public class PythonConfig implements Serializable { public boolean isMetricEnabled() { return metricEnabled; } + + public boolean isUsingManagedMemory() { + return isUsingManagedMemory; + } } diff --git a/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java b/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java index 254ad18..791c9d4 100644 --- a/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java +++ b/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java @@ -21,6 +21,7 @@ package org.apache.flink.python; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.configuration.ConfigOption; import org.apache.flink.configuration.ConfigOptions; +import org.apache.flink.configuration.TaskManagerOptions; /** * Configuration options for the Python API. @@ -148,4 +149,17 @@ public class PythonOptions { "The priority is as following: 1. the configuration 'python.client.executable' defined in " + "the source code; 2. the environment variable PYFLINK_EXECUTABLE; 3. the configuration " + "'python.client.executable' defined in flink-conf.yaml"); + + /** + * Whether the memory used by the Python framework is managed memory. + */ + public static final ConfigOption<Boolean> USE_MANAGED_MEMORY = ConfigOptions + .key("python.fn-execution.memory.managed") + .defaultValue(false) + .withDescription(String.format("If set, the Python worker will configure itself to use the " + + "managed memory budget of the task slot. Otherwise, it will use the Off-Heap Memory " + + "of the task slot. In this case, users should set the Task Off-Heap Memory using the " + + "configuration key %s. For each Python worker, the required Task Off-Heap Memory " + + "is the sum of the value of %s and %s.", TaskManagerOptions.TASK_OFF_HEAP_MEMORY.key(), + PYTHON_FRAMEWORK_MEMORY_SIZE.key(), PYTHON_DATA_BUFFER_MEMORY_SIZE.key())); } diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java index b1df221..cb7a3ea 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java @@ -117,7 +117,9 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT> try { this.bundleStarted = new AtomicBoolean(false); - reserveMemoryForPythonWorker(); + if (config.isUsingManagedMemory()) { + reserveMemoryForPythonWorker(); + } this.maxBundleSize = config.getMaxBundleSize(); if (this.maxBundleSize <= 0) { diff --git a/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java b/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java index 6e40739..3a1c6a8 100644 --- a/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java +++ b/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java @@ -33,6 +33,7 @@ import java.lang.reflect.Field; import java.util.Map; import java.util.UUID; +import static org.apache.flink.configuration.TaskManagerOptions.TASK_OFF_HEAP_MEMORY; import static org.apache.flink.python.PythonOptions.PYTHON_FILES; import static org.apache.flink.table.api.Expressions.$; import static org.apache.flink.table.api.Expressions.call; @@ -71,10 +72,12 @@ public class PythonFunctionFactoryTest { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); flinkTableEnv = BatchTableEnvironment.create(env); flinkTableEnv.getConfig().getConfiguration().set(PYTHON_FILES, pyFilePath.getAbsolutePath()); + flinkTableEnv.getConfig().getConfiguration().setString(TASK_OFF_HEAP_MEMORY.key(), "80mb"); StreamExecutionEnvironment sEnv = StreamExecutionEnvironment.getExecutionEnvironment(); blinkTableEnv = StreamTableEnvironment.create( sEnv, EnvironmentSettings.newInstance().useBlinkPlanner().inStreamingMode().build()); blinkTableEnv.getConfig().getConfiguration().set(PYTHON_FILES, pyFilePath.getAbsolutePath()); + blinkTableEnv.getConfig().getConfiguration().setString(TASK_OFF_HEAP_MEMORY.key(), "80mb"); flinkSourceTable = flinkTableEnv.fromDataSet(env.fromElements("1", "2", "3")).as("str"); blinkSourceTable = blinkTableEnv.fromDataStream(sEnv.fromElements("1", "2", "3")).as("str"); } diff --git a/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java b/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java index f889b03..549c63c 100644 --- a/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java +++ b/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java @@ -53,6 +53,8 @@ public class PythonConfigTest { assertThat(pythonConfig.getPythonRequirementsCacheDirInfo().isPresent(), is(false)); assertThat(pythonConfig.getPythonArchivesInfo().isEmpty(), is(true)); assertThat(pythonConfig.getPythonExec(), is("python")); + assertThat(pythonConfig.isUsingManagedMemory(), + is(equalTo(PythonOptions.USE_MANAGED_MEMORY.defaultValue()))); } @Test @@ -149,4 +151,11 @@ public class PythonConfigTest { assertThat(pythonConfig.getPythonExec(), is(equalTo("/usr/local/bin/python3"))); } + @Test + public void testManagedMemory() { + Configuration config = new Configuration(); + config.set(PythonOptions.USE_MANAGED_MEMORY, true); + PythonConfig pythonConfig = new PythonConfig(config); + assertThat(pythonConfig.isUsingManagedMemory(), is(equalTo(true))); + } } diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java index 05cdc63..fba3d3c 100644 --- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.python.PythonOptions; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; @@ -202,6 +203,8 @@ public abstract class PythonScalarFunctionOperatorTestBase<IN, OUT, UDFIN> { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); StreamTableEnvironment tEnv = createTableEnvironment(env); + tEnv.getConfig().getConfiguration().setString( + TaskManagerOptions.TASK_OFF_HEAP_MEMORY.key(), "80mb"); tEnv.registerFunction("pyFunc", new PythonScalarFunction("pyFunc")); DataStream<Tuple2<Integer, Integer>> ds = env.fromElements(new Tuple2<>(1, 2)); Table t = tEnv.fromDataStream(ds, $("a"), $("b")).select(call("pyFunc", $("a"), $("b"))); diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala index fc96c57..263d8c4 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.nodes.common import org.apache.calcite.rex.{RexCall, RexLiteral, RexNode} import org.apache.calcite.sql.`type`.SqlTypeName -import org.apache.flink.configuration.Configuration +import org.apache.flink.configuration.{ConfigOption, Configuration, MemorySize, TaskManagerOptions} import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment import org.apache.flink.table.api.{TableConfig, TableException} import org.apache.flink.table.functions.FunctionDefinition @@ -116,6 +116,7 @@ trait CommonPythonBase { method.setAccessible(true) val config = new Configuration(method.invoke(env).asInstanceOf[Configuration]) config.addAll(tableConfig.getConfiguration) + checkPythonWorkerMemory(config, env) config } @@ -128,6 +129,64 @@ trait CommonPythonBase { } realEnv } + + protected def isPythonWorkerUsingManagedMemory(config: Configuration): Boolean = { + val clazz = loadClass("org.apache.flink.python.PythonOptions") + config.getBoolean(clazz.getField("USE_MANAGED_MEMORY").get(null) + .asInstanceOf[ConfigOption[java.lang.Boolean]]) + } + + protected def getPythonWorkerMemory(config: Configuration): MemorySize = { + val clazz = loadClass("org.apache.flink.python.PythonOptions") + val pythonFrameworkMemorySize = MemorySize.parse( + config.getString( + clazz.getField("PYTHON_FRAMEWORK_MEMORY_SIZE").get(null) + .asInstanceOf[ConfigOption[String]])) + val pythonBufferMemorySize = MemorySize.parse( + config.getString( + clazz.getField("PYTHON_DATA_BUFFER_MEMORY_SIZE").get(null) + .asInstanceOf[ConfigOption[String]])) + pythonFrameworkMemorySize.add(pythonBufferMemorySize) + } + + private def checkPythonWorkerMemory( + config: Configuration, env: StreamExecutionEnvironment): Unit = { + if (!isPythonWorkerUsingManagedMemory(config)) { + val taskOffHeapMemory = config.get(TaskManagerOptions.TASK_OFF_HEAP_MEMORY) + val requiredPythonWorkerOffHeapMemory = getPythonWorkerMemory(config) + if (taskOffHeapMemory.compareTo(requiredPythonWorkerOffHeapMemory) < 0) { + throw new TableException(String.format("The configured Task Off-Heap Memory %s is less " + + "than the least required Python worker Memory %s. The Task Off-Heap Memory can be " + + "configured using the configuration key 'taskmanager.memory.task.off-heap.size'.", + taskOffHeapMemory, requiredPythonWorkerOffHeapMemory)) + } + } else if (isRocksDbUsingManagedMemory(env)) { + throw new TableException("Currently it doesn't support to use Managed Memory for both " + + "RocksDB state backend and Python worker at the same time. You can either configure " + + "RocksDB state backend to use Task Off-Heap Memory via the configuration key " + + "'state.backend.rocksdb.memory.managed' or configure Python worker to use " + + "Task Off-Heap Memory via the configuration key " + + "'python.fn-execution.memory.managed'.") + } + } + + private def isRocksDbUsingManagedMemory(env: StreamExecutionEnvironment): Boolean = { + val stateBackend = env.getStateBackend + if (stateBackend != null && env.getStateBackend.getClass.getCanonicalName.equals( + "org.apache.flink.contrib.streaming.state.RocksDBStateBackend")) { + val clazz = loadClass("org.apache.flink.contrib.streaming.state.RocksDBStateBackend") + val getMemoryConfigurationMethod = clazz.getDeclaredMethod("getMemoryConfiguration") + val rocksDbConfig = getMemoryConfigurationMethod.invoke(stateBackend) + val isUsingManagedMemoryMethod = + rocksDbConfig.getClass.getDeclaredMethod("isUsingManagedMemory") + val isUsingFixedMemoryPerSlotMethod = + rocksDbConfig.getClass.getDeclaredMethod("isUsingFixedMemoryPerSlot") + isUsingManagedMemoryMethod.invoke(rocksDbConfig).asInstanceOf[Boolean] && + !isUsingFixedMemoryPerSlotMethod.invoke(rocksDbConfig).asInstanceOf[Boolean] + } else { + false + } + } } object CommonPythonBase { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala index 3b1fa3d..48387ee 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala @@ -19,7 +19,6 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch import org.apache.flink.api.dag.Transformation -import org.apache.flink.configuration.{ConfigOption, Configuration, MemorySize} import org.apache.flink.table.data.RowData import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCalc @@ -61,20 +60,10 @@ class BatchExecPythonCalc( "BatchExecPythonCalc", getConfig(planner.getExecEnv, planner.getTableConfig)) - ExecNode.setManagedMemoryWeight( - ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration)) - } - - private def getPythonWorkerMemory(config: Configuration): Long = { - val clazz = loadClass("org.apache.flink.python.PythonOptions") - val pythonFrameworkMemorySize = MemorySize.parse( - config.getString( - clazz.getField("PYTHON_FRAMEWORK_MEMORY_SIZE").get(null) - .asInstanceOf[ConfigOption[String]])) - val pythonBufferMemorySize = MemorySize.parse( - config.getString( - clazz.getField("PYTHON_DATA_BUFFER_MEMORY_SIZE").get(null) - .asInstanceOf[ConfigOption[String]])) - pythonFrameworkMemorySize.add(pythonBufferMemorySize).getBytes + if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { + ExecNode.setManagedMemoryWeight( + ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes) + } + ret } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala index 5f765c9..062ddca 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala @@ -22,12 +22,12 @@ import org.apache.flink.table.data.RowData import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCorrelate import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan - import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{Correlate, JoinRelType} import org.apache.calcite.rel.RelNode import org.apache.calcite.rex.{RexNode, RexProgram} +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode /** * Batch physical RelNode for [[Correlate]] (Python user defined table function). @@ -72,12 +72,17 @@ class BatchExecPythonCorrelate( planner: BatchPlanner): Transformation[RowData] = { val inputTransformation = getInputNodes.get(0).translateToPlan(planner) .asInstanceOf[Transformation[RowData]] - createPythonOneInputTransformation( + val ret = createPythonOneInputTransformation( inputTransformation, scan, "BatchExecPythonCorrelate", outputRowType, getConfig(planner.getExecEnv, planner.getTableConfig), joinType) + if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { + ExecNode.setManagedMemoryWeight( + ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes) + } + ret } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala index 0aa6999..bcb9a41 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala @@ -22,12 +22,12 @@ import org.apache.flink.api.dag.Transformation import org.apache.flink.table.data.RowData import org.apache.flink.table.planner.delegation.StreamPlanner import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCalc - import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Calc import org.apache.calcite.rex.RexProgram +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode /** * Stream physical RelNode for Python ScalarFunctions. @@ -64,6 +64,10 @@ class StreamExecPythonCalc( ret.setParallelism(1) ret.setMaxParallelism(1) } + if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { + ExecNode.setManagedMemoryWeight( + ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes) + } ret } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala index fd8224c7..4b83baf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala @@ -23,12 +23,12 @@ import org.apache.flink.table.data.RowData import org.apache.flink.table.planner.delegation.StreamPlanner import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCorrelate import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan - import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rex.{RexNode, RexProgram} +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode /** * Flink RelNode which matches along with join a python user defined table function. @@ -77,12 +77,17 @@ class StreamExecPythonCorrelate( planner: StreamPlanner): Transformation[RowData] = { val inputTransformation = getInputNodes.get(0).translateToPlan(planner) .asInstanceOf[Transformation[RowData]] - createPythonOneInputTransformation( + val ret = createPythonOneInputTransformation( inputTransformation, scan, "StreamExecPythonCorrelate", outputRowType, getConfig(planner.getExecEnv, planner.getTableConfig), joinType) + if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { + ExecNode.setManagedMemoryWeight( + ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes) + } + ret } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala index 6d5b1da..0796cd4 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.plan.nodes import org.apache.calcite.rex.{RexCall, RexLiteral, RexNode} import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.flink.api.java.ExecutionEnvironment -import org.apache.flink.configuration.Configuration +import org.apache.flink.configuration.{ConfigOption, Configuration, MemorySize, TaskManagerOptions} import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment import org.apache.flink.table.api.{TableConfig, TableException} import org.apache.flink.table.functions.UserDefinedFunction @@ -126,6 +126,7 @@ trait CommonPythonBase { method.setAccessible(true) val config = new Configuration(method.invoke(env).asInstanceOf[Configuration]) config.addAll(tableConfig.getConfiguration) + checkPythonWorkerMemory(config, env) config } @@ -138,6 +139,7 @@ trait CommonPythonBase { // ensure the user specified configuration has priority over others. val config = new Configuration(env.getConfiguration) config.addAll(tableConfig.getConfiguration) + checkPythonWorkerMemory(config) config } @@ -150,6 +152,64 @@ trait CommonPythonBase { } realEnv } + + private def isPythonWorkerUsingManagedMemory(config: Configuration): Boolean = { + val clazz = loadClass("org.apache.flink.python.PythonOptions") + config.getBoolean(clazz.getField("USE_MANAGED_MEMORY").get(null) + .asInstanceOf[ConfigOption[java.lang.Boolean]]) + } + + private def getPythonWorkerMemory(config: Configuration): MemorySize = { + val clazz = loadClass("org.apache.flink.python.PythonOptions") + val pythonFrameworkMemorySize = MemorySize.parse( + config.getString( + clazz.getField("PYTHON_FRAMEWORK_MEMORY_SIZE").get(null) + .asInstanceOf[ConfigOption[String]])) + val pythonBufferMemorySize = MemorySize.parse( + config.getString( + clazz.getField("PYTHON_DATA_BUFFER_MEMORY_SIZE").get(null) + .asInstanceOf[ConfigOption[String]])) + pythonFrameworkMemorySize.add(pythonBufferMemorySize) + } + + private def checkPythonWorkerMemory( + config: Configuration, env: StreamExecutionEnvironment = null): Unit = { + if (!isPythonWorkerUsingManagedMemory(config)) { + val taskOffHeapMemory = config.get(TaskManagerOptions.TASK_OFF_HEAP_MEMORY) + val requiredPythonWorkerOffHeapMemory = getPythonWorkerMemory(config) + if (taskOffHeapMemory.compareTo(requiredPythonWorkerOffHeapMemory) < 0) { + throw new TableException(String.format("The configured Task Off-Heap Memory %s is less " + + "than the least required Python worker Memory %s. The Task Off-Heap Memory can be " + + "configured using the configuration key 'taskmanager.memory.task.off-heap.size'.", + taskOffHeapMemory, requiredPythonWorkerOffHeapMemory)) + } + } else if (env != null && isRocksDbUsingManagedMemory(env)) { + throw new TableException("Currently it doesn't support to use Managed Memory for both " + + "RocksDB state backend and Python worker at the same time. You can either configure " + + "RocksDB state backend to use Task Off-Heap Memory via the configuration key " + + "'state.backend.rocksdb.memory.managed' or configure Python worker to use " + + "Task Off-Heap Memory via the configuration key " + + "'python.fn-execution.memory.managed'.") + } + } + + private def isRocksDbUsingManagedMemory(env: StreamExecutionEnvironment): Boolean = { + val stateBackend = env.getStateBackend + if (stateBackend != null && stateBackend.getClass.getCanonicalName.equals( + "org.apache.flink.contrib.streaming.state.RocksDBStateBackend")) { + val clazz = loadClass("org.apache.flink.contrib.streaming.state.RocksDBStateBackend") + val getMemoryConfigurationMethod = clazz.getDeclaredMethod("getMemoryConfiguration") + val rocksDbConfig = getMemoryConfigurationMethod.invoke(stateBackend) + val isUsingManagedMemoryMethod = + rocksDbConfig.getClass.getDeclaredMethod("isUsingManagedMemory") + val isUsingFixedMemoryPerSlotMethod = + rocksDbConfig.getClass.getDeclaredMethod("isUsingFixedMemoryPerSlot") + isUsingManagedMemoryMethod.invoke(rocksDbConfig).asInstanceOf[Boolean] && + !isUsingFixedMemoryPerSlotMethod.invoke(rocksDbConfig).asInstanceOf[Boolean] + } else { + false + } + } } object CommonPythonBase {