You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2021/11/11 02:04:31 UTC

[flink] branch release-1.12 updated: [FLINK-24860][python] Fix the wrong position mappings in the Python UDTF

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch release-1.12
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.12 by this push:
     new 7b2ae80  [FLINK-24860][python] Fix the wrong position mappings in the Python UDTF
7b2ae80 is described below

commit 7b2ae80c943fdb91c15d9f844ddab4a9a1aabe73
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Wed Nov 10 19:37:04 2021 +0800

    [FLINK-24860][python] Fix the wrong position mappings in the Python UDTF
    
    This closes #17752.
---
 .../rules/logical/PythonCorrelateSplitRule.java    | 71 ++++++++++++++----
 .../plan/rules/logical/PythonCalcSplitRule.scala   | 10 ++-
 .../CalcPythonCorrelateTransposeRuleTest.xml       |  8 +--
 .../rules/logical/PythonCorrelateSplitRuleTest.xml | 17 +++--
 .../rules/logical/PythonCorrelateSplitRule.java    | 83 ++++++++++++++++++----
 .../plan/rules/logical/PythonCalcSplitRule.scala   | 15 +++-
 6 files changed, 161 insertions(+), 43 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
index dc3e0be..15eefdf 100644
--- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
@@ -23,14 +23,17 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
 import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecCorrelateRule;
 import org.apache.flink.table.planner.plan.utils.PythonUtil;
+import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
 
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.hep.HepRelVertex;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
 import org.apache.calcite.rex.RexFieldAccess;
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
@@ -112,10 +115,41 @@ public class PythonCorrelateSplitRule extends RelOptRule {
         for (int i = 0; i < primitiveFieldCount; i++) {
             calcProjects.add(RexInputRef.of(i, rowType));
         }
+        // change RexCorrelVariable to RexInputRef.
+        RexDefaultVisitor<RexNode> visitor =
+                new RexDefaultVisitor<RexNode>() {
+                    @Override
+                    public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+                        RexNode expr = fieldAccess.getReferenceExpr();
+                        if (expr instanceof RexCorrelVariable) {
+                            RelDataTypeField field = fieldAccess.getField();
+                            return new RexInputRef(field.getIndex(), field.getType());
+                        } else {
+                            return rexBuilder.makeFieldAccess(
+                                    expr.accept(this), fieldAccess.getField().getIndex());
+                        }
+                    }
+
+                    @Override
+                    public RexNode visitNode(RexNode rexNode) {
+                        return rexNode;
+                    }
+                };
         // add the fields of the extracted rex calls.
         Iterator<RexNode> iterator = extractedRexNodes.iterator();
         while (iterator.hasNext()) {
-            calcProjects.add(iterator.next());
+            RexNode rexNode = iterator.next();
+            if (rexNode instanceof RexCall) {
+                RexCall rexCall = (RexCall) rexNode;
+                List<RexNode> newProjects =
+                        rexCall.getOperands().stream()
+                                .map(x -> x.accept(visitor))
+                                .collect(Collectors.toList());
+                RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
+                calcProjects.add(newRexCall);
+            } else {
+                calcProjects.add(rexNode);
+            }
         }
 
         List<String> nameList = new LinkedList<>();
@@ -252,18 +286,31 @@ public class PythonCorrelateSplitRule extends RelOptRule {
                     mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
         }
 
-        FlinkLogicalCalc leftCalc =
-                createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
+        FlinkLogicalCorrelate newCorrelate;
+        if (extractedRexNodes.size() > 0) {
+            FlinkLogicalCalc leftCalc =
+                    createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
 
-        FlinkLogicalCorrelate newCorrelate =
-                new FlinkLogicalCorrelate(
-                        correlate.getCluster(),
-                        correlate.getTraitSet(),
-                        leftCalc,
-                        rightNewInput,
-                        correlate.getCorrelationId(),
-                        correlate.getRequiredColumns(),
-                        correlate.getJoinType());
+            newCorrelate =
+                    new FlinkLogicalCorrelate(
+                            correlate.getCluster(),
+                            correlate.getTraitSet(),
+                            leftCalc,
+                            rightNewInput,
+                            correlate.getCorrelationId(),
+                            correlate.getRequiredColumns(),
+                            correlate.getJoinType());
+        } else {
+            newCorrelate =
+                    new FlinkLogicalCorrelate(
+                            correlate.getCluster(),
+                            correlate.getTraitSet(),
+                            left,
+                            rightNewInput,
+                            correlate.getCorrelationId(),
+                            correlate.getRequiredColumns(),
+                            correlate.getJoinType());
+        }
 
         FlinkLogicalCalc newTopCalc =
                 createTopCalc(
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
index 10596b3..c6c9985 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
@@ -22,7 +22,7 @@ import java.util.function.Function
 
 import org.apache.calcite.plan.RelOptRule.{any, operand}
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rex.{RexBuilder, RexCall, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
+import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
 import org.apache.calcite.sql.validate.SqlValidatorUtil
 import org.apache.flink.table.functions.ScalarFunction
 import org.apache.flink.table.functions.python.PythonFunctionKind
@@ -393,7 +393,13 @@ private class ScalarFunctionSplitter(
       expr match {
         case localRef: RexLocalRef if containsPythonCall(program.expandLocalRef(localRef))
           => getExtractedRexFieldAccess(fieldAccess, localRef.getIndex)
-        case _ => getExtractedRexNode(fieldAccess)
+        case _: RexCorrelVariable =>
+          val field = fieldAccess.getField
+          new RexInputRef(field.getIndex, field.getType)
+        case _ =>
+          val newFieldAccess = rexBuilder.makeFieldAccess(
+            expr.accept(this), fieldAccess.getField.getIndex)
+          getExtractedRexNode(newFieldAccess)
       }
     } else {
       fieldAccess
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
index 6a1c397..7c146e2 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
@@ -31,12 +31,12 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3], y=[$4])
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f10 AS f1], where=[AND(=(f0, 2), =(+(f10, 1), *(f10, f10)), =(f00, a))])
-+- FlinkLogicalCalc(select=[a, b, c, f00, f10, pyFunc(f00, f00) AS f0])
+FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f1], where=[AND(=(f0, 2), =(+(f1, 1), *(f1, f1)), =(f00, a))])
++- FlinkLogicalCalc(select=[a, b, c, f00, f1, pyFunc(f00, f00) AS f0])
    +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])
-      :- FlinkLogicalCalc(select=[a, b, c, *($cor0.a, $cor0.a) AS f0, $cor0.b AS f1])
+      :- FlinkLogicalCalc(select=[a, b, c, *(a, a) AS f0])
       :  +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-      +- FlinkLogicalTableFunctionScan(invocation=[func($3, $4)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+      +- FlinkLogicalTableFunctionScan(invocation=[func($3, $1)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
 ]]>
     </Resource>
   </TestCase>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRuleTest.xml
index 4fea72c..829de02 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRuleTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRuleTest.xml
@@ -32,9 +32,8 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3])
       <![CDATA[
 FlinkLogicalCalc(select=[a, b, c, f00 AS f0])
 +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}])
-   :- FlinkLogicalCalc(select=[a, b, c, pyFunc(f0) AS f0])
-   :  +- FlinkLogicalCalc(select=[a, b, c, $cor0.c AS f0])
-   :     +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+   :- FlinkLogicalCalc(select=[a, b, c, pyFunc(c) AS f0])
+   :  +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
    +- FlinkLogicalTableFunctionScan(invocation=[javaFunc($3)], rowType=[RecordType(VARCHAR(2147483647) f0)], elementType=[class [Ljava.lang.Object;])
 ]]>
     </Resource>
@@ -53,11 +52,11 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3], y=[$4])
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f10 AS f1])
+FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f1])
 +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1, 2}])
-   :- FlinkLogicalCalc(select=[a, b, c, *($cor0.a, $cor0.a) AS f0, $cor0.b AS f1, $cor0.c AS f2])
+   :- FlinkLogicalCalc(select=[a, b, c, *(a, a) AS f0])
    :  +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-   +- FlinkLogicalTableFunctionScan(invocation=[func($3, pyFunc($4, $5))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+   +- FlinkLogicalTableFunctionScan(invocation=[func($3, pyFunc($1, $2))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
 ]]>
     </Resource>
   </TestCase>
@@ -78,7 +77,7 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$4])
 FlinkLogicalCalc(select=[a, b, c, f00 AS x])
 +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3}])
    :- FlinkLogicalCalc(select=[a, b, c, d, pyFunc(f0) AS f0])
-   :  +- FlinkLogicalCalc(select=[a, b, c, d, $cor0.d._1 AS f0])
+   :  +- FlinkLogicalCalc(select=[a, b, c, d, d._1 AS f0])
    :     +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
    +- FlinkLogicalTableFunctionScan(invocation=[javaFunc($4)], rowType=[RecordType(VARCHAR(2147483647) f0)], elementType=[class [Ljava.lang.Object;])
 ]]>
@@ -100,9 +99,9 @@ LogicalProject(a=[$0], b=[$1], c=[$2], x=[$4], y=[$5])
       <![CDATA[
 FlinkLogicalCalc(select=[a, b, c, f00 AS x, f10 AS y])
 +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 2, 3}])
-   :- FlinkLogicalCalc(select=[a, b, c, d, *($cor0.d._1, $cor0.a) AS f0, $cor0.d._2 AS f1, $cor0.c AS f2])
+   :- FlinkLogicalCalc(select=[a, b, c, d, *(d._1, a) AS f0, d._2 AS f1])
    :  +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
-   +- FlinkLogicalTableFunctionScan(invocation=[func($4, pyFunc($5, $6))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+   +- FlinkLogicalTableFunctionScan(invocation=[func($4, pyFunc($5, $2))], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
 ]]>
     </Resource>
   </TestCase>
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/PythonCorrelateSplitRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/PythonCorrelateSplitRule.java
index 67750a0..8114ae0 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/PythonCorrelateSplitRule.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/PythonCorrelateSplitRule.java
@@ -23,14 +23,17 @@ import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCorrelate;
 import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan;
 import org.apache.flink.table.plan.util.CorrelateUtil;
 import org.apache.flink.table.plan.util.PythonUtil;
+import org.apache.flink.table.plan.util.RexDefaultVisitor;
 
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.hep.HepRelVertex;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
 import org.apache.calcite.rex.RexFieldAccess;
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
@@ -119,10 +122,41 @@ public class PythonCorrelateSplitRule extends RelOptRule {
         for (int i = 0; i < primitiveFieldCount; i++) {
             calcProjects.add(RexInputRef.of(i, rowType));
         }
+        // change RexCorrelVariable to RexInputRef.
+        RexDefaultVisitor<RexNode> visitor =
+                new RexDefaultVisitor<RexNode>() {
+                    @Override
+                    public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+                        RexNode expr = fieldAccess.getReferenceExpr();
+                        if (expr instanceof RexCorrelVariable) {
+                            RelDataTypeField field = fieldAccess.getField();
+                            return new RexInputRef(field.getIndex(), field.getType());
+                        } else {
+                            return rexBuilder.makeFieldAccess(
+                                    expr.accept(this), fieldAccess.getField().getIndex());
+                        }
+                    }
+
+                    @Override
+                    public RexNode visitNode(RexNode rexNode) {
+                        return rexNode;
+                    }
+                };
         // add the fields of the extracted rex calls.
         Iterator<RexNode> iterator = extractedRexNodes.iterator();
         while (iterator.hasNext()) {
-            calcProjects.add(iterator.next());
+            RexNode rexNode = iterator.next();
+            if (rexNode instanceof RexCall) {
+                RexCall rexCall = (RexCall) rexNode;
+                List<RexNode> newProjects =
+                        rexCall.getOperands().stream()
+                                .map(x -> x.accept(visitor))
+                                .collect(Collectors.toList());
+                RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
+                calcProjects.add(newRexCall);
+            } else {
+                calcProjects.add(rexNode);
+            }
         }
 
         List<String> nameList = new LinkedList<>();
@@ -196,10 +230,12 @@ public class PythonCorrelateSplitRule extends RelOptRule {
     }
 
     private ScalarFunctionSplitter createScalarFunctionSplitter(
+            RexBuilder rexBuilder,
             int primitiveLeftFieldCount,
             ArrayBuffer<RexNode> extractedRexNodes,
             RexNode tableFunctionNode) {
         return new ScalarFunctionSplitter(
+                rexBuilder,
                 primitiveLeftFieldCount,
                 extractedRexNodes,
                 node -> {
@@ -233,7 +269,10 @@ public class PythonCorrelateSplitRule extends RelOptRule {
                     createNewScan(
                             scan,
                             createScalarFunctionSplitter(
-                                    primitiveLeftFieldCount, extractedRexNodes, scan.getCall()));
+                                    rexBuilder,
+                                    primitiveLeftFieldCount,
+                                    extractedRexNodes,
+                                    scan.getCall()));
         } else {
             FlinkLogicalCalc calc = (FlinkLogicalCalc) right;
             FlinkLogicalTableFunctionScan scan = CorrelateUtil.getTableFunctionScan(calc).get();
@@ -242,23 +281,39 @@ public class PythonCorrelateSplitRule extends RelOptRule {
                     createNewScan(
                             scan,
                             createScalarFunctionSplitter(
-                                    primitiveLeftFieldCount, extractedRexNodes, scan.getCall()));
+                                    rexBuilder,
+                                    primitiveLeftFieldCount,
+                                    extractedRexNodes,
+                                    scan.getCall()));
             rightNewInput =
                     mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
         }
 
-        FlinkLogicalCalc leftCalc =
-                createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
+        FlinkLogicalCorrelate newCorrelate;
+        if (extractedRexNodes.size() > 0) {
+            FlinkLogicalCalc leftCalc =
+                    createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
 
-        FlinkLogicalCorrelate newCorrelate =
-                new FlinkLogicalCorrelate(
-                        correlate.getCluster(),
-                        correlate.getTraitSet(),
-                        leftCalc,
-                        rightNewInput,
-                        correlate.getCorrelationId(),
-                        correlate.getRequiredColumns(),
-                        correlate.getJoinType());
+            newCorrelate =
+                    new FlinkLogicalCorrelate(
+                            correlate.getCluster(),
+                            correlate.getTraitSet(),
+                            leftCalc,
+                            rightNewInput,
+                            correlate.getCorrelationId(),
+                            correlate.getRequiredColumns(),
+                            correlate.getJoinType());
+        } else {
+            newCorrelate =
+                    new FlinkLogicalCorrelate(
+                            correlate.getCluster(),
+                            correlate.getTraitSet(),
+                            left,
+                            rightNewInput,
+                            correlate.getCorrelationId(),
+                            correlate.getRequiredColumns(),
+                            correlate.getJoinType());
+        }
 
         FlinkLogicalCalc newTopCalc =
                 createTopCalc(
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonCalcSplitRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonCalcSplitRule.scala
index 2b73c0b..a055a3c 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonCalcSplitRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/PythonCalcSplitRule.scala
@@ -22,7 +22,7 @@ import java.util.function.Function
 
 import org.apache.calcite.plan.RelOptRule.{any, operand}
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rex.{RexBuilder, RexCall, RexFieldAccess, RexInputRef, RexNode, RexProgram}
+import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexNode, RexProgram}
 import org.apache.calcite.sql.validate.SqlValidatorUtil
 import org.apache.flink.table.functions.ScalarFunction
 import org.apache.flink.table.functions.python.PythonFunctionKind
@@ -53,6 +53,7 @@ abstract class PythonCalcSplitRuleBase(description: String)
 
     val extractedFunctionOffset = input.getRowType.getFieldCount
     val splitter = new ScalarFunctionSplitter(
+      rexBuilder,
       extractedFunctionOffset,
       extractedRexNodes,
       new Function[RexNode, Boolean] {
@@ -304,6 +305,7 @@ object PythonCalcRewriteProjectionRule extends PythonCalcSplitRuleBase(
 }
 
 private class ScalarFunctionSplitter(
+    rexBuilder: RexBuilder,
     extractedFunctionOffset: Int,
     extractedRexNodes: mutable.ArrayBuffer[RexNode],
     needConvert: Function[RexNode, Boolean])
@@ -319,7 +321,16 @@ private class ScalarFunctionSplitter(
 
   override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
     if (needConvert(fieldAccess)) {
-      getExtractedRexNode(fieldAccess)
+      val expr = fieldAccess.getReferenceExpr
+      expr match {
+        case _: RexCorrelVariable =>
+          val field = fieldAccess.getField
+          new RexInputRef(field.getIndex, field.getType)
+        case _ =>
+          val newFieldAccess = rexBuilder.makeFieldAccess(
+            expr.accept(this), fieldAccess.getField.getIndex)
+          getExtractedRexNode(newFieldAccess)
+      }
     } else {
       fieldAccess
     }