You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/09/07 02:31:41 UTC
[flink] 01/03: [fixup][table-planner] Using user classloader instead of thread context classloader
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 523546101f0180999f11d68269aad53c59134064
Author: fengli <ld...@163.com>
AuthorDate: Mon Aug 29 20:10:53 2022 +0800
[fixup][table-planner] Using user classloader instead of thread context classloader
---
.../exec/batch/BatchExecPythonGroupAggregate.java | 9 ++--
.../batch/BatchExecPythonGroupWindowAggregate.java | 8 ++--
.../exec/batch/BatchExecPythonOverAggregate.java | 8 ++--
.../nodes/exec/common/CommonExecPythonCalc.java | 30 ++++++++----
.../exec/common/CommonExecPythonCorrelate.java | 24 ++++++----
.../stream/StreamExecPythonGroupAggregate.java | 11 +++--
.../StreamExecPythonGroupTableAggregate.java | 12 +++--
.../StreamExecPythonGroupWindowAggregate.java | 16 +++++--
.../exec/stream/StreamExecPythonOverAggregate.java | 10 ++--
.../plan/nodes/exec/utils/CommonPythonUtil.java | 53 ++++++++++++++--------
.../physical/common/CommonPhysicalMatchRule.java | 3 +-
.../table/planner/delegation/PlannerBase.scala | 2 +-
.../physical/batch/BatchPhysicalSortRule.scala | 1 -
13 files changed, 125 insertions(+), 62 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java
index dbb6033c364..98e2ca2551d 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java
@@ -94,7 +94,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
final RowType inputRowType = (RowType) inputEdge.getOutputType();
final RowType outputRowType = InternalTypeInfo.of(getOutputType()).toRowType();
Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> transform =
createPythonOneInputTransformation(
inputTransform,
@@ -104,7 +105,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
config,
planner.getFlinkContext().getClassLoader());
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return transform;
@@ -149,7 +151,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData>
int[] udafInputOffsets,
PythonFunctionInfo[] pythonFunctionInfos) {
final Class<?> clazz =
- CommonPythonUtil.loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME);
+ CommonPythonUtil.loadClass(
+ ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader);
RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType);
RowType udfOutputType =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java
index ae8a9c2ad02..930a2f7fe59 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java
@@ -114,7 +114,8 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
final Tuple2<Long, Long> windowSizeAndSlideSize = WindowCodeGenerator.getWindowDef(window);
final Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
int groupBufferLimitSize =
pythonConfig.getInteger(
ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT);
@@ -130,7 +131,8 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
pythonConfig,
config,
planner.getFlinkContext().getClassLoader());
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return transform;
@@ -204,7 +206,7 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData>
PythonFunctionInfo[] pythonFunctionInfos) {
Class<?> clazz =
CommonPythonUtil.loadClass(
- ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+ ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader);
RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType);
RowType udfOutputType =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java
index 5023931259f..9f4717aa5ef 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java
@@ -153,7 +153,8 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {
}
}
Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> transform =
createPythonOneInputTransformation(
inputTransform,
@@ -163,7 +164,8 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {
pythonConfig,
config,
planner.getFlinkContext().getClassLoader());
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return transform;
@@ -213,7 +215,7 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase {
PythonFunctionInfo[] pythonFunctionInfos) {
Class<?> clazz =
CommonPythonUtil.loadClass(
- ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+ ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader);
RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType);
RowType udfOutputType =
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java
index e102de9d063..d0249791edd 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java
@@ -108,14 +108,16 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
final Transformation<RowData> inputTransform =
(Transformation<RowData>) inputEdge.translateToPlan(planner);
final Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> ret =
createPythonOneInputTransformation(
inputTransform,
config,
planner.getFlinkContext().getClassLoader(),
pythonConfig);
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return ret;
@@ -139,7 +141,7 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
.collect(Collectors.toList());
Tuple2<int[], PythonFunctionInfo[]> extractResult =
- extractPythonScalarFunctionInfos(pythonRexCalls);
+ extractPythonScalarFunctionInfos(pythonRexCalls, classLoader);
int[] pythonUdfInputOffsets = extractResult.f0;
PythonFunctionInfo[] pythonFunctionInfos = extractResult.f1;
LogicalType[] inputLogicalTypes =
@@ -185,11 +187,14 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
}
private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos(
- List<RexCall> rexCalls) {
+ List<RexCall> rexCalls, ClassLoader classLoader) {
LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>();
PythonFunctionInfo[] pythonFunctionInfos =
rexCalls.stream()
- .map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes))
+ .map(
+ x ->
+ CommonPythonUtil.createPythonFunctionInfo(
+ x, inputNodes, classLoader))
.collect(Collectors.toList())
.toArray(new PythonFunctionInfo[rexCalls.size()]);
@@ -221,14 +226,21 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
int[] forwardedFields,
boolean isArrow) {
Class<?> clazz;
- boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig);
+ boolean isInProcessMode =
+ CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig, classLoader);
if (isArrow) {
- clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
+ clazz =
+ CommonPythonUtil.loadClass(
+ ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
} else {
if (isInProcessMode) {
- clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
+ clazz =
+ CommonPythonUtil.loadClass(
+ PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
} else {
- clazz = CommonPythonUtil.loadClass(EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME);
+ clazz =
+ CommonPythonUtil.loadClass(
+ EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
}
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java
index 8661fd9b5b6..81940866104 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java
@@ -102,7 +102,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
final Transformation<RowData> inputTransform =
(Transformation<RowData>) inputEdge.translateToPlan(planner);
final Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
final ExecNodeConfig pythonNodeConfig =
ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled());
final OneInputTransformation<RowData, RowData> transform =
@@ -111,7 +112,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
pythonNodeConfig,
planner.getFlinkContext().getClassLoader(),
pythonConfig);
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return transform;
@@ -122,7 +124,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
ExecNodeConfig pythonNodeConfig,
ClassLoader classLoader,
Configuration pythonConfig) {
- Tuple2<int[], PythonFunctionInfo> extractResult = extractPythonTableFunctionInfo();
+ Tuple2<int[], PythonFunctionInfo> extractResult =
+ extractPythonTableFunctionInfo(classLoader);
int[] pythonUdtfInputOffsets = extractResult.f0;
PythonFunctionInfo pythonFunctionInfo = extractResult.f1;
InternalTypeInfo<RowData> pythonOperatorInputRowType =
@@ -146,10 +149,11 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
inputTransform.getParallelism());
}
- private Tuple2<int[], PythonFunctionInfo> extractPythonTableFunctionInfo() {
+ private Tuple2<int[], PythonFunctionInfo> extractPythonTableFunctionInfo(
+ ClassLoader classLoader) {
LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>();
PythonFunctionInfo pythonTableFunctionInfo =
- CommonPythonUtil.createPythonFunctionInfo(invocation, inputNodes);
+ CommonPythonUtil.createPythonFunctionInfo(invocation, inputNodes, classLoader);
int[] udtfInputOffsets =
inputNodes.keySet().stream()
.filter(x -> x instanceof RexInputRef)
@@ -168,7 +172,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
InternalTypeInfo<RowData> outputRowType,
PythonFunctionInfo pythonFunctionInfo,
int[] udtfInputOffsets) {
- boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig);
+ boolean isInProcessMode =
+ CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig, classLoader);
final RowType inputType = inputRowType.toRowType();
final RowType outputType = outputRowType.toRowType();
@@ -180,7 +185,9 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
try {
if (isInProcessMode) {
- Class clazz = CommonPythonUtil.loadClass(PYTHON_TABLE_FUNCTION_OPERATOR_NAME);
+ Class clazz =
+ CommonPythonUtil.loadClass(
+ PYTHON_TABLE_FUNCTION_OPERATOR_NAME, classLoader);
Constructor ctor =
clazz.getConstructor(
Configuration.class,
@@ -206,7 +213,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
udtfInputOffsets));
} else {
Class clazz =
- CommonPythonUtil.loadClass(EMBEDDED_PYTHON_TABLE_FUNCTION_OPERATOR_NAME);
+ CommonPythonUtil.loadClass(
+ EMBEDDED_PYTHON_TABLE_FUNCTION_OPERATOR_NAME, classLoader);
Constructor ctor =
clazz.getConstructor(
Configuration.class,
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java
index 55a8a8cd3d5..4595191332b 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java
@@ -175,10 +175,12 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
PythonAggregateFunctionInfo[] pythonFunctionInfos = aggInfosAndDataViewSpecs.f0;
DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1;
Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
final OneInputStreamOperator<RowData, RowData> operator =
getPythonAggregateFunctionOperator(
pythonConfig,
+ planner.getFlinkContext().getClassLoader(),
inputRowType,
InternalTypeInfo.of(getOutputType()).toRowType(),
pythonFunctionInfos,
@@ -196,7 +198,8 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
InternalTypeInfo.of(getOutputType()),
inputTransform.getParallelism());
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
@@ -214,6 +217,7 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
@SuppressWarnings("unchecked")
private OneInputStreamOperator<RowData, RowData> getPythonAggregateFunctionOperator(
Configuration config,
+ ClassLoader classLoader,
RowType inputType,
RowType outputType,
PythonAggregateFunctionInfo[] aggregateFunctions,
@@ -222,7 +226,8 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
long maxIdleStateRetentionTime,
int indexOfCountStar,
boolean countStarInserted) {
- Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME);
+ Class<?> clazz =
+ CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME, classLoader);
try {
Constructor<?> ctor =
clazz.getConstructor(
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java
index 179b302941a..3d05d273def 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java
@@ -131,10 +131,12 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
PythonAggregateFunctionInfo[] pythonFunctionInfos = aggInfosAndDataViewSpecs.f0;
DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1;
Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
OneInputStreamOperator<RowData, RowData> pythonOperator =
getPythonTableAggregateFunctionOperator(
pythonConfig,
+ planner.getFlinkContext().getClassLoader(),
inputRowType,
InternalTypeInfo.of(getOutputType()).toRowType(),
pythonFunctionInfos,
@@ -153,7 +155,8 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
InternalTypeInfo.of(getOutputType()),
inputTransform.getParallelism());
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
@@ -171,6 +174,7 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
@SuppressWarnings("unchecked")
private OneInputStreamOperator<RowData, RowData> getPythonTableAggregateFunctionOperator(
Configuration config,
+ ClassLoader classLoader,
RowType inputRowType,
RowType outputRowType,
PythonAggregateFunctionInfo[] aggregateFunctions,
@@ -179,7 +183,9 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData>
long maxIdleStateRetentionTime,
boolean generateUpdateBefore,
int indexOfCountStar) {
- Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME);
+ Class<?> clazz =
+ CommonPythonUtil.loadClass(
+ PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME, classLoader);
try {
Constructor<?> ctor =
clazz.getConstructor(
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
index 8aa55962285..6e210a642b0 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java
@@ -258,7 +258,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
WindowAssigner<?> windowAssigner = windowAssignerAndTrigger.f0;
Trigger<?> trigger = windowAssignerAndTrigger.f1;
final Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
final ExecNodeConfig pythonNodeConfig =
ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled());
boolean isGeneralPythonUDAF =
@@ -289,6 +290,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
emitStrategy.getAllowLateness(),
pythonConfig,
pythonNodeConfig,
+ planner.getFlinkContext().getClassLoader(),
shiftTimeZone);
} else {
transform =
@@ -306,7 +308,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
shiftTimeZone);
}
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
// set KeyType and Selector for state
@@ -436,6 +439,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
long allowance,
Configuration pythonConfig,
ExecNodeConfig pythonNodeConfig,
+ ClassLoader classLoader,
ZoneId shiftTimeZone) {
final int inputCountIndex = aggInfoList.getIndexOfCountStar();
final boolean countStarInserted = aggInfoList.countStarInserted();
@@ -446,6 +450,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
OneInputStreamOperator<RowData, RowData> pythonOperator =
getGeneralPythonStreamGroupWindowAggregateFunctionOperator(
pythonConfig,
+ classLoader,
inputRowType,
outputRowType,
windowAssigner,
@@ -484,7 +489,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
ZoneId shiftTimeZone) {
Class clazz =
CommonPythonUtil.loadClass(
- ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+ ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME,
+ classLoader);
RowType userDefinedFunctionInputType =
(RowType) Projection.of(udafInputOffsets).project(inputRowType);
RowType userDefinedFunctionOutputType =
@@ -542,6 +548,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
private OneInputStreamOperator<RowData, RowData>
getGeneralPythonStreamGroupWindowAggregateFunctionOperator(
Configuration config,
+ ClassLoader classLoader,
RowType inputType,
RowType outputType,
WindowAssigner<?> windowAssigner,
@@ -555,7 +562,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas
ZoneId shiftTimeZone) {
Class clazz =
CommonPythonUtil.loadClass(
- GENERAL_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME);
+ GENERAL_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME,
+ classLoader);
boolean isRowTime = AggregateUtil.isRowtimeAttribute(window.timeAttribute());
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java
index fd507b97d4a..d1057bcab6f 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java
@@ -197,7 +197,8 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
}
long precedingOffset = -1 * (long) boundValue;
Configuration pythonConfig =
- CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
+ CommonPythonUtil.extractPythonConfiguration(
+ planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader());
OneInputTransformation<RowData, RowData> transform =
createPythonOneInputTransformation(
inputTransform,
@@ -213,7 +214,8 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
config,
planner.getFlinkContext().getClassLoader());
- if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
+ if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(
+ pythonConfig, planner.getFlinkContext().getClassLoader())) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
@@ -306,7 +308,7 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
className =
ARROW_PYTHON_OVER_WINDOW_ROWS_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME;
}
- Class<?> clazz = CommonPythonUtil.loadClass(className);
+ Class<?> clazz = CommonPythonUtil.loadClass(className, classLoader);
try {
Constructor<?> ctor =
@@ -349,7 +351,7 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
className =
ARROW_PYTHON_OVER_WINDOW_RANGE_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME;
}
- Class<?> clazz = CommonPythonUtil.loadClass(className);
+ Class<?> clazz = CommonPythonUtil.loadClass(className, classLoader);
try {
Constructor<?> ctor =
clazz.getConstructor(
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index a949ad2afa6..201407b718a 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -98,9 +98,9 @@ public class CommonPythonUtil {
private CommonPythonUtil() {}
- public static Class<?> loadClass(String className) {
+ public static Class<?> loadClass(String className, ClassLoader classLoader) {
try {
- return Class.forName(className, false, Thread.currentThread().getContextClassLoader());
+ return Class.forName(className, false, classLoader);
} catch (ClassNotFoundException e) {
throw new TableException(
"The dependency of 'flink-python' is not present on the classpath.", e);
@@ -108,8 +108,8 @@ public class CommonPythonUtil {
}
public static Configuration extractPythonConfiguration(
- StreamExecutionEnvironment env, ReadableConfig tableConfig) {
- Class<?> clazz = loadClass(PYTHON_CONFIG_UTILS_CLASS);
+ StreamExecutionEnvironment env, ReadableConfig tableConfig, ClassLoader classLoader) {
+ Class<?> clazz = loadClass(PYTHON_CONFIG_UTILS_CLASS, classLoader);
try {
StreamExecutionEnvironment realEnv = getRealEnvironment(env);
Method method =
@@ -125,20 +125,27 @@ public class CommonPythonUtil {
}
public static PythonFunctionInfo createPythonFunctionInfo(
- RexCall pythonRexCall, Map<RexNode, Integer> inputNodes) {
+ RexCall pythonRexCall, Map<RexNode, Integer> inputNodes, ClassLoader classLoader) {
SqlOperator operator = pythonRexCall.getOperator();
try {
if (operator instanceof ScalarSqlFunction) {
return createPythonFunctionInfo(
- pythonRexCall, inputNodes, ((ScalarSqlFunction) operator).scalarFunction());
+ pythonRexCall,
+ inputNodes,
+ ((ScalarSqlFunction) operator).scalarFunction(),
+ classLoader);
} else if (operator instanceof TableSqlFunction) {
return createPythonFunctionInfo(
- pythonRexCall, inputNodes, ((TableSqlFunction) operator).udtf());
+ pythonRexCall,
+ inputNodes,
+ ((TableSqlFunction) operator).udtf(),
+ classLoader);
} else if (operator instanceof BridgingSqlFunction) {
return createPythonFunctionInfo(
pythonRexCall,
inputNodes,
- ((BridgingSqlFunction) operator).getDefinition());
+ ((BridgingSqlFunction) operator).getDefinition(),
+ classLoader);
}
} catch (InvocationTargetException | IllegalAccessException e) {
throw new TableException("Method pickleValue accessed failed. ", e);
@@ -147,8 +154,9 @@ public class CommonPythonUtil {
}
@SuppressWarnings("unchecked")
- public static boolean isPythonWorkerUsingManagedMemory(Configuration config) {
- Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS);
+ public static boolean isPythonWorkerUsingManagedMemory(
+ Configuration config, ClassLoader classLoader) {
+ Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS, classLoader);
try {
return config.getBoolean(
(ConfigOption<Boolean>) (clazz.getField("USE_MANAGED_MEMORY").get(null)));
@@ -158,8 +166,9 @@ public class CommonPythonUtil {
}
@SuppressWarnings("unchecked")
- public static boolean isPythonWorkerInProcessMode(Configuration config) {
- Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS);
+ public static boolean isPythonWorkerInProcessMode(
+ Configuration config, ClassLoader classLoader) {
+ Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS, classLoader);
try {
return config.getString(
(ConfigOption<String>)
@@ -337,7 +346,8 @@ public class CommonPythonUtil {
});
}
- private static byte[] convertLiteralToPython(RexLiteral o, SqlTypeName typeName)
+ private static byte[] convertLiteralToPython(
+ RexLiteral o, SqlTypeName typeName, ClassLoader classLoader)
throws InvocationTargetException, IllegalAccessException {
byte type;
Object value;
@@ -396,16 +406,18 @@ public class CommonPythonUtil {
throw new RuntimeException("Unsupported type " + typeName);
}
}
- loadPickleValue();
+ loadPickleValue(classLoader);
return (byte[]) pickleValue.invoke(null, value, type);
}
- private static void loadPickleValue() {
+ private static void loadPickleValue(ClassLoader classLoader) {
if (pickleValue == null) {
synchronized (CommonPythonUtil.class) {
if (pickleValue == null) {
Class<?> clazz =
- loadClass("org.apache.flink.api.common.python.PythonBridgeUtils");
+ loadClass(
+ "org.apache.flink.api.common.python.PythonBridgeUtils",
+ classLoader);
try {
pickleValue = clazz.getMethod("pickleValue", Object.class, byte.class);
} catch (NoSuchMethodException e) {
@@ -419,18 +431,21 @@ public class CommonPythonUtil {
private static PythonFunctionInfo createPythonFunctionInfo(
RexCall pythonRexCall,
Map<RexNode, Integer> inputNodes,
- FunctionDefinition functionDefinition)
+ FunctionDefinition functionDefinition,
+ ClassLoader classLoader)
throws InvocationTargetException, IllegalAccessException {
ArrayList<Object> inputs = new ArrayList<>();
for (RexNode operand : pythonRexCall.getOperands()) {
if (operand instanceof RexCall) {
RexCall childPythonRexCall = (RexCall) operand;
PythonFunctionInfo argPythonInfo =
- createPythonFunctionInfo(childPythonRexCall, inputNodes);
+ createPythonFunctionInfo(childPythonRexCall, inputNodes, classLoader);
inputs.add(argPythonInfo);
} else if (operand instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) operand;
- inputs.add(convertLiteralToPython(literal, literal.getType().getSqlTypeName()));
+ inputs.add(
+ convertLiteralToPython(
+ literal, literal.getType().getSqlTypeName(), classLoader));
} else {
if (inputNodes.containsKey(operand)) {
inputs.add(inputNodes.get(operand));
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java
index 397476b0ebb..ad1b15e8ea8 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java
@@ -25,6 +25,7 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalMatch;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.MatchUtil.AggregationPatternVariableFinder;
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
@@ -87,7 +88,7 @@ public abstract class CommonPhysicalMatchRule extends ConverterRule {
Class.forName(
"org.apache.flink.cep.pattern.Pattern",
false,
- Thread.currentThread().getContextClassLoader());
+ ShortcutUtils.unwrapContext(rel).getClassLoader());
} catch (ClassNotFoundException e) {
throw new TableException(
"MATCH RECOGNIZE clause requires flink-cep dependency to be present on the classpath.",
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
index 8fb8f80d37d..2681ed64fad 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
@@ -473,7 +473,7 @@ abstract class PlannerBase(
tableConfig.set(TABLE_QUERY_CURRENT_DATABASE, currentDatabase)
// We pass only the configuration to avoid reconfiguration with the rootConfiguration
- getExecEnv.configure(tableConfig.getConfiguration, Thread.currentThread().getContextClassLoader)
+ getExecEnv.configure(tableConfig.getConfiguration, classLoader)
// Use config parallelism to override env parallelism.
val defaultParallelism =
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala
index b1b58d803c0..ef387d485d1 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala
@@ -20,7 +20,6 @@ package org.apache.flink.table.planner.plan.rules.physical.batch
import org.apache.flink.annotation.Experimental
import org.apache.flink.configuration.ConfigOption
import org.apache.flink.configuration.ConfigOptions.key
-import org.apache.flink.table.planner.calcite.FlinkContext
import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalSort