You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2020/03/13 03:07:31 UTC

[flink] branch master updated: [FLINK-16008][python][table-planner][table-planner-blink] Add rules to transpose the join condition as a Calc on top of the Python Correlate node (#11299)

This is an automated email from the ASF dual-hosted git repository.

hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new ae1d1b2  [FLINK-16008][python][table-planner][table-planner-blink] Add rules to transpose the join condition as a Calc on top of the Python Correlate node (#11299)
ae1d1b2 is described below

commit ae1d1b2481b4c585b7ee53f53ec44b0d9bcda3b5
Author: HuangXingBo <hx...@gmail.com>
AuthorDate: Fri Mar 13 11:07:15 2020 +0800

    [FLINK-16008][python][table-planner][table-planner-blink] Add rules to transpose the join condition as a Calc on top of the Python Correlate node (#11299)
    
    Since currently we don't support joining a Python UDTF with conditions,
    add a rule to transpose the condition as a Calc on top of the Python Correlate node.
---
 .../logical/CalcPythonCorrelateTransposeRule.java  | 110 ++++++++++++++++++++
 .../stream/StreamExecPythonCorrelate.scala         |   5 +
 .../planner/plan/rules/FlinkStreamRuleSets.scala   |   2 +
 .../SplitPythonConditionFromCorrelateRule.scala    |   8 +-
 .../CalcPythonCorrelateTransposeRuleTest.xml       |  43 ++++++++
 .../CalcPythonCorrelateTransposeRuleTest.scala     |  67 ++++++++++++
 .../logical/CalcPythonCorrelateTransposeRule.java  | 113 +++++++++++++++++++++
 .../datastream/DataStreamPythonCorrelate.scala     |   6 +-
 .../flink/table/plan/rules/FlinkRuleSets.scala     |   2 +
 .../SplitPythonConditionFromCorrelateRule.scala    |   5 +-
 .../CalcPythonCorrelateTransposeRuleTest.scala     |  63 ++++++++++++
 11 files changed, 420 insertions(+), 4 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
new file mode 100644
index 0000000..5501949
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecCorrelateRule;
+import org.apache.flink.table.planner.plan.utils.PythonUtil;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.rex.RexUtil;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule will transpose the conditions after the Python correlate node if the join type is inner join.
+ */
+public class CalcPythonCorrelateTransposeRule extends RelOptRule {
+
+	public static final CalcPythonCorrelateTransposeRule INSTANCE =
+		new CalcPythonCorrelateTransposeRule();
+
+	private CalcPythonCorrelateTransposeRule() {
+		super(operand(FlinkLogicalCorrelate.class,
+			operand(FlinkLogicalRel.class, any()),
+			operand(FlinkLogicalCalc.class, any())),
+			"CalcPythonCorrelateTransposeRule");
+	}
+
+	@Override
+	public boolean matches(RelOptRuleCall call) {
+		FlinkLogicalCorrelate correlate = call.rel(0);
+		FlinkLogicalCalc right = call.rel(2);
+		JoinRelType joinType = correlate.getJoinType();
+		FlinkLogicalCalc mergedCalc = StreamExecCorrelateRule.getMergedCalc(right);
+		FlinkLogicalTableFunctionScan scan = StreamExecCorrelateRule.getTableScan(mergedCalc);
+		return joinType == JoinRelType.INNER &&
+			PythonUtil.isPythonCall(scan.getCall(), null) &&
+			mergedCalc.getProgram().getCondition() != null;
+	}
+
+	@Override
+	public void onMatch(RelOptRuleCall call) {
+		FlinkLogicalCorrelate correlate = call.rel(0);
+		FlinkLogicalCalc right = call.rel(2);
+		RexBuilder rexBuilder = call.builder().getRexBuilder();
+		FlinkLogicalCalc mergedCalc = StreamExecCorrelateRule.getMergedCalc(right);
+		FlinkLogicalTableFunctionScan tableScan = StreamExecCorrelateRule.getTableScan(mergedCalc);
+		RexProgram mergedCalcProgram = mergedCalc.getProgram();
+
+		InputRefRewriter inputRefRewriter = new InputRefRewriter(
+			correlate.getRowType().getFieldCount() - mergedCalc.getRowType().getFieldCount());
+		List<RexNode> correlateFilters = RelOptUtil
+			.conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition()))
+			.stream()
+			.map(x -> x.accept(inputRefRewriter))
+			.collect(Collectors.toList());
+
+		FlinkLogicalCorrelate newCorrelate = new FlinkLogicalCorrelate(
+			correlate.getCluster(),
+			correlate.getTraitSet(),
+			correlate.getLeft(),
+			tableScan,
+			correlate.getCorrelationId(),
+			correlate.getRequiredColumns(),
+			correlate.getJoinType());
+
+		RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder, correlateFilters);
+		RexProgram rexProgram = new RexProgramBuilder(
+			newCorrelate.getRowType(), rexBuilder).getProgram();
+		FlinkLogicalCalc newTopCalc = new FlinkLogicalCalc(
+			newCorrelate.getCluster(),
+			newCorrelate.getTraitSet(),
+			newCorrelate,
+			RexProgram.create(
+				newCorrelate.getRowType(),
+				rexProgram.getExprList(),
+				topCalcCondition,
+				newCorrelate.getRowType(),
+				rexBuilder));
+
+		call.transformTo(newTopCalc);
+	}
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
index ca847fe..4ece477 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
@@ -26,6 +26,7 @@ import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.JoinRelType
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rex.{RexNode, RexProgram}
+import org.apache.flink.table.api.TableException
 import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCorrelate
 
 /**
@@ -51,6 +52,10 @@ class StreamExecPythonCorrelate(
     joinType)
   with CommonPythonCorrelate {
 
+  if (condition.isDefined) {
+    throw new TableException("Currently Python correlate does not support conditions in left join.")
+  }
+
   def copy(
       traitSet: RelTraitSet,
       newChild: RelNode,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index 44760f6..41fe922 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -342,6 +342,8 @@ object FlinkStreamRuleSets {
     // Rule that splits python ScalarFunctions from
     // java/scala ScalarFunctions in correlate conditions
     SplitPythonConditionFromCorrelateRule.INSTANCE,
+    // Rule that transpose the conditions after the Python correlate node.
+    CalcPythonCorrelateTransposeRule.INSTANCE,
     // Rule that splits java calls from python TableFunction
     PythonCorrelateSplitRule.INSTANCE,
     // merge calc after calc transpose
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
index a5de6b7..4806d40 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
@@ -22,9 +22,9 @@ import org.apache.calcite.plan.RelOptRule.{any, operand}
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
 import org.apache.calcite.rel.core.JoinRelType
 import org.apache.calcite.rex._
-import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalRel}
+import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalRel, FlinkLogicalTableFunctionScan}
 import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecCorrelateRule
-import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall
+import org.apache.flink.table.planner.plan.utils.PythonUtil.{containsPythonCall, isNonPythonCall}
 import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor
 
 import scala.collection.JavaConversions._
@@ -52,7 +52,11 @@ class SplitPythonConditionFromCorrelateRule
     val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
     val joinType: JoinRelType = correlate.getJoinType
     val mergedCalc = StreamExecCorrelateRule.getMergedCalc(right)
+    val tableScan = StreamExecCorrelateRule
+      .getTableScan(mergedCalc)
+      .asInstanceOf[FlinkLogicalTableFunctionScan]
     joinType == JoinRelType.INNER &&
+      isNonPythonCall(tableScan.getCall) &&
       Option(mergedCalc.getProgram.getCondition)
         .map(mergedCalc.getProgram.expandLocalRef)
         .exists(containsPythonCall(_))
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
new file mode 100644
index 0000000..b1dda00
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
@@ -0,0 +1,43 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+  <TestCase name="testPythonFunctionInCorrelateCondition">
+    <Resource name="sql">
+      <![CDATA[SELECT a, b, c, x, y FROM MyTable, LATERAL TABLE(func(a * a, b)) AS T(x, y) WHERE x = a and pyFunc(x, x) = 2 and y + 1 = y * y]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3], y=[$4])
++- LogicalFilter(condition=[AND(=($3, $0), =(pyFunc($3, $3), 2), =(+($4, 1), *($4, $4)))])
+   +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])
+      :- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]])
+      +- LogicalTableFunctionScan(invocation=[func(*($cor0.a, $cor0.a), $cor0.b)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f1], where=[AND(=(f0, 2), =(+(f1, 1), *(f1, f1)), =(f00, a))])
++- FlinkLogicalCalc(select=[a, b, c, f00, f1, pyFunc(f00, f00) AS f0])
+   +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])
+      :- FlinkLogicalCalc(select=[a, b, c, *($cor0.a, $cor0.a) AS f0])
+      :  +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+      +- FlinkLogicalTableFunctionScan(invocation=[func($3, $cor0.b)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+]]>
+    </Resource>
+  </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.scala
new file mode 100644
index 0000000..5f97cfa
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical
+
+import org.apache.calcite.plan.hep.HepMatchOrder
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.plan.optimize.program._
+import org.apache.flink.table.planner.plan.rules.FlinkStreamRuleSets
+import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.PythonScalarFunction
+import org.apache.flink.table.planner.utils.{MockPythonTableFunction, TableTestBase}
+import org.junit.{Before, Test}
+
+class CalcPythonCorrelateTransposeRuleTest extends TableTestBase {
+
+  private val util = streamTestUtil()
+
+  @Before
+  def setup(): Unit = {
+    val programs = new FlinkChainedProgram[StreamOptimizeContext]()
+    // query decorrelation
+    programs.addLast("decorrelate", new FlinkDecorrelateProgram)
+    programs.addLast(
+      "logical",
+      FlinkVolcanoProgramBuilder.newBuilder
+        .add(FlinkStreamRuleSets.LOGICAL_OPT_RULES)
+        .setRequiredOutputTraits(Array(FlinkConventions.LOGICAL))
+        .build())
+    programs.addLast(
+      "logical_rewrite",
+      FlinkHepRuleSetProgramBuilder.newBuilder
+        .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+        .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+        .add(FlinkStreamRuleSets.LOGICAL_REWRITE)
+        .build())
+    util.replaceStreamProgram(programs)
+
+    util.addTableSource[(Int, Int, Int)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func", new MockPythonTableFunction)
+    util.addFunction("pyFunc", new PythonScalarFunction("pyFunc"))
+  }
+
+  @Test
+  def testPythonFunctionInCorrelateCondition(): Unit = {
+    val sqlQuery = "SELECT a, b, c, x, y FROM MyTable, LATERAL TABLE(func(a * a, b)) AS T(x, y) " +
+      "WHERE x = a and pyFunc(x, x) = 2 and y + 1 = y * y"
+    util.verifyPlan(sqlQuery)
+  }
+
+}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/CalcPythonCorrelateTransposeRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
new file mode 100644
index 0000000..41975b7
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical;
+
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRel;
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.plan.util.CorrelateUtil;
+import org.apache.flink.table.plan.util.PythonUtil;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.rex.RexUtil;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import scala.Option;
+
+/**
+ * Rule will transpose the conditions after the Python correlate node if the join type is inner join.
+ */
+public class CalcPythonCorrelateTransposeRule extends RelOptRule {
+
+	public static final CalcPythonCorrelateTransposeRule INSTANCE =
+		new CalcPythonCorrelateTransposeRule();
+
+	private CalcPythonCorrelateTransposeRule() {
+		super(operand(FlinkLogicalCorrelate.class,
+			operand(FlinkLogicalRel.class, any()),
+			operand(FlinkLogicalCalc.class, any())),
+			"CalcPythonCorrelateTransposeRule");
+	}
+
+	@Override
+	public boolean matches(RelOptRuleCall call) {
+		FlinkLogicalCorrelate correlate = call.rel(0);
+		FlinkLogicalCalc right = call.rel(2);
+		JoinRelType joinType = correlate.getJoinType();
+		FlinkLogicalCalc mergedCalc = CorrelateUtil.getMergedCalc(right);
+		Option<FlinkLogicalTableFunctionScan> scan = CorrelateUtil.getTableFunctionScan(mergedCalc);
+		return joinType == JoinRelType.INNER &&
+			scan.isDefined() &&
+			PythonUtil.isPythonCall(scan.get().getCall(), null) &&
+			mergedCalc.getProgram().getCondition() != null;
+	}
+
+	@Override
+	public void onMatch(RelOptRuleCall call) {
+		FlinkLogicalCorrelate correlate = call.rel(0);
+		FlinkLogicalCalc right = call.rel(2);
+		RexBuilder rexBuilder = call.builder().getRexBuilder();
+		FlinkLogicalCalc mergedCalc = CorrelateUtil.getMergedCalc(right);
+		FlinkLogicalTableFunctionScan tableScan = CorrelateUtil.getTableFunctionScan(mergedCalc).get();
+		RexProgram mergedCalcProgram = mergedCalc.getProgram();
+
+		InputRefRewriter inputRefRewriter = new InputRefRewriter(
+			correlate.getRowType().getFieldCount() - mergedCalc.getRowType().getFieldCount());
+		List<RexNode> correlateFilters = RelOptUtil
+			.conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition()))
+			.stream()
+			.map(x -> x.accept(inputRefRewriter))
+			.collect(Collectors.toList());
+
+		FlinkLogicalCorrelate newCorrelate = new FlinkLogicalCorrelate(
+			correlate.getCluster(),
+			correlate.getTraitSet(),
+			correlate.getLeft(),
+			tableScan,
+			correlate.getCorrelationId(),
+			correlate.getRequiredColumns(),
+			correlate.getJoinType());
+
+		RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder, correlateFilters);
+		RexProgram rexProgram = new RexProgramBuilder(
+			newCorrelate.getRowType(), rexBuilder).getProgram();
+		FlinkLogicalCalc newTopCalc = new FlinkLogicalCalc(
+			newCorrelate.getCluster(),
+			newCorrelate.getTraitSet(),
+			newCorrelate,
+			RexProgram.create(
+				newCorrelate.getRowType(),
+				rexProgram.getExprList(),
+				topCalcCondition,
+				newCorrelate.getRowType(),
+				rexBuilder));
+
+		call.transformTo(newTopCalc);
+	}
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala
index 0021bac..1e46c35 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala
@@ -22,7 +22,7 @@ import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.core.JoinRelType
 import org.apache.calcite.rex.{RexCall, RexNode}
 import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.table.api.StreamQueryConfig
+import org.apache.flink.table.api.{StreamQueryConfig, TableException}
 import org.apache.flink.table.functions.utils.TableSqlFunction
 import org.apache.flink.table.plan.nodes.CommonPythonCorrelate
 import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan
@@ -57,6 +57,10 @@ class DataStreamPythonCorrelate(
     joinType)
   with CommonPythonCorrelate {
 
+  if (condition.isDefined) {
+    throw new TableException("Currently Python correlate does not support conditions in left join.")
+  }
+
   override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
     new DataStreamPythonCorrelate(
       cluster,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index 7bbfa6a..534abe4 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -152,6 +152,8 @@ object FlinkRuleSets {
     // Rule that splits python ScalarFunctions from
     // java/scala ScalarFunctions in correlate conditions
     SplitPythonConditionFromCorrelateRule.INSTANCE,
+    // Rule that transpose the conditions after the Python correlate node.
+    CalcPythonCorrelateTransposeRule.INSTANCE,
     // Rule that splits java calls from python TableFunction
     PythonCorrelateSplitRule.INSTANCE,
     CalcMergeRule.INSTANCE,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
index 09a90a6..cee224a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
@@ -23,7 +23,7 @@ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
 import org.apache.calcite.rel.core.JoinRelType
 import org.apache.calcite.rex._
 import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalRel}
-import org.apache.flink.table.plan.util.PythonUtil.containsPythonCall
+import org.apache.flink.table.plan.util.PythonUtil.{containsPythonCall, isNonPythonCall}
 import org.apache.flink.table.plan.util.{CorrelateUtil, RexDefaultVisitor}
 
 import scala.collection.JavaConversions._
@@ -51,7 +51,10 @@ class SplitPythonConditionFromCorrelateRule
     val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
     val joinType: JoinRelType = correlate.getJoinType
     val mergedCalc = CorrelateUtil.getMergedCalc(right)
+    val tableScan = CorrelateUtil.getTableFunctionScan(mergedCalc)
     joinType == JoinRelType.INNER &&
+      tableScan.isDefined &&
+      isNonPythonCall(tableScan.get.getCall) &&
       Option(mergedCalc.getProgram.getCondition)
         .map(mergedCalc.getProgram.expandLocalRef)
         .exists(containsPythonCall(_))
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/CalcPythonCorrelateTransposeRuleTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/CalcPythonCorrelateTransposeRuleTest.scala
new file mode 100644
index 0000000..a5bd851
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/CalcPythonCorrelateTransposeRuleTest.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.JavaUserDefinedScalarFunctions.PythonScalarFunction
+import org.apache.flink.table.utils.{MockPythonTableFunction, TableTestBase}
+import org.apache.flink.table.utils.TableTestUtil.{streamTableNode, term, unaryNode}
+import org.junit.Test
+
+class CalcPythonCorrelateTransposeRuleTest extends TableTestBase {
+  @Test
+  def testPythonTableFunctionWithCondition(): Unit = {
+    val util = streamTestUtil()
+    val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c)
+    val scalarFunc = new PythonScalarFunction("pyFunc")
+    val tableFunc = new MockPythonTableFunction()
+
+    val resultTable = table.joinLateral(
+      tableFunc('a * 'a, 'b) as('x, 'y),
+      'x === 'a && scalarFunc('x, 'x) === 2 && 'y + 1 === 'y * 'y)
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamPythonCalc",
+        unaryNode(
+          "DataStreamPythonCorrelate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(table),
+            term("select", "a, b, c, *(a, a) AS f0")),
+          term("invocation", s"${tableFunc.functionIdentifier}" +
+            s"($$3, $$1)"),
+          term("correlate", s"table(${tableFunc.getClass.getSimpleName}" +
+            s"(f0, b))"),
+          term("select", "a, b, c, f0, x, y"),
+          term("rowType",
+               "RecordType(INTEGER a, INTEGER b, INTEGER c, INTEGER f0, INTEGER x, INTEGER y)"),
+          term("joinType", "INNER")),
+        term("select", "a, b, c, x, y, pyFunc(x, x) AS f0")),
+      term("select", "a, b, c, x, y"),
+      term("where", "AND(AND(=(f0, 2), =(+(y, 1), *(y, y))), =(x, a))"))
+    util.verifyTable(resultTable, expected)
+  }
+
+}