You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/09/07 02:31:41 UTC

[flink] 01/03: [fixup][table-planner] Using user classloader instead of thread context classloader

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 523546101f0180999f11d68269aad53c59134064
Author: fengli <ld...@163.com>
AuthorDate: Mon Aug 29 20:10:53 2022 +0800

    [fixup][table-planner] Using user classloader instead of thread context classloader
---
 .../exec/batch/BatchExecPythonGroupAggregate.java  |  9 ++--
 .../batch/BatchExecPythonGroupWindowAggregate.java |  8 ++--
 .../exec/batch/BatchExecPythonOverAggregate.java   |  8 ++--
 .../nodes/exec/common/CommonExecPythonCalc.java    | 30 ++++++++----
 .../exec/common/CommonExecPythonCorrelate.java     | 24 ++++++----
 .../stream/StreamExecPythonGroupAggregate.java     | 11 +++--
 .../StreamExecPythonGroupTableAggregate.java       | 12 +++--
 .../StreamExecPythonGroupWindowAggregate.java      | 16 +++++--
 .../exec/stream/StreamExecPythonOverAggregate.java | 10 ++--
 .../plan/nodes/exec/utils/CommonPythonUtil.java    | 53 ++++++++++++++--------
 .../physical/common/CommonPhysicalMatchRule.java   |  3 +-
 .../table/planner/delegation/PlannerBase.scala     |  2 +-
 .../physical/batch/BatchPhysicalSortRule.scala     |  1 -
 13 files changed, 125 insertions(+), 62 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java
index dbb6033c364..98e2ca2551d 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java
@@ -94,7 +94,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
         final RowType inputRowType = (RowType) inputEdge.getOutputType();
         final RowType outputRowType = InternalTypeInfo.of(getOutputType()).toRowType();
         Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         OneInputTransformation<RowData, RowData> transform =
                 createPythonOneInputTransformation(
                         inputTransform,
@@ -104,7 +105,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
                         config,
                         planner.getFlinkContext().getClassLoader());
 
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
         return transform;
@@ -149,7 +151,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
             int[] udafInputOffsets,
             PythonFunctionInfo[] pythonFunctionInfos) {
         final Class<?> clazz =
-                CommonPythonUtil.loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME);
+                CommonPythonUtil.loadClass(
+                        ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader);
 
         RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType);
         RowType udfOutputType =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java
index ae8a9c2ad02..930a2f7fe59 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java
@@ -114,7 +114,8 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
 
         final Tuple2<Long, Long> windowSizeAndSlideSize = WindowCodeGenerator.getWindowDef(window);
         final Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         int groupBufferLimitSize =
                 pythonConfig.getInteger(
                         ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT);
@@ -130,7 +131,8 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
                         pythonConfig,
                         config,
                         planner.getFlinkContext().getClassLoader());
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
         return transform;
@@ -204,7 +206,7 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
             PythonFunctionInfo[] pythonFunctionInfos) {
         Class<?> clazz =
                 CommonPythonUtil.loadClass(
-                        ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+                        ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader);
 
         RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType);
         RowType udfOutputType =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java
index 5023931259f..9f4717aa5ef 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java
@@ -153,7 +153,8 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {
             }
         }
         Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         OneInputTransformation<RowData, RowData> transform =
                 createPythonOneInputTransformation(
                         inputTransform,
@@ -163,7 +164,8 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {
                         pythonConfig,
                         config,
                         planner.getFlinkContext().getClassLoader());
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
         return transform;
@@ -213,7 +215,7 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {
             PythonFunctionInfo[] pythonFunctionInfos) {
         Class<?> clazz =
                 CommonPythonUtil.loadClass(
-                        ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+                        ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader);
 
         RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType);
         RowType udfOutputType =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java
index e102de9d063..d0249791edd 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java
@@ -108,14 +108,16 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
         final Transformation<RowData> inputTransform =
                 (Transformation<RowData>) inputEdge.translateToPlan(planner);
         final Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         OneInputTransformation<RowData, RowData> ret =
                 createPythonOneInputTransformation(
                         inputTransform,
                         config,
                         planner.getFlinkContext().getClassLoader(),
                         pythonConfig);
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
         return ret;
@@ -139,7 +141,7 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
                         .collect(Collectors.toList());
 
         Tuple2<int[], PythonFunctionInfo[]> extractResult =
-                extractPythonScalarFunctionInfos(pythonRexCalls);
+                extractPythonScalarFunctionInfos(pythonRexCalls, classLoader);
         int[] pythonUdfInputOffsets = extractResult.f0;
         PythonFunctionInfo[] pythonFunctionInfos = extractResult.f1;
         LogicalType[] inputLogicalTypes =
@@ -185,11 +187,14 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
     }
 
     private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos(
-            List<RexCall> rexCalls) {
+            List<RexCall> rexCalls, ClassLoader classLoader) {
         LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>();
         PythonFunctionInfo[] pythonFunctionInfos =
                 rexCalls.stream()
-                        .map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes))
+                        .map(
+                                x ->
+                                        CommonPythonUtil.createPythonFunctionInfo(
+                                                x, inputNodes, classLoader))
                         .collect(Collectors.toList())
                         .toArray(new PythonFunctionInfo[rexCalls.size()]);
 
@@ -221,14 +226,21 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
             int[] forwardedFields,
             boolean isArrow) {
         Class<?> clazz;
-        boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig);
+        boolean isInProcessMode =
+                CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig, classLoader);
         if (isArrow) {
-            clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
+            clazz =
+                    CommonPythonUtil.loadClass(
+                            ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
         } else {
             if (isInProcessMode) {
-                clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
+                clazz =
+                        CommonPythonUtil.loadClass(
+                                PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
             } else {
-                clazz = CommonPythonUtil.loadClass(EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
+                clazz =
+                        CommonPythonUtil.loadClass(
+                                EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
             }
         }
 
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java
index 8661fd9b5b6..81940866104 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java
@@ -102,7 +102,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
         final Transformation<RowData> inputTransform =
                 (Transformation<RowData>) inputEdge.translateToPlan(planner);
         final Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         final ExecNodeConfig pythonNodeConfig =
                 ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled());
         final OneInputTransformation<RowData, RowData> transform =
@@ -111,7 +112,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
                         pythonNodeConfig,
                         planner.getFlinkContext().getClassLoader(),
                         pythonConfig);
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
         return transform;
@@ -122,7 +124,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
             ExecNodeConfig pythonNodeConfig,
             ClassLoader classLoader,
             Configuration pythonConfig) {
-        Tuple2<int[], PythonFunctionInfo> extractResult = extractPythonTableFunctionInfo();
+        Tuple2<int[], PythonFunctionInfo> extractResult =
+                extractPythonTableFunctionInfo(classLoader);
         int[] pythonUdtfInputOffsets = extractResult.f0;
         PythonFunctionInfo pythonFunctionInfo = extractResult.f1;
         InternalTypeInfo<RowData> pythonOperatorInputRowType =
@@ -146,10 +149,11 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
                 inputTransform.getParallelism());
     }
 
-    private Tuple2<int[], PythonFunctionInfo> extractPythonTableFunctionInfo() {
+    private Tuple2<int[], PythonFunctionInfo> extractPythonTableFunctionInfo(
+            ClassLoader classLoader) {
         LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>();
         PythonFunctionInfo pythonTableFunctionInfo =
-                CommonPythonUtil.createPythonFunctionInfo(invocation, inputNodes);
+                CommonPythonUtil.createPythonFunctionInfo(invocation, inputNodes, classLoader);
         int[] udtfInputOffsets =
                 inputNodes.keySet().stream()
                         .filter(x -> x instanceof RexInputRef)
@@ -168,7 +172,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
             InternalTypeInfo<RowData> outputRowType,
             PythonFunctionInfo pythonFunctionInfo,
             int[] udtfInputOffsets) {
-        boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig);
+        boolean isInProcessMode =
+                CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig, classLoader);
 
         final RowType inputType = inputRowType.toRowType();
         final RowType outputType = outputRowType.toRowType();
