You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2020/03/13 03:07:31 UTC
[flink] branch master updated:
[FLINK-16008][python][table-planner][table-planner-blink] Add rules to
transpose the join condition as a Calc on top of the Python Correlate node
(#11299)
This is an automated email from the ASF dual-hosted git repository.
hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new ae1d1b2 [FLINK-16008][python][table-planner][table-planner-blink] Add rules to transpose the join condition as a Calc on top of the Python Correlate node (#11299)
ae1d1b2 is described below
commit ae1d1b2481b4c585b7ee53f53ec44b0d9bcda3b5
Author: HuangXingBo <hx...@gmail.com>
AuthorDate: Fri Mar 13 11:07:15 2020 +0800
[FLINK-16008][python][table-planner][table-planner-blink] Add rules to transpose the join condition as a Calc on top of the Python Correlate node (#11299)
Since currently we don't support joining a Python UDTF with conditions,
add a rule to transpose the condition as a Calc on top of the Python Correlate node.
---
.../logical/CalcPythonCorrelateTransposeRule.java | 110 ++++++++++++++++++++
.../stream/StreamExecPythonCorrelate.scala | 5 +
.../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +
.../SplitPythonConditionFromCorrelateRule.scala | 8 +-
.../CalcPythonCorrelateTransposeRuleTest.xml | 43 ++++++++
.../CalcPythonCorrelateTransposeRuleTest.scala | 67 ++++++++++++
.../logical/CalcPythonCorrelateTransposeRule.java | 113 +++++++++++++++++++++
.../datastream/DataStreamPythonCorrelate.scala | 6 +-
.../flink/table/plan/rules/FlinkRuleSets.scala | 2 +
.../SplitPythonConditionFromCorrelateRule.scala | 5 +-
.../CalcPythonCorrelateTransposeRuleTest.scala | 63 ++++++++++++
11 files changed, 420 insertions(+), 4 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
new file mode 100644
index 0000000..5501949
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecCorrelateRule;
+import org.apache.flink.table.planner.plan.utils.PythonUtil;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.rex.RexUtil;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Rule will transpose the conditions after the Python correlate node if the join type is inner join.
+ */
+public class CalcPythonCorrelateTransposeRule extends RelOptRule {
+
+ public static final CalcPythonCorrelateTransposeRule INSTANCE =
+ new CalcPythonCorrelateTransposeRule();
+
+ private CalcPythonCorrelateTransposeRule() {
+ super(operand(FlinkLogicalCorrelate.class,
+ operand(FlinkLogicalRel.class, any()),
+ operand(FlinkLogicalCalc.class, any())),
+ "CalcPythonCorrelateTransposeRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ FlinkLogicalCorrelate correlate = call.rel(0);
+ FlinkLogicalCalc right = call.rel(2);
+ JoinRelType joinType = correlate.getJoinType();
+ FlinkLogicalCalc mergedCalc = StreamExecCorrelateRule.getMergedCalc(right);
+ FlinkLogicalTableFunctionScan scan = StreamExecCorrelateRule.getTableScan(mergedCalc);
+ return joinType == JoinRelType.INNER &&
+ PythonUtil.isPythonCall(scan.getCall(), null) &&
+ mergedCalc.getProgram().getCondition() != null;
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ FlinkLogicalCorrelate correlate = call.rel(0);
+ FlinkLogicalCalc right = call.rel(2);
+ RexBuilder rexBuilder = call.builder().getRexBuilder();
+ FlinkLogicalCalc mergedCalc = StreamExecCorrelateRule.getMergedCalc(right);
+ FlinkLogicalTableFunctionScan tableScan = StreamExecCorrelateRule.getTableScan(mergedCalc);
+ RexProgram mergedCalcProgram = mergedCalc.getProgram();
+
+ InputRefRewriter inputRefRewriter = new InputRefRewriter(
+ correlate.getRowType().getFieldCount() - mergedCalc.getRowType().getFieldCount());
+ List<RexNode> correlateFilters = RelOptUtil
+ .conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition()))
+ .stream()
+ .map(x -> x.accept(inputRefRewriter))
+ .collect(Collectors.toList());
+
+ FlinkLogicalCorrelate newCorrelate = new FlinkLogicalCorrelate(
+ correlate.getCluster(),
+ correlate.getTraitSet(),
+ correlate.getLeft(),
+ tableScan,
+ correlate.getCorrelationId(),
+ correlate.getRequiredColumns(),
+ correlate.getJoinType());
+
+ RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder, correlateFilters);
+ RexProgram rexProgram = new RexProgramBuilder(
+ newCorrelate.getRowType(), rexBuilder).getProgram();
+ FlinkLogicalCalc newTopCalc = new FlinkLogicalCalc(
+ newCorrelate.getCluster(),
+ newCorrelate.getTraitSet(),
+ newCorrelate,
+ RexProgram.create(
+ newCorrelate.getRowType(),
+ rexProgram.getExprList(),
+ topCalcCondition,
+ newCorrelate.getRowType(),
+ rexBuilder));
+
+ call.transformTo(newTopCalc);
+ }
+}
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
index ca847fe..4ece477 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonCorrelate.scala
@@ -26,6 +26,7 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rex.{RexNode, RexProgram}
+import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.plan.nodes.common.CommonPythonCorrelate
/**
@@ -51,6 +52,10 @@ class StreamExecPythonCorrelate(
joinType)
with CommonPythonCorrelate {
+ if (condition.isDefined) {
+ throw new TableException("Currently Python correlate does not support conditions in left join.")
+ }
+
def copy(
traitSet: RelTraitSet,
newChild: RelNode,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index 44760f6..41fe922 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -342,6 +342,8 @@ object FlinkStreamRuleSets {
// Rule that splits python ScalarFunctions from
// java/scala ScalarFunctions in correlate conditions
SplitPythonConditionFromCorrelateRule.INSTANCE,
+ // Rule that transpose the conditions after the Python correlate node.
+ CalcPythonCorrelateTransposeRule.INSTANCE,
// Rule that splits java calls from python TableFunction
PythonCorrelateSplitRule.INSTANCE,
// merge calc after calc transpose
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
index a5de6b7..4806d40 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
@@ -22,9 +22,9 @@ import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex._
-import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalRel}
+import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalRel, FlinkLogicalTableFunctionScan}
import org.apache.flink.table.planner.plan.rules.physical.stream.StreamExecCorrelateRule
-import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall
+import org.apache.flink.table.planner.plan.utils.PythonUtil.{containsPythonCall, isNonPythonCall}
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor
import scala.collection.JavaConversions._
@@ -52,7 +52,11 @@ class SplitPythonConditionFromCorrelateRule
val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
val joinType: JoinRelType = correlate.getJoinType
val mergedCalc = StreamExecCorrelateRule.getMergedCalc(right)
+ val tableScan = StreamExecCorrelateRule
+ .getTableScan(mergedCalc)
+ .asInstanceOf[FlinkLogicalTableFunctionScan]
joinType == JoinRelType.INNER &&
+ isNonPythonCall(tableScan.getCall) &&
Option(mergedCalc.getProgram.getCondition)
.map(mergedCalc.getProgram.expandLocalRef)
.exists(containsPythonCall(_))
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
new file mode 100644
index 0000000..b1dda00
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.xml
@@ -0,0 +1,43 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testPythonFunctionInCorrelateCondition">
+ <Resource name="sql">
+ <![CDATA[SELECT a, b, c, x, y FROM MyTable, LATERAL TABLE(func(a * a, b)) AS T(x, y) WHERE x = a and pyFunc(x, x) = 2 and y + 1 = y * y]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], x=[$3], y=[$4])
++- LogicalFilter(condition=[AND(=($3, $0), =(pyFunc($3, $3), 2), =(+($4, 1), *($4, $4)))])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])
+ :- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]])
+ +- LogicalTableFunctionScan(invocation=[func(*($cor0.a, $cor0.a), $cor0.b)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+FlinkLogicalCalc(select=[a, b, c, f00 AS f0, f1], where=[AND(=(f0, 2), =(+(f1, 1), *(f1, f1)), =(f00, a))])
++- FlinkLogicalCalc(select=[a, b, c, f00, f1, pyFunc(f00, f00) AS f0])
+ +- FlinkLogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0, 1}])
+ :- FlinkLogicalCalc(select=[a, b, c, *($cor0.a, $cor0.a) AS f0])
+ : +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+ +- FlinkLogicalTableFunctionScan(invocation=[func($3, $cor0.b)], rowType=[RecordType(INTEGER f0, INTEGER f1)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.scala
new file mode 100644
index 0000000..5f97cfa
--- /dev/null
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRuleTest.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical
+
+import org.apache.calcite.plan.hep.HepMatchOrder
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.plan.optimize.program._
+import org.apache.flink.table.planner.plan.rules.FlinkStreamRuleSets
+import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.PythonScalarFunction
+import org.apache.flink.table.planner.utils.{MockPythonTableFunction, TableTestBase}
+import org.junit.{Before, Test}
+
+class CalcPythonCorrelateTransposeRuleTest extends TableTestBase {
+
+ private val util = streamTestUtil()
+
+ @Before
+ def setup(): Unit = {
+ val programs = new FlinkChainedProgram[StreamOptimizeContext]()
+ // query decorrelation
+ programs.addLast("decorrelate", new FlinkDecorrelateProgram)
+ programs.addLast(
+ "logical",
+ FlinkVolcanoProgramBuilder.newBuilder
+ .add(FlinkStreamRuleSets.LOGICAL_OPT_RULES)
+ .setRequiredOutputTraits(Array(FlinkConventions.LOGICAL))
+ .build())
+ programs.addLast(
+ "logical_rewrite",
+ FlinkHepRuleSetProgramBuilder.newBuilder
+ .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE)
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkStreamRuleSets.LOGICAL_REWRITE)
+ .build())
+ util.replaceStreamProgram(programs)
+
+ util.addTableSource[(Int, Int, Int)]("MyTable", 'a, 'b, 'c)
+ util.addFunction("func", new MockPythonTableFunction)
+ util.addFunction("pyFunc", new PythonScalarFunction("pyFunc"))
+ }
+
+ @Test
+ def testPythonFunctionInCorrelateCondition(): Unit = {
+ val sqlQuery = "SELECT a, b, c, x, y FROM MyTable, LATERAL TABLE(func(a * a, b)) AS T(x, y) " +
+ "WHERE x = a and pyFunc(x, x) = 2 and y + 1 = y * y"
+ util.verifyPlan(sqlQuery)
+ }
+
+}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/CalcPythonCorrelateTransposeRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
new file mode 100644
index 0000000..41975b7
--- /dev/null
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical;
+
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalRel;
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import org.apache.flink.table.plan.util.CorrelateUtil;
+import org.apache.flink.table.plan.util.PythonUtil;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.rex.RexUtil;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import scala.Option;
+
+/**
+ * Rule will transpose the conditions after the Python correlate node if the join type is inner join.
+ */
+public class CalcPythonCorrelateTransposeRule extends RelOptRule {
+
+ public static final CalcPythonCorrelateTransposeRule INSTANCE =
+ new CalcPythonCorrelateTransposeRule();
+
+ private CalcPythonCorrelateTransposeRule() {
+ super(operand(FlinkLogicalCorrelate.class,
+ operand(FlinkLogicalRel.class, any()),
+ operand(FlinkLogicalCalc.class, any())),
+ "CalcPythonCorrelateTransposeRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ FlinkLogicalCorrelate correlate = call.rel(0);
+ FlinkLogicalCalc right = call.rel(2);
+ JoinRelType joinType = correlate.getJoinType();
+ FlinkLogicalCalc mergedCalc = CorrelateUtil.getMergedCalc(right);
+ Option<FlinkLogicalTableFunctionScan> scan = CorrelateUtil.getTableFunctionScan(mergedCalc);
+ return joinType == JoinRelType.INNER &&
+ scan.isDefined() &&
+ PythonUtil.isPythonCall(scan.get().getCall(), null) &&
+ mergedCalc.getProgram().getCondition() != null;
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ FlinkLogicalCorrelate correlate = call.rel(0);
+ FlinkLogicalCalc right = call.rel(2);
+ RexBuilder rexBuilder = call.builder().getRexBuilder();
+ FlinkLogicalCalc mergedCalc = CorrelateUtil.getMergedCalc(right);
+ FlinkLogicalTableFunctionScan tableScan = CorrelateUtil.getTableFunctionScan(mergedCalc).get();
+ RexProgram mergedCalcProgram = mergedCalc.getProgram();
+
+ InputRefRewriter inputRefRewriter = new InputRefRewriter(
+ correlate.getRowType().getFieldCount() - mergedCalc.getRowType().getFieldCount());
+ List<RexNode> correlateFilters = RelOptUtil
+ .conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition()))
+ .stream()
+ .map(x -> x.accept(inputRefRewriter))
+ .collect(Collectors.toList());
+
+ FlinkLogicalCorrelate newCorrelate = new FlinkLogicalCorrelate(
+ correlate.getCluster(),
+ correlate.getTraitSet(),
+ correlate.getLeft(),
+ tableScan,
+ correlate.getCorrelationId(),
+ correlate.getRequiredColumns(),
+ correlate.getJoinType());
+
+ RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder, correlateFilters);
+ RexProgram rexProgram = new RexProgramBuilder(
+ newCorrelate.getRowType(), rexBuilder).getProgram();
+ FlinkLogicalCalc newTopCalc = new FlinkLogicalCalc(
+ newCorrelate.getCluster(),
+ newCorrelate.getTraitSet(),
+ newCorrelate,
+ RexProgram.create(
+ newCorrelate.getRowType(),
+ rexProgram.getExprList(),
+ topCalcCondition,
+ newCorrelate.getRowType(),
+ rexBuilder));
+
+ call.transformTo(newTopCalc);
+ }
+}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala
index 0021bac..1e46c35 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamPythonCorrelate.scala
@@ -22,7 +22,7 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.table.api.StreamQueryConfig
+import org.apache.flink.table.api.{StreamQueryConfig, TableException}
import org.apache.flink.table.functions.utils.TableSqlFunction
import org.apache.flink.table.plan.nodes.CommonPythonCorrelate
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTableFunctionScan
@@ -57,6 +57,10 @@ class DataStreamPythonCorrelate(
joinType)
with CommonPythonCorrelate {
+ if (condition.isDefined) {
+ throw new TableException("Currently Python correlate does not support conditions in left join.")
+ }
+
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamPythonCorrelate(
cluster,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index 7bbfa6a..534abe4 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -152,6 +152,8 @@ object FlinkRuleSets {
// Rule that splits python ScalarFunctions from
// java/scala ScalarFunctions in correlate conditions
SplitPythonConditionFromCorrelateRule.INSTANCE,
+ // Rule that transpose the conditions after the Python correlate node.
+ CalcPythonCorrelateTransposeRule.INSTANCE,
// Rule that splits java calls from python TableFunction
PythonCorrelateSplitRule.INSTANCE,
CalcMergeRule.INSTANCE,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
index 09a90a6..cee224a 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
@@ -23,7 +23,7 @@ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex._
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalRel}
-import org.apache.flink.table.plan.util.PythonUtil.containsPythonCall
+import org.apache.flink.table.plan.util.PythonUtil.{containsPythonCall, isNonPythonCall}
import org.apache.flink.table.plan.util.{CorrelateUtil, RexDefaultVisitor}
import scala.collection.JavaConversions._
@@ -51,7 +51,10 @@ class SplitPythonConditionFromCorrelateRule
val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
val joinType: JoinRelType = correlate.getJoinType
val mergedCalc = CorrelateUtil.getMergedCalc(right)
+ val tableScan = CorrelateUtil.getTableFunctionScan(mergedCalc)
joinType == JoinRelType.INNER &&
+ tableScan.isDefined &&
+ isNonPythonCall(tableScan.get.getCall) &&
Option(mergedCalc.getProgram.getCondition)
.map(mergedCalc.getProgram.expandLocalRef)
.exists(containsPythonCall(_))
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/CalcPythonCorrelateTransposeRuleTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/CalcPythonCorrelateTransposeRuleTest.scala
new file mode 100644
index 0000000..a5bd851
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/plan/CalcPythonCorrelateTransposeRuleTest.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.JavaUserDefinedScalarFunctions.PythonScalarFunction
+import org.apache.flink.table.utils.{MockPythonTableFunction, TableTestBase}
+import org.apache.flink.table.utils.TableTestUtil.{streamTableNode, term, unaryNode}
+import org.junit.Test
+
+class CalcPythonCorrelateTransposeRuleTest extends TableTestBase {
+ @Test
+ def testPythonTableFunctionWithCondition(): Unit = {
+ val util = streamTestUtil()
+ val table = util.addTable[(Int, Int, Int)]("MyTable", 'a, 'b, 'c)
+ val scalarFunc = new PythonScalarFunction("pyFunc")
+ val tableFunc = new MockPythonTableFunction()
+
+ val resultTable = table.joinLateral(
+ tableFunc('a * 'a, 'b) as('x, 'y),
+ 'x === 'a && scalarFunc('x, 'x) === 2 && 'y + 1 === 'y * 'y)
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamPythonCalc",
+ unaryNode(
+ "DataStreamPythonCorrelate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(table),
+ term("select", "a, b, c, *(a, a) AS f0")),
+ term("invocation", s"${tableFunc.functionIdentifier}" +
+ s"($$3, $$1)"),
+ term("correlate", s"table(${tableFunc.getClass.getSimpleName}" +
+ s"(f0, b))"),
+ term("select", "a, b, c, f0, x, y"),
+ term("rowType",
+ "RecordType(INTEGER a, INTEGER b, INTEGER c, INTEGER f0, INTEGER x, INTEGER y)"),
+ term("joinType", "INNER")),
+ term("select", "a, b, c, x, y, pyFunc(x, x) AS f0")),
+ term("select", "a, b, c, x, y"),
+ term("where", "AND(AND(=(f0, 2), =(+(y, 1), *(y, y))), =(x, a))"))
+ util.verifyTable(resultTable, expected)
+ }
+
+}