You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/06/03 01:55:07 UTC

[flink] branch master updated: [FLINK-17923][python] Allow Python worker to use off-heap memory

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 77e5494  [FLINK-17923][python] Allow Python worker to use off-heap memory
77e5494 is described below

commit 77e5494c1c252ba2dd458078380ee862fa423e4e
Author: Dian Fu <di...@apache.org>
AuthorDate: Thu May 28 13:34:38 2020 +0800

    [FLINK-17923][python] Allow Python worker to use off-heap memory
    
    This closes #12370.
---
 docs/_includes/generated/python_configuration.html |  6 +++
 flink-python/pyflink/testing/test_case_utils.py    |  8 +++
 .../java/org/apache/flink/python/PythonConfig.java | 10 ++++
 .../org/apache/flink/python/PythonOptions.java     | 14 +++++
 .../python/AbstractPythonFunctionOperator.java     |  4 +-
 .../client/python/PythonFunctionFactoryTest.java   |  3 ++
 .../org/apache/flink/python/PythonConfigTest.java  |  9 ++++
 .../PythonScalarFunctionOperatorTestBase.java      |  3 ++
 .../plan/nodes/common/CommonPythonBase.scala       | 61 ++++++++++++++++++++-
 .../nodes/physical/batch/BatchExecPythonCalc.scala | 21 ++------
 .../physical/batch/BatchExecPythonCorrelate.scala  |  9 +++-
 .../physical/stream/StreamExecPythonCalc.scala     |  6 ++-
 .../stream/StreamExecPythonCorrelate.scala         |  9 +++-
 .../flink/table/plan/nodes/CommonPythonBase.scala  | 62 +++++++++++++++++++++-
 14 files changed, 201 insertions(+), 24 deletions(-)

diff --git a/docs/_includes/generated/python_configuration.html b/docs/_includes/generated/python_configuration.html
index 967cb66..890d025 100644
--- a/docs/_includes/generated/python_configuration.html
+++ b/docs/_includes/generated/python_configuration.html
@@ -63,6 +63,12 @@
             <td>The amount of memory to be allocated by the Python framework. The sum of the value of this configuration and "python.fn-execution.buffer.memory.size" represents the total memory of a Python worker. The memory will be accounted as managed memory if the actual memory allocated to an operator is no less than the total memory of a Python worker. Otherwise, this configuration takes no effect.</td>
         </tr>
         <tr>
+            <td><h5>python.fn-execution.memory.managed</h5></td>
+            <td style="word-wrap: break-word;">false</td>
+            <td>Boolean</td>
+            <td>If set, the Python worker will configure itself to use the managed memory budget of the task slot. Otherwise, it will use the Off-Heap Memory of the task slot. In this case, users should set the Task Off-Heap Memory using the configuration key taskmanager.memory.task.off-heap.size. For each Python worker, the required Task Off-Heap Memory is the sum of the value of python.fn-execution.framework.memory.size and python.fn-execution.buffer.memory.size.</td>
+        </tr>
+        <tr>
             <td><h5>python.metric.enabled</h5></td>
             <td style="word-wrap: break-word;">true</td>
             <td>Boolean</td>
diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py
index 35e889a..ac562db 100644
--- a/flink-python/pyflink/testing/test_case_utils.py
+++ b/flink-python/pyflink/testing/test_case_utils.py
@@ -128,6 +128,8 @@ class PyFlinkStreamTableTestCase(PyFlinkTestCase):
             self.env,
             environment_settings=EnvironmentSettings.new_instance()
                 .in_streaming_mode().use_old_planner().build())
+        self.t_env.get_config().get_configuration().set_string(
+            "taskmanager.memory.task.off-heap.size", "80mb")
 
 
 class PyFlinkBatchTableTestCase(PyFlinkTestCase):
@@ -140,6 +142,8 @@ class PyFlinkBatchTableTestCase(PyFlinkTestCase):
         self.env = ExecutionEnvironment.get_execution_environment()
         self.env.set_parallelism(2)
         self.t_env = BatchTableEnvironment.create(self.env, TableConfig())
+        self.t_env.get_config().get_configuration().set_string(
+            "taskmanager.memory.task.off-heap.size", "80mb")
 
     def collect(self, table):
         j_table = table._j_table
@@ -162,6 +166,8 @@ class PyFlinkBlinkStreamTableTestCase(PyFlinkTestCase):
         self.t_env = StreamTableEnvironment.create(
             self.env, environment_settings=EnvironmentSettings.new_instance()
                 .in_streaming_mode().use_blink_planner().build())
+        self.t_env.get_config().get_configuration().set_string(
+            "taskmanager.memory.task.off-heap.size", "80mb")
 
 
 class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase):
@@ -174,6 +180,8 @@ class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase):
         self.t_env = BatchTableEnvironment.create(
             environment_settings=EnvironmentSettings.new_instance()
             .in_batch_mode().use_blink_planner().build())
+        self.t_env.get_config().get_configuration().set_string(
+            "taskmanager.memory.task.off-heap.size", "80mb")
         self.t_env._j_tenv.getPlanner().getExecEnv().setParallelism(2)
 
 
diff --git a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
index 01a7b95..1bec9d4f 100644
--- a/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
+++ b/flink-python/src/main/java/org/apache/flink/python/PythonConfig.java
@@ -102,6 +102,11 @@ public class PythonConfig implements Serializable {
 	 */
 	private final boolean metricEnabled;
 
+	/**
+	 * Whether to use managed memory for the Python worker.
+	 */
+	private final boolean isUsingManagedMemory;
+
 	public PythonConfig(Configuration config) {
 		maxBundleSize = config.get(PythonOptions.MAX_BUNDLE_SIZE);
 		maxBundleTimeMills = config.get(PythonOptions.MAX_BUNDLE_TIME_MILLS);
@@ -118,6 +123,7 @@ public class PythonConfig implements Serializable {
 		pythonArchivesInfo = config.getOptional(PythonDependencyUtils.PYTHON_ARCHIVES).orElse(new HashMap<>());
 		pythonExec = config.get(PythonOptions.PYTHON_EXECUTABLE);
 		metricEnabled = config.getBoolean(PythonOptions.PYTHON_METRIC_ENABLED);
+		isUsingManagedMemory = config.getBoolean(PythonOptions.USE_MANAGED_MEMORY);
 	}
 
 	public int getMaxBundleSize() {
@@ -163,4 +169,8 @@ public class PythonConfig implements Serializable {
 	public boolean isMetricEnabled() {
 		return metricEnabled;
 	}
+
+	public boolean isUsingManagedMemory() {
+		return isUsingManagedMemory;
+	}
 }
diff --git a/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java b/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java
index 254ad18..791c9d4 100644
--- a/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java
+++ b/flink-python/src/main/java/org/apache/flink/python/PythonOptions.java
@@ -21,6 +21,7 @@ package org.apache.flink.python;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.configuration.ConfigOption;
 import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.configuration.TaskManagerOptions;
 
 /**
  * Configuration options for the Python API.
@@ -148,4 +149,17 @@ public class PythonOptions {
 			"The priority is as following: 1. the configuration 'python.client.executable' defined in " +
 			"the source code; 2. the environment variable PYFLINK_EXECUTABLE; 3. the configuration " +
 			"'python.client.executable' defined in flink-conf.yaml");
+
+	/**
+	 * Whether the memory used by the Python framework is managed memory.
+	 */
+	public static final ConfigOption<Boolean> USE_MANAGED_MEMORY = ConfigOptions
+		.key("python.fn-execution.memory.managed")
+		.defaultValue(false)
+		.withDescription(String.format("If set, the Python worker will configure itself to use the " +
+			"managed memory budget of the task slot. Otherwise, it will use the Off-Heap Memory " +
+			"of the task slot. In this case, users should set the Task Off-Heap Memory using the " +
+			"configuration key %s. For each Python worker, the required Task Off-Heap Memory " +
+			"is the sum of the value of %s and %s.", TaskManagerOptions.TASK_OFF_HEAP_MEMORY.key(),
+			PYTHON_FRAMEWORK_MEMORY_SIZE.key(), PYTHON_DATA_BUFFER_MEMORY_SIZE.key()));
 }
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
index b1df221..cb7a3ea 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
@@ -117,7 +117,9 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
 		try {
 			this.bundleStarted = new AtomicBoolean(false);
 
-			reserveMemoryForPythonWorker();
+			if (config.isUsingManagedMemory()) {
+				reserveMemoryForPythonWorker();
+			}
 
 			this.maxBundleSize = config.getMaxBundleSize();
 			if (this.maxBundleSize <= 0) {
diff --git a/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java b/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java
index 6e40739..3a1c6a8 100644
--- a/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java
+++ b/flink-python/src/test/java/org/apache/flink/client/python/PythonFunctionFactoryTest.java
@@ -33,6 +33,7 @@ import java.lang.reflect.Field;
 import java.util.Map;
 import java.util.UUID;
 
+import static org.apache.flink.configuration.TaskManagerOptions.TASK_OFF_HEAP_MEMORY;
 import static org.apache.flink.python.PythonOptions.PYTHON_FILES;
 import static org.apache.flink.table.api.Expressions.$;
 import static org.apache.flink.table.api.Expressions.call;
@@ -71,10 +72,12 @@ public class PythonFunctionFactoryTest {
 		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
 		flinkTableEnv = BatchTableEnvironment.create(env);
 		flinkTableEnv.getConfig().getConfiguration().set(PYTHON_FILES, pyFilePath.getAbsolutePath());
+		flinkTableEnv.getConfig().getConfiguration().setString(TASK_OFF_HEAP_MEMORY.key(), "80mb");
 		StreamExecutionEnvironment sEnv = StreamExecutionEnvironment.getExecutionEnvironment();
 		blinkTableEnv = StreamTableEnvironment.create(
 			sEnv, EnvironmentSettings.newInstance().useBlinkPlanner().inStreamingMode().build());
 		blinkTableEnv.getConfig().getConfiguration().set(PYTHON_FILES, pyFilePath.getAbsolutePath());
+		blinkTableEnv.getConfig().getConfiguration().setString(TASK_OFF_HEAP_MEMORY.key(), "80mb");
 		flinkSourceTable = flinkTableEnv.fromDataSet(env.fromElements("1", "2", "3")).as("str");
 		blinkSourceTable = blinkTableEnv.fromDataStream(sEnv.fromElements("1", "2", "3")).as("str");
 	}
diff --git a/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java b/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java
index f889b03..549c63c 100644
--- a/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java
+++ b/flink-python/src/test/java/org/apache/flink/python/PythonConfigTest.java
@@ -53,6 +53,8 @@ public class PythonConfigTest {
 		assertThat(pythonConfig.getPythonRequirementsCacheDirInfo().isPresent(), is(false));
 		assertThat(pythonConfig.getPythonArchivesInfo().isEmpty(), is(true));
 		assertThat(pythonConfig.getPythonExec(), is("python"));
+		assertThat(pythonConfig.isUsingManagedMemory(),
+			is(equalTo(PythonOptions.USE_MANAGED_MEMORY.defaultValue())));
 	}
 
 	@Test
@@ -149,4 +151,11 @@ public class PythonConfigTest {
 		assertThat(pythonConfig.getPythonExec(), is(equalTo("/usr/local/bin/python3")));
 	}
 
+	@Test
+	public void testManagedMemory() {
+		Configuration config = new Configuration();
+		config.set(PythonOptions.USE_MANAGED_MEMORY, true);
+		PythonConfig pythonConfig = new PythonConfig(config);
+		assertThat(pythonConfig.isUsingManagedMemory(), is(equalTo(true)));
+	}
 }
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java
index 05cdc63..fba3d3c 100644
--- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.python.PythonOptions;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -202,6 +203,8 @@ public abstract class PythonScalarFunctionOperatorTestBase<IN, OUT, UDFIN> {
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(1);
 		StreamTableEnvironment tEnv = createTableEnvironment(env);
+		tEnv.getConfig().getConfiguration().setString(
+			TaskManagerOptions.TASK_OFF_HEAP_MEMORY.key(), "80mb");
 		tEnv.registerFunction("pyFunc", new PythonScalarFunction("pyFunc"));
 		DataStream<Tuple2<Integer, Integer>> ds = env.fromElements(new Tuple2<>(1, 2));
 		Table t = tEnv.fromDataStream(ds, $("a"), $("b")).select(call("pyFunc", $("a"), $("b")));
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala
index fc96c57..263d8c4 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonBase.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.nodes.common
 
 import org.apache.calcite.rex.{RexCall, RexLiteral, RexNode}
 import org.apache.calcite.sql.`type`.SqlTypeName
-import org.apache.flink.configuration.Configuration
+import org.apache.flink.configuration.{ConfigOption, Configuration, MemorySize, TaskManagerOptions}
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
 import org.apache.flink.table.api.{TableConfig, TableException}
 import org.apache.flink.table.functions.FunctionDefinition
@@ -116,6 +116,7 @@ trait CommonPythonBase {
     method.setAccessible(true)
     val config = new Configuration(method.invoke(env).asInstanceOf[Configuration])
     config.addAll(tableConfig.getConfiguration)
+    checkPythonWorkerMemory(config, env)
     config
   }
 
@@ -128,6 +129,64 @@ trait CommonPythonBase {
     }
     realEnv
   }
+
+  protected def isPythonWorkerUsingManagedMemory(config: Configuration): Boolean = {
+    val clazz = loadClass("org.apache.flink.python.PythonOptions")
+    config.getBoolean(clazz.getField("USE_MANAGED_MEMORY").get(null)
+      .asInstanceOf[ConfigOption[java.lang.Boolean]])
+  }
+
+  protected def getPythonWorkerMemory(config: Configuration): MemorySize = {
+    val clazz = loadClass("org.apache.flink.python.PythonOptions")
+    val pythonFrameworkMemorySize = MemorySize.parse(
+      config.getString(
+        clazz.getField("PYTHON_FRAMEWORK_MEMORY_SIZE").get(null)
+          .asInstanceOf[ConfigOption[String]]))
+    val pythonBufferMemorySize = MemorySize.parse(
+      config.getString(
+        clazz.getField("PYTHON_DATA_BUFFER_MEMORY_SIZE").get(null)
+          .asInstanceOf[ConfigOption[String]]))
+    pythonFrameworkMemorySize.add(pythonBufferMemorySize)
+  }
+
+  private def checkPythonWorkerMemory(
+      config: Configuration, env: StreamExecutionEnvironment): Unit = {
+    if (!isPythonWorkerUsingManagedMemory(config)) {
+      val taskOffHeapMemory = config.get(TaskManagerOptions.TASK_OFF_HEAP_MEMORY)
+      val requiredPythonWorkerOffHeapMemory = getPythonWorkerMemory(config)
+      if (taskOffHeapMemory.compareTo(requiredPythonWorkerOffHeapMemory) < 0) {
+        throw new TableException(String.format("The configured Task Off-Heap Memory %s is less " +
+          "than the least required Python worker Memory %s. The Task Off-Heap Memory can be " +
+          "configured using the configuration key 'taskmanager.memory.task.off-heap.size'.",
+          taskOffHeapMemory, requiredPythonWorkerOffHeapMemory))
+      }
+    } else if (isRocksDbUsingManagedMemory(env)) {
+      throw new TableException("Currently it doesn't support to use Managed Memory for both " +
+        "RocksDB state backend and Python worker at the same time. You can either configure " +
+        "RocksDB state backend to use Task Off-Heap Memory via the configuration key " +
+        "'state.backend.rocksdb.memory.managed' or configure Python worker to use " +
+        "Task Off-Heap Memory via the configuration key " +
+        "'python.fn-execution.memory.managed'.")
+    }
+  }
+
+  private def isRocksDbUsingManagedMemory(env: StreamExecutionEnvironment): Boolean = {
+    val stateBackend = env.getStateBackend
+    if (stateBackend != null && env.getStateBackend.getClass.getCanonicalName.equals(
+      "org.apache.flink.contrib.streaming.state.RocksDBStateBackend")) {
+      val clazz = loadClass("org.apache.flink.contrib.streaming.state.RocksDBStateBackend")
+      val getMemoryConfigurationMethod = clazz.getDeclaredMethod("getMemoryConfiguration")
+      val rocksDbConfig = getMemoryConfigurationMethod.invoke(stateBackend)
+      val isUsingManagedMemoryMethod =
+        rocksDbConfig.getClass.getDeclaredMethod("isUsingManagedMemory")
+      val isUsingFixedMemoryPerSlotMethod =
+        rocksDbConfig.getClass.getDeclaredMethod("isUsingFixedMemoryPerSlot")
+      isUsingManagedMemoryMethod.invoke(rocksDbConfig).asInstanceOf[Boolean] &&
+        !isUsingFixedMemoryPerSlotMethod.invoke(rocksDbConfig).asInstanceOf[Boolean]
+    } else {
+      false
+    }
+  }
 }
 
 object CommonPythonBase {
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala
index 3b1fa3d..48387ee 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCalc.scala
@@ -19,7 +19,6 @@
 package org.apache.flink.table.planner.plan.nodes.physical.batch
 
 import org.apache.flink.api.dag.Transformation
-import org.apache.flink.configuration.{ConfigOption, Configuration, MemorySize}
 import org.apache.flink.table.data.RowData
 import org.apache.flink.table.planner.delegation.BatchPlanner
 import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCalc
@@ -61,20 +60,10 @@ class BatchExecPythonCalc(
       "BatchExecPythonCalc",
       getConfig(planner.getExecEnv, planner.getTableConfig))
 
-    ExecNode.setManagedMemoryWeight(
-      ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration))
-  }
-
-  private def getPythonWorkerMemory(config: Configuration): Long = {
-    val clazz = loadClass("org.apache.flink.python.PythonOptions")
-    val pythonFrameworkMemorySize = MemorySize.parse(
-      config.getString(
-        clazz.getField("PYTHON_FRAMEWORK_MEMORY_SIZE").get(null)
-          .asInstanceOf[ConfigOption[String]]))
-    val pythonBufferMemorySize = MemorySize.parse(
-      config.getString(
-        clazz.getField("PYTHON_DATA_BUFFER_MEMORY_SIZE").get(null)
-          .asInstanceOf[ConfigOption[String]]))
-    pythonFrameworkMemorySize.add(pythonBufferMemorySize).getBytes
+    if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+      ExecNode.setManagedMemoryWeight(
+        ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes)
+    }
+    ret
   }
 }
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala
index 5f765c9..062ddca 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonCorrelate.scala
@@ -22,12 +22,12 @@ import org.apache.flink.table.data.RowData
 import org.apache.flink.table.planner.delegation.BatchPlanner
 import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCorrelate
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan
-
 import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.{Correlate, JoinRelType}
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rex.{RexNode, RexProgram}
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNode
 
 /**
   * Batch physical RelNode for [[Correlate]] (Python user defined table function).
@@ -72,12 +72,17 @@ class BatchExecPythonCorrelate(
       planner: BatchPlanner): Transformation[RowData] = {
     val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
       .asInstanceOf[Transformation[RowData]]
-    createPythonOneInputTransformation(
+    val ret = createPythonOneInputTransformation(
       inputTransformation,
       scan,
       "BatchExecPythonCorrelate",
       outputRowType,
       getConfig(planner.getExecEnv, planner.getTableConfig),
       joinType)
+    if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+      ExecNode.setManagedMemoryWeight(
+        ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes)
+    }
+    ret
   }
 }
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala
index 0aa6999..bcb9a41 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCalc.scala
@@ -22,12 +22,12 @@ import org.apache.flink.api.dag.Transformation
 import org.apache.flink.table.data.RowData
 import org.apache.flink.table.planner.delegation.StreamPlanner
 import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCalc
-
 import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.Calc
 import org.apache.calcite.rex.RexProgram
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNode
 
 /**
   * Stream physical RelNode for Python ScalarFunctions.
@@ -64,6 +64,10 @@ class StreamExecPythonCalc(
       ret.setParallelism(1)
       ret.setMaxParallelism(1)
     }
+    if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+      ExecNode.setManagedMemoryWeight(
+        ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes)
+    }
     ret
   }
 }
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
index fd8224c7..4b83baf 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
@@ -23,12 +23,12 @@ import org.apache.flink.table.data.RowData
 import org.apache.flink.table.planner.delegation.StreamPlanner
 import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCorrelate
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan
-
 import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.JoinRelType
 import org.apache.calcite.rex.{RexNode, RexProgram}
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNode
 
 /**
   * Flink RelNode which matches along with join a python user defined table function.
@@ -77,12 +77,17 @@ class StreamExecPythonCorrelate(
       planner: StreamPlanner): Transformation[RowData] = {
     val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
       .asInstanceOf[Transformation[RowData]]
-    createPythonOneInputTransformation(
+    val ret = createPythonOneInputTransformation(
       inputTransformation,
       scan,
       "StreamExecPythonCorrelate",
       outputRowType,
       getConfig(planner.getExecEnv, planner.getTableConfig),
       joinType)
+    if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+      ExecNode.setManagedMemoryWeight(
+        ret, getPythonWorkerMemory(planner.getTableConfig.getConfiguration).getBytes)
+    }
+    ret
   }
 }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala
index 6d5b1da..0796cd4 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/CommonPythonBase.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.nodes
 import org.apache.calcite.rex.{RexCall, RexLiteral, RexNode}
 import org.apache.calcite.sql.`type`.SqlTypeName
 import org.apache.flink.api.java.ExecutionEnvironment
-import org.apache.flink.configuration.Configuration
+import org.apache.flink.configuration.{ConfigOption, Configuration, MemorySize, TaskManagerOptions}
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
 import org.apache.flink.table.api.{TableConfig, TableException}
 import org.apache.flink.table.functions.UserDefinedFunction
@@ -126,6 +126,7 @@ trait CommonPythonBase {
     method.setAccessible(true)
     val config = new Configuration(method.invoke(env).asInstanceOf[Configuration])
     config.addAll(tableConfig.getConfiguration)
+    checkPythonWorkerMemory(config, env)
     config
   }
 
@@ -138,6 +139,7 @@ trait CommonPythonBase {
     // ensure the user specified configuration has priority over others.
     val config = new Configuration(env.getConfiguration)
     config.addAll(tableConfig.getConfiguration)
+    checkPythonWorkerMemory(config)
     config
   }
 
@@ -150,6 +152,64 @@ trait CommonPythonBase {
     }
     realEnv
   }
+
+  private def isPythonWorkerUsingManagedMemory(config: Configuration): Boolean = {
+    val clazz = loadClass("org.apache.flink.python.PythonOptions")
+    config.getBoolean(clazz.getField("USE_MANAGED_MEMORY").get(null)
+      .asInstanceOf[ConfigOption[java.lang.Boolean]])
+  }
+
+  private def getPythonWorkerMemory(config: Configuration): MemorySize = {
+    val clazz = loadClass("org.apache.flink.python.PythonOptions")
+    val pythonFrameworkMemorySize = MemorySize.parse(
+      config.getString(
+        clazz.getField("PYTHON_FRAMEWORK_MEMORY_SIZE").get(null)
+          .asInstanceOf[ConfigOption[String]]))
+    val pythonBufferMemorySize = MemorySize.parse(
+      config.getString(
+        clazz.getField("PYTHON_DATA_BUFFER_MEMORY_SIZE").get(null)
+          .asInstanceOf[ConfigOption[String]]))
+    pythonFrameworkMemorySize.add(pythonBufferMemorySize)
+  }
+
+  private def checkPythonWorkerMemory(
+      config: Configuration, env: StreamExecutionEnvironment = null): Unit = {
+    if (!isPythonWorkerUsingManagedMemory(config)) {
+      val taskOffHeapMemory = config.get(TaskManagerOptions.TASK_OFF_HEAP_MEMORY)
+      val requiredPythonWorkerOffHeapMemory = getPythonWorkerMemory(config)
+      if (taskOffHeapMemory.compareTo(requiredPythonWorkerOffHeapMemory) < 0) {
+        throw new TableException(String.format("The configured Task Off-Heap Memory %s is less " +
+          "than the least required Python worker Memory %s. The Task Off-Heap Memory can be " +
+          "configured using the configuration key 'taskmanager.memory.task.off-heap.size'.",
+          taskOffHeapMemory, requiredPythonWorkerOffHeapMemory))
+      }
+    } else if (env != null && isRocksDbUsingManagedMemory(env)) {
+      throw new TableException("Currently it doesn't support to use Managed Memory for both " +
+        "RocksDB state backend and Python worker at the same time. You can either configure " +
+        "RocksDB state backend to use Task Off-Heap Memory via the configuration key " +
+        "'state.backend.rocksdb.memory.managed' or configure Python worker to use " +
+        "Task Off-Heap Memory via the configuration key " +
+        "'python.fn-execution.memory.managed'.")
+    }
+  }
+
+  private def isRocksDbUsingManagedMemory(env: StreamExecutionEnvironment): Boolean = {
+    val stateBackend = env.getStateBackend
+    if (stateBackend != null && stateBackend.getClass.getCanonicalName.equals(
+      "org.apache.flink.contrib.streaming.state.RocksDBStateBackend")) {
+      val clazz = loadClass("org.apache.flink.contrib.streaming.state.RocksDBStateBackend")
+      val getMemoryConfigurationMethod = clazz.getDeclaredMethod("getMemoryConfiguration")
+      val rocksDbConfig = getMemoryConfigurationMethod.invoke(stateBackend)
+      val isUsingManagedMemoryMethod =
+        rocksDbConfig.getClass.getDeclaredMethod("isUsingManagedMemory")
+      val isUsingFixedMemoryPerSlotMethod =
+        rocksDbConfig.getClass.getDeclaredMethod("isUsingFixedMemoryPerSlot")
+      isUsingManagedMemoryMethod.invoke(rocksDbConfig).asInstanceOf[Boolean] &&
+        !isUsingFixedMemoryPerSlotMethod.invoke(rocksDbConfig).asInstanceOf[Boolean]
+    } else {
+      false
+    }
+  }
 }
 
 object CommonPythonBase {