@@ -180,7 +185,9 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
 
         try {
             if (isInProcessMode) {
-                Class clazz = CommonPythonUtil.loadClass(PYTHON_TABLE_FUNCTION_OPERATOR_NAME);
+                Class clazz =
+                        CommonPythonUtil.loadClass(
+                                PYTHON_TABLE_FUNCTION_OPERATOR_NAME, classLoader);
                 Constructor ctor =
                         clazz.getConstructor(
                                 Configuration.class,
@@ -206,7 +213,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
                                         udtfInputOffsets));
             } else {
                 Class clazz =
-                        CommonPythonUtil.loadClass(EMBEDDED_PYTHON_TABLE_FUNCTION_OPERATOR_NAME);
+                        CommonPythonUtil.loadClass(
+                                EMBEDDED_PYTHON_TABLE_FUNCTION_OPERATOR_NAME, classLoader);
                 Constructor ctor =
                         clazz.getConstructor(
                                 Configuration.class,
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java
index 55a8a8cd3d5..4595191332b 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java
@@ -175,10 +175,12 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
         PythonAggregateFunctionInfo[] pythonFunctionInfos = aggInfosAndDataViewSpecs.f0;
         DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1;
         Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         final OneInputStreamOperator<RowData, RowData> operator =
                 getPythonAggregateFunctionOperator(
                         pythonConfig,
+                        planner.getFlinkContext().getClassLoader(),
                         inputRowType,
                         InternalTypeInfo.of(getOutputType()).toRowType(),
                         pythonFunctionInfos,
@@ -196,7 +198,8 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
                         InternalTypeInfo.of(getOutputType()),
                         inputTransform.getParallelism());
 
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
 
@@ -214,6 +217,7 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
     @SuppressWarnings("unchecked")
     private OneInputStreamOperator<RowData, RowData> getPythonAggregateFunctionOperator(
             Configuration config,
+            ClassLoader classLoader,
             RowType inputType,
             RowType outputType,
             PythonAggregateFunctionInfo[] aggregateFunctions,
@@ -222,7 +226,8 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
             long maxIdleStateRetentionTime,
             int indexOfCountStar,
             boolean countStarInserted) {
-        Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME);
+        Class<?> clazz =
+                CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME, classLoader);
         try {
             Constructor<?> ctor =
                     clazz.getConstructor(
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java
index 179b302941a..3d05d273def 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java
@@ -131,10 +131,12 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
         PythonAggregateFunctionInfo[] pythonFunctionInfos = aggInfosAndDataViewSpecs.f0;
         DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1;
         Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         OneInputStreamOperator<RowData, RowData> pythonOperator =
                 getPythonTableAggregateFunctionOperator(
                         pythonConfig,
+                        planner.getFlinkContext().getClassLoader(),
                         inputRowType,
                         InternalTypeInfo.of(getOutputType()).toRowType(),
                         pythonFunctionInfos,
@@ -153,7 +155,8 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
                         InternalTypeInfo.of(getOutputType()),
                         inputTransform.getParallelism());
 
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
 
@@ -171,6 +174,7 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
     @SuppressWarnings("unchecked")
     private OneInputStreamOperator<RowData, RowData> getPythonTableAggregateFunctionOperator(
             Configuration config,
+            ClassLoader classLoader,
             RowType inputRowType,
             RowType outputRowType,
             PythonAggregateFunctionInfo[] aggregateFunctions,
@@ -179,7 +183,9 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
             long maxIdleStateRetentionTime,
             boolean generateUpdateBefore,
             int indexOfCountStar) {
-        Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME);
+        Class<?> clazz =
+                CommonPythonUtil.loadClass(
+                        PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME, classLoader);
         try {
             Constructor<?> ctor =
                     clazz.getConstructor(
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
index 8aa55962285..6e210a642b0 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
@@ -258,7 +258,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
         WindowAssigner<?> windowAssigner = windowAssignerAndTrigger.f0;
         Trigger<?> trigger = windowAssignerAndTrigger.f1;
         final Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         final ExecNodeConfig pythonNodeConfig =
                 ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled());
         boolean isGeneralPythonUDAF =
@@ -289,6 +290,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
                             emitStrategy.getAllowLateness(),
                             pythonConfig,
                             pythonNodeConfig,
+                            planner.getFlinkContext().getClassLoader(),
                             shiftTimeZone);
         } else {
             transform =
@@ -306,7 +308,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
                             shiftTimeZone);
         }
 
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
         // set KeyType and Selector for state
@@ -436,6 +439,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
                     long allowance,
                     Configuration pythonConfig,
                     ExecNodeConfig pythonNodeConfig,
+                    ClassLoader classLoader,
                     ZoneId shiftTimeZone) {
         final int inputCountIndex = aggInfoList.getIndexOfCountStar();
         final boolean countStarInserted = aggInfoList.countStarInserted();
@@ -446,6 +450,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
         OneInputStreamOperator<RowData, RowData> pythonOperator =
                 getGeneralPythonStreamGroupWindowAggregateFunctionOperator(
                         pythonConfig,
+                        classLoader,
                         inputRowType,
                         outputRowType,
                         windowAssigner,
@@ -484,7 +489,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
                     ZoneId shiftTimeZone) {
         Class clazz =
                 CommonPythonUtil.loadClass(
-                        ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+                        ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME,
+                        classLoader);
         RowType userDefinedFunctionInputType =
                 (RowType) Projection.of(udafInputOffsets).project(inputRowType);
         RowType userDefinedFunctionOutputType =
@@ -542,6 +548,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
     private OneInputStreamOperator<RowData, RowData>
             getGeneralPythonStreamGroupWindowAggregateFunctionOperator(
                     Configuration config,
+                    ClassLoader classLoader,
                     RowType inputType,
                     RowType outputType,
                     WindowAssigner<?> windowAssigner,
@@ -555,7 +562,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
                     ZoneId shiftTimeZone) {
         Class clazz =
                 CommonPythonUtil.loadClass(
-                        GENERAL_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+                        GENERAL_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME,
+                        classLoader);
 
         boolean isRowTime = AggregateUtil.isRowtimeAttribute(window.timeAttribute());
 
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java
index fd507b97d4a..d1057bcab6f 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java
@@ -197,7 +197,8 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
         }
         long precedingOffset = -1 * (long) boundValue;
         Configuration pythonConfig =
-                CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+                CommonPythonUtil.extractPythonConfiguration(
+                        planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
         OneInputTransformation<RowData, RowData> transform =
                 createPythonOneInputTransformation(
                         inputTransform,
@@ -213,7 +214,8 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
                         config,
                         planner.getFlinkContext().getClassLoader());
 
-        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+                pythonConfig, planner.getFlinkContext().getClassLoader())) {
             transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
         }
 
@@ -306,7 +308,7 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
                 className =
                         ARROW_PYTHON_OVER_WINDOW_ROWS_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME;
             }
-            Class<?> clazz = CommonPythonUtil.loadClass(className);
+            Class<?> clazz = CommonPythonUtil.loadClass(className, classLoader);
 
             try {
                 Constructor<?> ctor =
@@ -349,7 +351,7 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
                 className =
                         ARROW_PYTHON_OVER_WINDOW_RANGE_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME;
             }
-            Class<?> clazz = CommonPythonUtil.loadClass(className);
+            Class<?> clazz = CommonPythonUtil.loadClass(className, classLoader);
             try {
                 Constructor<?> ctor =
                         clazz.getConstructor(
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index a949ad2afa6..201407b718a 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -98,9 +98,9 @@ public class CommonPythonUtil {
 
     private CommonPythonUtil() {}
 
-    public static Class<?> loadClass(String className) {
+    public static Class<?> loadClass(String className, ClassLoader classLoader) {
         try {
-            return Class.forName(className, false, Thread.currentThread().getContextClassLoader());
+            return Class.forName(className, false, classLoader);
         } catch (ClassNotFoundException e) {
             throw new TableException(
                     "The dependency of 'flink-python' is not present on the classpath.", e);
@@ -108,8 +108,8 @@ public class CommonPythonUtil {
     }
 
     public static Configuration extractPythonConfiguration(
-            StreamExecutionEnvironment env, ReadableConfig tableConfig) {
-        Class<?> clazz = loadClass(PYTHON_CONFIG_UTILS_CLASS);
+            StreamExecutionEnvironment env, ReadableConfig tableConfig, ClassLoader classLoader) {
+        Class<?> clazz = loadClass(PYTHON_CONFIG_UTILS_CLASS, classLoader);
         try {
             StreamExecutionEnvironment realEnv = getRealEnvironment(env);
             Method method =
@@ -125,20 +125,27 @@ public class CommonPythonUtil {
     }
 
     public static PythonFunctionInfo createPythonFunctionInfo(
-            RexCall pythonRexCall, Map<RexNode, Integer> inputNodes) {
+            RexCall pythonRexCall, Map<RexNode, Integer> inputNodes, ClassLoader classLoader) {
         SqlOperator operator = pythonRexCall.getOperator();
         try {
             if (operator instanceof ScalarSqlFunction) {
                 return createPythonFunctionInfo(
-                        pythonRexCall, inputNodes, ((ScalarSqlFunction) operator).scalarFunction());
+                        pythonRexCall,
+                        inputNodes,
+                        ((ScalarSqlFunction) operator).scalarFunction(),
+                        classLoader);
             } else if (operator instanceof TableSqlFunction) {
                 return createPythonFunctionInfo(
-                        pythonRexCall, inputNodes, ((TableSqlFunction) operator).udtf());
+                        pythonRexCall,
+                        inputNodes,
+                        ((TableSqlFunction) operator).udtf(),
+                        classLoader);
             } else if (operator instanceof BridgingSqlFunction) {
                 return createPythonFunctionInfo(
                         pythonRexCall,
                         inputNodes,
-                        ((BridgingSqlFunction) operator).getDefinition());
+                        ((BridgingSqlFunction) operator).getDefinition(),
+                        classLoader);
             }
         } catch (InvocationTargetException | IllegalAccessException e) {
             throw new TableException("Method pickleValue accessed failed. ", e);
@@ -147,8 +154,9 @@ public class CommonPythonUtil {
     }
 
     @SuppressWarnings("unchecked")
-    public static boolean isPythonWorkerUsingManagedMemory(Configuration config) {
-        Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS);
+    public static boolean isPythonWorkerUsingManagedMemory(
+            Configuration config, ClassLoader classLoader) {
+        Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS, classLoader);
         try {
             return config.getBoolean(
                     (ConfigOption<Boolean>) (clazz.getField("USE_MANAGED_MEMORY").get(null)));
@@ -158,8 +166,9 @@ public class CommonPythonUtil {
     }
 
     @SuppressWarnings("unchecked")
-    public static boolean isPythonWorkerInProcessMode(Configuration config) {
-        Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS);
+    public static boolean isPythonWorkerInProcessMode(
+            Configuration config, ClassLoader classLoader) {
+        Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS, classLoader);
         try {
             return config.getString(
                             (ConfigOption<String>)
@@ -337,7 +346,8 @@ public class CommonPythonUtil {
                         });
     }
 
-    private static byte[] convertLiteralToPython(RexLiteral o, SqlTypeName typeName)
+    private static byte[] convertLiteralToPython(
+            RexLiteral o, SqlTypeName typeName, ClassLoader classLoader)
             throws InvocationTargetException, IllegalAccessException {
         byte type;
         Object value;
@@ -396,16 +406,18 @@ public class CommonPythonUtil {
                     throw new RuntimeException("Unsupported type " + typeName);
             }
         }
-        loadPickleValue();
+        loadPickleValue(classLoader);
         return (byte[]) pickleValue.invoke(null, value, type);
     }
 
-    private static void loadPickleValue() {
+    private static void loadPickleValue(ClassLoader classLoader) {
         if (pickleValue == null) {
             synchronized (CommonPythonUtil.class) {
                 if (pickleValue == null) {
                     Class<?> clazz =
-                            loadClass("org.apache.flink.api.common.python.PythonBridgeUtils");
+                            loadClass(
+                                    "org.apache.flink.api.common.python.PythonBridgeUtils",
+                                    classLoader);
                     try {
                         pickleValue = clazz.getMethod("pickleValue", Object.class, byte.class);
                     } catch (NoSuchMethodException e) {
@@ -419,18 +431,21 @@ public class CommonPythonUtil {
     private static PythonFunctionInfo createPythonFunctionInfo(
             RexCall pythonRexCall,
             Map<RexNode, Integer> inputNodes,
-            FunctionDefinition functionDefinition)
+            FunctionDefinition functionDefinition,
+            ClassLoader classLoader)
             throws InvocationTargetException, IllegalAccessException {
         ArrayList<Object> inputs = new ArrayList<>();
         for (RexNode operand : pythonRexCall.getOperands()) {
             if (operand instanceof RexCall) {
                 RexCall childPythonRexCall = (RexCall) operand;
                 PythonFunctionInfo argPythonInfo =
-                        createPythonFunctionInfo(childPythonRexCall, inputNodes);
+                        createPythonFunctionInfo(childPythonRexCall, inputNodes, classLoader);
                 inputs.add(argPythonInfo);
             } else if (operand instanceof RexLiteral) {
                 RexLiteral literal = (RexLiteral) operand;
-                inputs.add(convertLiteralToPython(literal, literal.getType().getSqlTypeName()));
+                inputs.add(
+                        convertLiteralToPython(
+                                literal, literal.getType().getSqlTypeName(), classLoader));
             } else {
                 if (inputNodes.containsKey(operand)) {
                     inputs.add(inputNodes.get(operand));
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java
index 397476b0ebb..ad1b15e8ea8 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java
@@ -25,6 +25,7 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalMatch;
 import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
 import org.apache.flink.table.planner.plan.utils.MatchUtil.AggregationPatternVariableFinder;
 import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
 
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptRule;
@@ -87,7 +88,7 @@ public abstract class CommonPhysicalMatchRule extends ConverterRule {
             Class.forName(
                     "org.apache.flink.cep.pattern.Pattern",
                     false,
-                    Thread.currentThread().getContextClassLoader());
+                    ShortcutUtils.unwrapContext(rel).getClassLoader());
         } catch (ClassNotFoundException e) {
             throw new TableException(
                     "MATCH RECOGNIZE clause requires flink-cep dependency to be present on the classpath.",
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
index 8fb8f80d37d..2681ed64fad 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
@@ -473,7 +473,7 @@ abstract class PlannerBase(
     tableConfig.set(TABLE_QUERY_CURRENT_DATABASE, currentDatabase)
 
     // We pass only the configuration to avoid reconfiguration with the rootConfiguration
-    getExecEnv.configure(tableConfig.getConfiguration, Thread.currentThread().getContextClassLoader)
+    getExecEnv.configure(tableConfig.getConfiguration, classLoader)
 
     // Use config parallelism to override env parallelism.
     val defaultParallelism =
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala
index b1b58d803c0..ef387d485d1 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala
@@ -20,7 +20,6 @@ package org.apache.flink.table.planner.plan.rules.physical.batch
 import org.apache.flink.annotation.Experimental
 import org.apache.flink.configuration.ConfigOption
 import org.apache.flink.configuration.ConfigOptions.key
-import org.apache.flink.table.planner.calcite.FlinkContext
 import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
 import org.apache.flink.table.planner.plan.nodes.FlinkConventions
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalSort