You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2018/09/21 11:43:39 UTC
[flink] 06/11: [FLINK-9713][table][sql] Support versioned join in
planning phase
This is an automated email from the ASF dual-hosted git repository.
pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 77c38346cb06b6e6c1bb672695c54f4ba253bd3f
Author: Piotr Nowojski <pi...@gmail.com>
AuthorDate: Thu Jul 5 20:02:51 2018 +0200
[FLINK-9713][table][sql] Support versioned join in planning phase
---
.../flink/table/api/BatchTableEnvironment.scala | 3 +-
.../flink/table/api/StreamTableEnvironment.scala | 3 +-
.../apache/flink/table/api/TableEnvironment.scala | 8 +
.../flink/table/calcite/FlinkRelBuilder.scala | 2 +
.../table/calcite/RelTimeIndicatorConverter.scala | 53 +++++-
.../logical/rel/LogicalTemporalTableJoin.scala | 156 ++++++++++++++++
.../datastream/DataStreamTemporalTableJoin.scala | 82 ++++++++
.../logical/FlinkLogicalTemporalTableJoin.scala | 97 ++++++++++
.../flink/table/plan/rules/FlinkRuleSets.scala | 10 +
.../DataStreamTemporalTableJoinRule.scala | 78 ++++++++
.../LogicalCorrelateToTemporalTableJoinRule.scala | 207 +++++++++++++++++++++
.../flink/table/plan/util/RexDefaultVisitor.scala | 66 +++++++
.../api/batch/sql/TemporalTableJoinTest.scala | 112 +++++++++++
.../api/batch/table/TemporalTableJoinTest.scala | 77 ++++++++
.../api/stream/sql/TemporalTableJoinTest.scala | 130 +++++++++++++
.../api/stream/table/TemporalTableJoinTest.scala | 202 +++++++++++++++++++-
.../TemporalTableJoinValidationTest.scala | 2 -
.../table/plan/TimeIndicatorConversionTest.scala | 110 +++++++++++
.../apache/flink/table/utils/TableTestBase.scala | 49 ++++-
19 files changed, 1428 insertions(+), 19 deletions(-)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
index 522a03e..5a34ee1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
@@ -449,7 +449,8 @@ abstract class BatchTableEnvironment(
*/
private[flink] def optimize(relNode: RelNode): RelNode = {
val convSubQueryPlan = optimizeConvertSubQueries(relNode)
- val fullNode = optimizeConvertTableReferences(convSubQueryPlan)
+ val temporalTableJoinPlan = optimizeConvertToTemporalJoin(convSubQueryPlan)
+ val fullNode = optimizeConvertTableReferences(temporalTableJoinPlan)
val decorPlan = RelDecorrelator.decorrelateQuery(fullNode)
val normalizedPlan = optimizeNormalizeLogicalPlan(decorPlan)
val logicalPlan = optimizeLogicalPlan(normalizedPlan)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
index 860f8b2..5f45cc3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
@@ -804,7 +804,8 @@ abstract class StreamTableEnvironment(
*/
private[flink] def optimize(relNode: RelNode, updatesAsRetraction: Boolean): RelNode = {
val convSubQueryPlan = optimizeConvertSubQueries(relNode)
- val fullNode = optimizeConvertTableReferences(convSubQueryPlan)
+ val temporalTableJoinPlan = optimizeConvertToTemporalJoin(convSubQueryPlan)
+ val fullNode = optimizeConvertTableReferences(temporalTableJoinPlan)
val decorPlan = RelDecorrelator.decorrelateQuery(fullNode)
val planWithMaterializedTimeAttributes =
RelTimeIndicatorConverter.convert(decorPlan, getRelBuilder.getRexBuilder)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
index cce270c..d740c3f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
@@ -242,6 +242,14 @@ abstract class TableEnvironment(val config: TableConfig) {
relNode.getTraitSet)
}
+ protected def optimizeConvertToTemporalJoin(relNode: RelNode): RelNode = {
+ runHepPlanner(
+ HepMatchOrder.BOTTOM_UP,
+ FlinkRuleSets.TEMPORAL_JOIN_RULES,
+ relNode,
+ relNode.getTraitSet)
+ }
+
protected def optimizeConvertTableReferences(relNode: RelNode): RelNode = {
runHepPlanner(
HepMatchOrder.BOTTOM_UP,
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
index 1ac9b53..1aecdd8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
@@ -46,6 +46,8 @@ class FlinkRelBuilder(
relOptCluster,
relOptSchema) {
+ def getRelOptSchema: RelOptSchema = relOptSchema
+
def getPlanner: RelOptPlanner = cluster.getPlanner
def getCluster: RelOptCluster = relOptCluster
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
index f67b715..34b98a8 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
@@ -28,7 +28,7 @@ import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory.{isRowtimeIndicatorType, _}
import org.apache.flink.table.functions.sql.ProctimeSqlFunction
-import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate
+import org.apache.flink.table.plan.logical.rel.{LogicalTemporalTableJoin, LogicalWindowAggregate}
import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType
import org.apache.flink.table.validate.BasicOperatorTable
@@ -117,11 +117,13 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
aggregate.getNamedProperties,
convAggregate)
+ case temporalTableJoin: LogicalTemporalTableJoin =>
+ visit(temporalTableJoin)
+
case _ =>
throw new TableException(s"Unsupported logical operator: ${other.getClass.getSimpleName}")
}
-
override def visit(exchange: LogicalExchange): RelNode =
throw new TableException("Logical exchange in a stream environment is not supported yet.")
@@ -163,9 +165,18 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
val right = join.getRight.accept(this)
LogicalJoin.create(left, right, join.getCondition, join.getVariablesSet, join.getJoinType)
-
}
+ def visit(temporalJoin: LogicalTemporalTableJoin): RelNode = {
+ val left = temporalJoin.getLeft.accept(this)
+ val right = temporalJoin.getRight.accept(this)
+
+ val rewrittenTemporalJoin = temporalJoin.copy(temporalJoin.getTraitSet, List(left, right))
+
+ val indicesToMaterialize = gatherIndicesToMaterialize(rewrittenTemporalJoin, left, right)
+
+ materializerUtils.projectAndMaterializeFields(rewrittenTemporalJoin, indicesToMaterialize)
+ }
override def visit(correlate: LogicalCorrelate): RelNode = {
// visit children and update inputs
@@ -204,13 +215,43 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
correlate.getJoinType)
}
+ private def gatherIndicesToMaterialize(
+ temporalJoin: Join,
+ left: RelNode,
+ right: RelNode)
+ : Set[Int] = {
+
+ // Materialize all of the time attributes from the right side of temporal join
+ var indicesToMaterialize =
+ (left.getRowType.getFieldCount until temporalJoin.getRowType.getFieldCount).toSet
+
+ if (!hasRowtimeAttribute(right.getRowType)) {
+ // No rowtime on the right side means that this must be a processing time temporal join
+ // and that we can not provide watermarks even if there is a rowtime time attribute
+ // on the left side (besides processing time attribute used for temporal join).
+ for (fieldIndex <- 0 until left.getRowType.getFieldCount) {
+ val fieldName = left.getRowType.getFieldNames.get(fieldIndex)
+ val fieldType = left.getRowType.getFieldList.get(fieldIndex).getType
+ if (isRowtimeIndicatorType(fieldType)) {
+ indicesToMaterialize += fieldIndex
+ }
+ }
+ }
+
+ indicesToMaterialize
+ }
+
+ private def hasRowtimeAttribute(rowType: RelDataType): Boolean = {
+ rowType.getFieldList.exists(field => isRowtimeIndicatorType(field.getType))
+ }
+
private def convertAggregate(aggregate: Aggregate): LogicalAggregate = {
// visit children and update inputs
val input = aggregate.getInput.accept(this)
// add a project to materialize aggregation arguments/grouping keys
- val indicesToMaterialize = gatherIndicesToMaterialize(aggregate)
+ val indicesToMaterialize = gatherIndicesToMaterialize(aggregate, input)
val needsMaterialization = indicesToMaterialize.exists(idx =>
isTimeIndicatorType(input.getRowType.getFieldList.get(idx).getType))
@@ -266,13 +307,13 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
updatedAggCalls)
}
- private def gatherIndicesToMaterialize(aggregate: Aggregate): Set[Int] = {
+ private def gatherIndicesToMaterialize(aggregate: Aggregate, input: RelNode): Set[Int] = {
val indicesToMaterialize = mutable.Set[Int]()
// check arguments of agg calls
aggregate.getAggCallList.foreach(call => if (call.getArgList.size() == 0) {
// count(*) has an empty argument list
- (0 until aggregate.getRowType.getFieldCount).foreach(indicesToMaterialize.add)
+ (0 until input.getRowType.getFieldCount).foreach(indicesToMaterialize.add)
} else {
// for other aggregations
call.getArgList.map(_.asInstanceOf[Int]).foreach(indicesToMaterialize.add)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalTemporalTableJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalTemporalTableJoin.scala
new file mode 100644
index 0000000..3b1d51b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/rel/LogicalTemporalTableJoin.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.logical.rel
+
+import java.util.Collections
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core._
+import org.apache.calcite.rex.{RexBuilder, RexNode}
+import org.apache.calcite.sql.`type`.{OperandTypes, ReturnTypes}
+import org.apache.calcite.sql.{SqlFunction, SqlFunctionCategory, SqlKind}
+import org.apache.flink.util.Preconditions.checkArgument
+
+/**
+ * Represents a join between a table and [[org.apache.flink.table.functions.TemporalTableFunction]]
+ *
+ * @param cluster
+ * @param traitSet
+ * @param left
+ * @param right table scan (or other more complex table expression) of underlying
+ * [[org.apache.flink.table.functions.TemporalTableFunction]]
+ * @param condition must contain [[LogicalTemporalTableJoin#TEMPORAL_JOIN_CONDITION]] with
+ * correctly defined references to rightTimeAttribute,
+ * rightPrimaryKeyExpression and leftTimeAttribute. We can not implement
+ * those references as separate fields, because of problems with Calcite's
+ * optimization rules like projections push downs, column
+ * pruning/renaming/reordering, etc. Later rightTimeAttribute,
+ * rightPrimaryKeyExpression and leftTimeAttribute will be extracted from
+ * the condition.
+ */
+class LogicalTemporalTableJoin private(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ left: RelNode,
+ right: RelNode,
+ condition: RexNode)
+ extends Join(
+ cluster,
+ traitSet,
+ left,
+ right,
+ condition,
+ Collections.emptySet().asInstanceOf[java.util.Set[CorrelationId]],
+ JoinRelType.INNER) {
+
+ override def copy(
+ traitSet: RelTraitSet,
+ condition: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): LogicalTemporalTableJoin = {
+ checkArgument(joinType == this.getJoinType,
+ "Can not change join type".asInstanceOf[Object])
+ checkArgument(semiJoinDone == this.isSemiJoinDone,
+ "Can not change semiJoinDone".asInstanceOf[Object])
+ new LogicalTemporalTableJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ condition)
+ }
+}
+
+object LogicalTemporalTableJoin {
+ /**
+ * See [[LogicalTemporalTableJoin#condition]]
+ */
+ val TEMPORAL_JOIN_CONDITION = new SqlFunction(
+ "__TEMPORAL_JOIN_CONDITION",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.BOOLEAN_NOT_NULL,
+ null,
+ OperandTypes.or(
+ OperandTypes.sequence(
+ "'(LEFT_TIME_ATTRIBUTE, RIGHT_TIME_ATTRIBUTE, PRIMARY_KEY)'",
+ OperandTypes.DATETIME,
+ OperandTypes.DATETIME,
+ OperandTypes.ANY),
+ OperandTypes.sequence(
+ "'(LEFT_TIME_ATTRIBUTE, PRIMARY_KEY)'",
+ OperandTypes.DATETIME,
+ OperandTypes.ANY)),
+ SqlFunctionCategory.SYSTEM)
+
+ /**
+ * See [[LogicalTemporalTableJoin]]
+ */
+ def createRowtime(
+ rexBuilder: RexBuilder,
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ left: RelNode,
+ right: RelNode,
+ leftTimeAttribute: RexNode,
+ rightTimeAttribute: RexNode,
+ rightPrimaryKeyExpression: RexNode)
+ : LogicalTemporalTableJoin = {
+ new LogicalTemporalTableJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ rexBuilder.makeCall(
+ TEMPORAL_JOIN_CONDITION,
+ leftTimeAttribute,
+ rightTimeAttribute,
+ rightPrimaryKeyExpression))
+ }
+
+ /**
+ * See [[LogicalTemporalTableJoin]]
+ *
+ * @param leftTimeAttribute is needed because otherwise,
+ * [[LogicalTemporalTableJoin#TEMPORAL_JOIN_CONDITION]] could be pushed
+ * down below [[LogicalTemporalTableJoin]], since it wouldn't have any
+ * references to the left node.
+ */
+ def createProctime(
+ rexBuilder: RexBuilder,
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ left: RelNode,
+ right: RelNode,
+ leftTimeAttribute: RexNode,
+ rightPrimaryKeyExpression: RexNode)
+ : LogicalTemporalTableJoin = {
+ new LogicalTemporalTableJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ rexBuilder.makeCall(
+ TEMPORAL_JOIN_CONDITION,
+ leftTimeAttribute,
+ rightPrimaryKeyExpression))
+ }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamTemporalTableJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamTemporalTableJoin.scala
new file mode 100644
index 0000000..60f36d3
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamTemporalTableJoin.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes.datastream
+
+import org.apache.calcite.plan._
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment}
+import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.util.Preconditions.checkState
+
+/**
+ * RelNode for a stream join with [[org.apache.flink.table.functions.TemporalTableFunction]].
+ */
+class DataStreamTemporalTableJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftNode: RelNode,
+ rightNode: RelNode,
+ joinCondition: RexNode,
+ joinInfo: JoinInfo,
+ leftSchema: RowSchema,
+ rightSchema: RowSchema,
+ schema: RowSchema,
+ ruleDescription: String)
+ extends DataStreamJoin(
+ cluster,
+ traitSet,
+ leftNode,
+ rightNode,
+ joinCondition,
+ joinInfo,
+ JoinRelType.INNER,
+ leftSchema,
+ rightSchema,
+ schema,
+ ruleDescription) {
+
+ override def needsUpdatesAsRetraction: Boolean = true
+
+ override def producesRetractions: Boolean = false
+
+ override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
+ checkState(inputs.size() == 2)
+ new DataStreamTemporalTableJoin(
+ cluster,
+ traitSet,
+ inputs.get(0),
+ inputs.get(1),
+ joinCondition,
+ joinInfo,
+ leftSchema,
+ rightSchema,
+ schema,
+ ruleDescription)
+ }
+
+ override def translateToPlan(
+ tableEnv: StreamTableEnvironment,
+ queryConfig: StreamQueryConfig): DataStream[CRow] = {
+ throw new NotImplementedError()
+ }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTemporalTableJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTemporalTableJoin.scala
new file mode 100644
index 0000000..4be2fb9
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTemporalTableJoin.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes.logical
+
+import org.apache.calcite.plan._
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.core._
+import org.apache.calcite.rex.RexNode
+import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin
+import org.apache.flink.table.plan.nodes.FlinkConventions
+import org.apache.flink.util.Preconditions.checkArgument
+
+/**
+ * Represents a join between a table and
+ * [[org.apache.flink.table.functions.TemporalTableFunction]]. For more details please check
+ * [[LogicalTemporalTableJoin]].
+ */
+class FlinkLogicalTemporalTableJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ left: RelNode,
+ right: RelNode,
+ condition: RexNode)
+ extends FlinkLogicalJoinBase(
+ cluster,
+ traitSet,
+ left,
+ right,
+ condition,
+ JoinRelType.INNER) {
+
+ def copy(
+ traitSet: RelTraitSet,
+ condition: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): FlinkLogicalTemporalTableJoin = {
+ checkArgument(joinType == this.getJoinType,
+ "Can not change join type".asInstanceOf[Object])
+ checkArgument(semiJoinDone == this.isSemiJoinDone,
+ "Can not change semiJoinDone".asInstanceOf[Object])
+ new FlinkLogicalTemporalTableJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ condition)
+ }
+}
+
+class FlinkLogicalTemporalTableJoinConverter
+ extends ConverterRule(
+ classOf[LogicalTemporalTableJoin],
+ Convention.NONE,
+ FlinkConventions.LOGICAL,
+ "FlinkLogicalTemporalTableJoinConverter") {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ true
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val temporalTableJoin = rel.asInstanceOf[LogicalTemporalTableJoin]
+ val traitSet = rel.getTraitSet.replace(FlinkConventions.LOGICAL)
+ val newLeft = RelOptRule.convert(temporalTableJoin.getLeft, FlinkConventions.LOGICAL)
+ val newRight = RelOptRule.convert(temporalTableJoin.getRight, FlinkConventions.LOGICAL)
+
+ new FlinkLogicalTemporalTableJoin(
+ rel.getCluster,
+ traitSet,
+ newLeft,
+ newRight,
+ temporalTableJoin.getCondition)
+ }
+}
+
+object FlinkLogicalTemporalTableJoin {
+ val CONVERTER = new FlinkLogicalTemporalTableJoinConverter
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index 52dab8b..e4cd8d1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -38,6 +38,14 @@ object FlinkRuleSets {
SubQueryRemoveRule.JOIN)
/**
+ * Handles proper conversion of correlate queries with temporal table functions
+ * into temporal table joins. This can create new table scans in the plan.
+ */
+ val TEMPORAL_JOIN_RULES: RuleSet = RuleSets.ofList(
+ LogicalCorrelateToTemporalTableJoinRule.INSTANCE
+ )
+
+ /**
* Convert table references before query decorrelation.
*/
val TABLE_REF_RULES: RuleSet = RuleSets.ofList(
@@ -127,6 +135,7 @@ object FlinkRuleSets {
FlinkLogicalCorrelate.CONVERTER,
FlinkLogicalIntersect.CONVERTER,
FlinkLogicalJoin.CONVERTER,
+ FlinkLogicalTemporalTableJoin.CONVERTER,
FlinkLogicalMinus.CONVERTER,
FlinkLogicalSort.CONVERTER,
FlinkLogicalUnion.CONVERTER,
@@ -211,6 +220,7 @@ object FlinkRuleSets {
DataStreamCorrelateRule.INSTANCE,
DataStreamWindowJoinRule.INSTANCE,
DataStreamJoinRule.INSTANCE,
+ DataStreamTemporalTableJoinRule.INSTANCE,
StreamTableSourceScanRule.INSTANCE
)
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamTemporalTableJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamTemporalTableJoinRule.scala
new file mode 100644
index 0000000..94ff19c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamTemporalTableJoinRule.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.datastream
+
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.core.JoinRelType
+import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.plan.nodes.FlinkConventions
+import org.apache.flink.table.plan.nodes.datastream.DataStreamTemporalTableJoin
+import org.apache.flink.table.plan.nodes.logical.FlinkLogicalTemporalTableJoin
+import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.runtime.join.WindowJoinUtil
+
+class DataStreamTemporalTableJoinRule
+ extends ConverterRule(
+ classOf[FlinkLogicalTemporalTableJoin],
+ FlinkConventions.LOGICAL,
+ FlinkConventions.DATASTREAM,
+ "DataStreamTemporalTableJoinRule") {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join: FlinkLogicalTemporalTableJoin = call.rel(0)
+ val joinInfo = join.analyzeCondition
+
+ val (windowBounds, remainingPreds) = WindowJoinUtil.extractWindowBoundsFromPredicate(
+ joinInfo.getRemaining(join.getCluster.getRexBuilder),
+ join.getLeft.getRowType.getFieldCount,
+ join.getRowType,
+ join.getCluster.getRexBuilder,
+ TableConfig.DEFAULT)
+
+ windowBounds.isEmpty && join.getJoinType == JoinRelType.INNER
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val temporalJoin = rel.asInstanceOf[FlinkLogicalTemporalTableJoin]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASTREAM)
+ val left: RelNode = RelOptRule.convert(temporalJoin.getInput(0), FlinkConventions.DATASTREAM)
+ val right: RelNode = RelOptRule.convert(temporalJoin.getInput(1), FlinkConventions.DATASTREAM)
+ val joinInfo = temporalJoin.analyzeCondition
+ val leftRowSchema = new RowSchema(left.getRowType)
+ val rightRowSchema = new RowSchema(right.getRowType)
+
+ new DataStreamTemporalTableJoin(
+ rel.getCluster,
+ traitSet,
+ left,
+ right,
+ temporalJoin.getCondition,
+ joinInfo,
+ leftRowSchema,
+ rightRowSchema,
+ new RowSchema(rel.getRowType),
+ description)
+ }
+}
+
+object DataStreamTemporalTableJoinRule {
+ val INSTANCE: RelOptRule = new DataStreamTemporalTableJoinRule
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
new file mode 100644
index 0000000..cb666fa
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.calcite.plan.RelOptRule.{any, none, operand, some}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.TableFunctionScan
+import org.apache.calcite.rel.logical.LogicalCorrelate
+import org.apache.calcite.rex._
+import org.apache.flink.table.api.{Table, Types, ValidationException}
+import org.apache.flink.table.calcite.FlinkTypeFactory.{isProctimeIndicatorType, isTimeIndicatorType}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.functions.TemporalTableFunction
+import org.apache.flink.table.functions.utils.TableSqlFunction
+import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin
+import org.apache.flink.table.plan.util.RexDefaultVisitor
+import org.apache.flink.util.Preconditions.checkState
+
+class LogicalCorrelateToTemporalTableJoinRule
+ extends RelOptRule(
+ operand(classOf[LogicalCorrelate],
+ some(
+ operand(classOf[RelNode], any()),
+ operand(classOf[TableFunctionScan], none()))),
+ "LogicalCorrelateToTemporalTableJoinRule") {
+
+ def extractNameFromTimeAttribute(timeAttribute: Expression): String = {
+ timeAttribute match {
+ case ResolvedFieldReference(name, _)
+ if timeAttribute.resultType == Types.LONG ||
+ timeAttribute.resultType == Types.SQL_TIMESTAMP ||
+ isTimeIndicatorType(timeAttribute.resultType) =>
+ name
+ case _ => throw new ValidationException(
+ s"Invalid timeAttribute [$timeAttribute] in TemporalTableFunction")
+ }
+ }
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val logicalCorrelate: LogicalCorrelate = call.rel(0)
+ val leftNode: RelNode = call.rel(1)
+ val rightTableFunctionScan: TableFunctionScan = call.rel(2)
+
+ val cluster = logicalCorrelate.getCluster
+
+ new GetTemporalTableFunctionCall(cluster.getRexBuilder, leftNode)
+ .visit(rightTableFunctionScan.getCall) match {
+ case None =>
+ // Do nothing and handle standard TableFunction
+ case Some(TemporalTableFunctionCall(rightTemporalTableFunction, leftTimeAttribute)) =>
+ // If TemporalTableFunction was found, rewrite LogicalCorrelate to TemporalJoin
+ val underlyingHistoryTable: Table = rightTemporalTableFunction.getUnderlyingHistoryTable
+ val relBuilder = this.relBuilderFactory.create(
+ cluster,
+ underlyingHistoryTable.relBuilder.getRelOptSchema)
+ val rexBuilder = cluster.getRexBuilder
+
+ val rightNode: RelNode = underlyingHistoryTable.logicalPlan.toRelNode(relBuilder)
+
+ val rightTimeIndicatorExpression = createRightExpression(
+ rexBuilder,
+ leftNode,
+ rightNode,
+ extractNameFromTimeAttribute(rightTemporalTableFunction.getTimeAttribute))
+
+ val rightPrimaryKeyExpression = createRightExpression(
+ rexBuilder,
+ leftNode,
+ rightNode,
+ rightTemporalTableFunction.getPrimaryKey)
+
+ relBuilder.push(
+ if (isProctimeIndicatorType(rightTemporalTableFunction.getTimeAttribute.resultType)) {
+ LogicalTemporalTableJoin.createProctime(
+ rexBuilder,
+ cluster,
+ logicalCorrelate.getTraitSet,
+ leftNode,
+ rightNode,
+ leftTimeAttribute,
+ rightPrimaryKeyExpression)
+ }
+ else {
+ LogicalTemporalTableJoin.createRowtime(
+ rexBuilder,
+ cluster,
+ logicalCorrelate.getTraitSet,
+ leftNode,
+ rightNode,
+ leftTimeAttribute,
+ rightTimeIndicatorExpression,
+ rightPrimaryKeyExpression)
+ })
+ call.transformTo(relBuilder.build())
+ }
+ }
+
+ private def createRightExpression(
+ rexBuilder: RexBuilder,
+ leftNode: RelNode,
+ rightNode: RelNode,
+ field: String): RexNode = {
+ val rightReferencesOffset = leftNode.getRowType.getFieldCount
+ val rightDataTypeField = rightNode.getRowType.getField(field, false, false)
+ rexBuilder.makeInputRef(
+ rightDataTypeField.getType, rightReferencesOffset + rightDataTypeField.getIndex)
+ }
+}
+
+object LogicalCorrelateToTemporalTableJoinRule {
+ val INSTANCE: RelOptRule = new LogicalCorrelateToTemporalTableJoinRule
+}
+
+/**
+ * Simple pojo class for extracted [[TemporalTableFunction]] with time attribute
+ * extracted from RexNode with [[TemporalTableFunction]] call.
+ */
+case class TemporalTableFunctionCall(
+ var temporalTableFunction: TemporalTableFunction,
+ var timeAttribute: RexNode) {
+}
+
+/**
+ * Find [[TemporalTableFunction]] call and run [[CorrelatedFieldAccessRemoval]] on it's operand.
+ */
+class GetTemporalTableFunctionCall(
+ var rexBuilder: RexBuilder,
+ var leftSide: RelNode)
+ extends RexVisitorImpl[TemporalTableFunctionCall](false) {
+
+ def visit(node: RexNode): Option[TemporalTableFunctionCall] = {
+ val result = node.accept(this)
+ if (result == null) {
+ return None
+ }
+ Some(result)
+ }
+
+ override def visitCall(rexCall: RexCall): TemporalTableFunctionCall = {
+ if (!rexCall.getOperator.isInstanceOf[TableSqlFunction]) {
+ return null
+ }
+ val tableFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+
+ if (!tableFunction.getTableFunction.isInstanceOf[TemporalTableFunction]) {
+ return null
+ }
+ val temporalTableFunction = tableFunction.getTableFunction.asInstanceOf[TemporalTableFunction]
+
+ checkState(
+ rexCall.getOperands.size().equals(1),
+ "TemporalTableFunction call [%s] must have exactly one argument",
+ rexCall)
+ val correlatedFieldAccessRemoval =
+ new CorrelatedFieldAccessRemoval(temporalTableFunction, rexBuilder, leftSide)
+ TemporalTableFunctionCall(
+ temporalTableFunction,
+ rexCall.getOperands.get(0).accept(correlatedFieldAccessRemoval))
+ }
+}
+
+/**
+ * This converts field accesses like `$cor0.o_rowtime` to valid input references
+ * for join condition context without `$cor` reference.
+ */
+class CorrelatedFieldAccessRemoval(
+ var temporalTableFunction: TemporalTableFunction,
+ var rexBuilder: RexBuilder,
+ var leftSide: RelNode) extends RexDefaultVisitor[RexNode] {
+
+ override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
+ val leftIndex = leftSide.getRowType.getFieldList.indexOf(fieldAccess.getField)
+ if (leftIndex < 0) {
+ throw new IllegalStateException(
+ s"Failed to find reference to field [${fieldAccess.getField}] in node [$leftSide]")
+ }
+ rexBuilder.makeInputRef(leftSide, leftIndex)
+ }
+
+ override def visitInputRef(inputRef: RexInputRef): RexNode = {
+ inputRef
+ }
+
+ override def visitNode(rexNode: RexNode): RexNode = {
+ throw new ValidationException(
+ s"Unsupported argument [$rexNode] " +
+ s"in ${classOf[TemporalTableFunction].getSimpleName} call of " +
+ s"[${temporalTableFunction.getUnderlyingHistoryTable}] table")
+ }
+}
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexDefaultVisitor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexDefaultVisitor.scala
new file mode 100644
index 0000000..7c44616
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexDefaultVisitor.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import org.apache.calcite.rex._
+
+/**
+ * Implementation of [[RexVisitor]] that redirects all calls into generic
+ * [[RexDefaultVisitor#visitNode(org.apache.calcite.rex.RexNode)]] method.
+ */
+abstract class RexDefaultVisitor[R] extends RexVisitor[R] {
+
+ override def visitFieldAccess(fieldAccess: RexFieldAccess): R =
+ visitNode(fieldAccess)
+
+ override def visitCall(call: RexCall): R =
+ visitNode(call)
+
+ override def visitInputRef(inputRef: RexInputRef): R =
+ visitNode(inputRef)
+
+ override def visitOver(over: RexOver): R =
+ visitNode(over)
+
+ override def visitCorrelVariable(correlVariable: RexCorrelVariable): R =
+ visitNode(correlVariable)
+
+ override def visitLocalRef(localRef: RexLocalRef): R =
+ visitNode(localRef)
+
+ override def visitDynamicParam(dynamicParam: RexDynamicParam): R =
+ visitNode(dynamicParam)
+
+ override def visitRangeRef(rangeRef: RexRangeRef): R =
+ visitNode(rangeRef)
+
+ override def visitTableInputRef(tableRef: RexTableInputRef): R =
+ visitNode(tableRef)
+
+ override def visitPatternFieldRef(patternFieldRef: RexPatternFieldRef): R =
+ visitNode(patternFieldRef)
+
+ override def visitSubQuery(subQuery: RexSubQuery): R =
+ visitNode(subQuery)
+
+ override def visitLiteral(literal: RexLiteral): R =
+ visitNode(literal)
+
+ def visitNode(rexNode: RexNode): R
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/TemporalTableJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/TemporalTableJoinTest.scala
new file mode 100644
index 0000000..cca733b
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/TemporalTableJoinTest.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.batch.sql
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.{TableException, ValidationException}
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.TableTestUtil._
+import org.apache.flink.table.utils._
+import org.hamcrest.Matchers.startsWith
+import org.junit.Test
+
+class TemporalTableJoinTest extends TableTestBase {
+
+ val util: TableTestUtil = batchTestUtil()
+
+ val orders = util.addTable[(Long, String, Timestamp)](
+ "Orders", 'o_amount, 'o_currency, 'o_rowtime)
+
+ val ratesHistory = util.addTable[(String, Int, Timestamp)](
+ "RatesHistory", 'currency, 'rate, 'rowtime)
+
+ val rates = util.addFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction('rowtime, 'currency))
+
+ @Test
+ def testSimpleJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage("Cannot generate a valid execution plan for the given query")
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o_rowtime)) AS r " +
+ "WHERE currency = o_currency";
+
+ util.printSql(sqlQuery)
+ }
+
+ /**
+ * Test temporal table joins with more complicated query.
+ * Important thing here is that we have complex OR join condition
+ * and there are some columns that are not being used (are being pruned).
+ */
+ @Test(expected = classOf[TableException])
+ def testComplexJoin(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(String, Int)]("Table3", 't3_comment, 't3_secondary_key)
+ util.addTable[(Timestamp, String, Long, String, Int)](
+ "Orders", 'o_rowtime, 'o_comment, 'o_amount, 'o_currency, 'o_secondary_key)
+
+ val ratesHistory = util.addTable[(Timestamp, String, String, Int, Int)](
+ "RatesHistory", 'rowtime, 'comment, 'currency, 'rate, 'secondary_key)
+ val rates = ratesHistory.createTemporalTableFunction('rowtime, 'currency)
+ util.addFunction("Rates", rates)
+
+ val sqlQuery =
+ "SELECT * FROM " +
+ "(SELECT " +
+ "o_amount * rate as rate, " +
+ "secondary_key as secondary_key " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o_rowtime)) AS r " +
+ "WHERE currency = o_currency OR secondary_key = o_secondary_key), " +
+ "Table3 " +
+ "WHERE t3_secondary_key = secondary_key";
+
+ util.printSql(sqlQuery)
+ }
+
+ @Test
+ def testUncorrelatedJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(startsWith("Cannot generate a valid execution plan"))
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123')) AS r " +
+ "WHERE currency = o_currency";
+
+ util.printSql(sqlQuery)
+ }
+
+ @Test
+ def testTemporalTableFunctionScan(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(startsWith("Cannot generate a valid execution plan"))
+
+ val sqlQuery = "SELECT * FROM LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123'))";
+
+ util.printSql(sqlQuery)
+ }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/TemporalTableJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/TemporalTableJoinTest.scala
new file mode 100644
index 0000000..190eebe
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/TemporalTableJoinTest.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.batch.table
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.{TableException, ValidationException}
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils._
+import org.hamcrest.Matchers.startsWith
+import org.junit.Test
+
+class TemporalTableJoinTest extends TableTestBase {
+
+ val util: TableTestUtil = batchTestUtil()
+
+ val orders = util.addTable[(Long, String, Timestamp)](
+ "Orders", 'o_amount, 'o_currency, 'o_rowtime)
+
+ val ratesHistory = util.addTable[(String, Int, Timestamp)](
+ "RatesHistory", 'currency, 'rate, 'rowtime)
+
+ val rates = util.addFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction('rowtime, 'currency))
+
+ @Test
+ def testSimpleJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage("Cannot generate a valid execution plan for the given query")
+
+ val result = orders
+ .join(rates('o_rowtime), "currency = o_currency")
+ .select("o_amount * rate").as("rate")
+
+ util.printTable(result)
+ }
+
+ @Test
+ def testUncorrelatedJoin(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage(startsWith("Unsupported argument"))
+
+ val result = orders
+ .join(rates(java.sql.Timestamp.valueOf("2016-06-27 10:10:42.123")), "o_currency = currency")
+ .select("o_amount * rate")
+
+ util.printTable(result)
+ }
+
+ @Test
+ def testTemporalTableFunctionScan(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage(
+ "Cannot translate a query with an unbounded table function call.")
+
+ val result = rates(java.sql.Timestamp.valueOf("2016-06-27 10:10:42.123"))
+
+ util.printTable(result)
+ }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
new file mode 100644
index 0000000..3c47f56
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.stream.sql
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.TableException
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.stream.table.TemporalTableJoinTest._
+import org.apache.flink.table.utils._
+import org.hamcrest.Matchers.startsWith
+import org.junit.Test
+
+class TemporalTableJoinTest extends TableTestBase {
+
+ val util: TableTestUtil = streamTestUtil()
+
+ val orders = util.addTable[(Long, String, Timestamp)](
+ "Orders", 'o_amount, 'o_currency, 'o_rowtime.rowtime)
+
+ val ratesHistory = util.addTable[(String, Int, Timestamp)](
+ "RatesHistory", 'currency, 'rate, 'rowtime.rowtime)
+
+ val rates = util.addFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction('rowtime, 'currency))
+
+ val proctimeOrders = util.addTable[(Long, String)](
+ "ProctimeOrders", 'o_amount, 'o_currency, 'o_proctime.proctime)
+
+ val proctimeRatesHistory = util.addTable[(String, Int)](
+ "ProctimeRatesHistory", 'currency, 'rate, 'proctime.proctime)
+
+ val proctimeRates = util.addFunction(
+ "ProctimeRates",
+ proctimeRatesHistory.createTemporalTableFunction('proctime, 'currency))
+
+ @Test
+ def testSimpleJoin(): Unit = {
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o.o_rowtime)) AS r " +
+ "WHERE currency = o_currency";
+
+ util.verifySql(sqlQuery, getExpectedSimpleJoinPlan())
+ }
+
+ @Test
+ def testSimpleProctimeJoin(): Unit = {
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM ProctimeOrders AS o, " +
+ "LATERAL TABLE (ProctimeRates(o.o_proctime)) AS r " +
+ "WHERE currency = o_currency";
+
+ util.verifySql(sqlQuery, getExpectedSimpleProctimeJoinPlan())
+ }
+
+ /**
+ * Test versioned joins with more complicated query.
+ * Important thing here is that we have complex OR join condition
+ * and there are some columns that are not being used (are being pruned).
+ */
+ @Test
+ def testComplexJoin(): Unit = {
+ val util = streamTestUtil()
+ util.addTable[(String, Int)]("Table3", 't3_comment, 't3_secondary_key)
+ util.addTable[(Timestamp, String, Long, String, Int)](
+ "Orders", 'o_rowtime.rowtime, 'o_comment, 'o_amount, 'o_currency, 'o_secondary_key)
+
+ val ratesHistory = util.addTable[(Timestamp, String, String, Int, Int)](
+ "RatesHistory", 'rowtime.rowtime, 'comment, 'currency, 'rate, 'secondary_key)
+ val rates = ratesHistory.createTemporalTableFunction('rowtime, 'currency)
+ util.addFunction("Rates", rates)
+
+ val sqlQuery =
+ "SELECT * FROM " +
+ "(SELECT " +
+ "o_amount * rate as rate, " +
+ "secondary_key as secondary_key " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o_rowtime)) AS r " +
+ "WHERE currency = o_currency OR secondary_key = o_secondary_key), " +
+ "Table3 " +
+ "WHERE t3_secondary_key = secondary_key";
+
+ util.verifySql(sqlQuery, getExpectedComplexJoinPlan())
+ }
+
+ @Test
+ def testUncorrelatedJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(startsWith("Cannot generate a valid execution plan"))
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123')) AS r " +
+ "WHERE currency = o_currency";
+
+ util.printSql(sqlQuery)
+ }
+
+ @Test
+ def testTemporalTableFunctionScan(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(startsWith("Cannot generate a valid execution plan"))
+
+ val sqlQuery = "SELECT * FROM LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123'))";
+
+ util.printSql(sqlQuery)
+ }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala
index 0942dd3..82fa251 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/TemporalTableJoinTest.scala
@@ -15,6 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.flink.table.api.stream.table
import java.sql.Timestamp
@@ -22,9 +23,13 @@ import java.sql.Timestamp
import org.apache.flink.api.scala._
import org.apache.flink.table.api.{TableSchema, ValidationException}
import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.stream.table.TemporalTableJoinTest._
import org.apache.flink.table.expressions.ResolvedFieldReference
import org.apache.flink.table.functions.TemporalTableFunction
+import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin._
+import org.apache.flink.table.utils.TableTestUtil._
import org.apache.flink.table.utils._
+import org.hamcrest.Matchers.startsWith
import org.junit.Assert.{assertArrayEquals, assertEquals, assertTrue}
import org.junit.Test
@@ -32,6 +37,9 @@ class TemporalTableJoinTest extends TableTestBase {
val util: TableTestUtil = streamTestUtil()
+ val orders = util.addTable[(Long, String, Timestamp)](
+ "Orders", 'o_amount, 'o_currency, 'o_rowtime.rowtime)
+
val ratesHistory = util.addTable[(String, Int, Timestamp)](
"RatesHistory", 'currency, 'rate, 'rowtime.rowtime)
@@ -39,6 +47,91 @@ class TemporalTableJoinTest extends TableTestBase {
"Rates",
ratesHistory.createTemporalTableFunction('rowtime, 'currency))
+ val proctimeOrders = util.addTable[(Long, String)](
+ "ProctimeOrders", 'o_amount, 'o_currency, 'o_proctime.proctime)
+
+ val proctimeRatesHistory = util.addTable[(String, Int)](
+ "ProctimeRatesHistory", 'currency, 'rate, 'proctime.proctime)
+
+ val proctimeRates = proctimeRatesHistory.createTemporalTableFunction('proctime, 'currency)
+
+ @Test
+ def testSimpleJoin(): Unit = {
+ val result = orders
+ .join(rates('o_rowtime), "currency = o_currency")
+ .select("o_amount * rate").as("rate")
+
+ util.verifyTable(result, getExpectedSimpleJoinPlan())
+ }
+
+ @Test
+ def testSimpleProctimeJoin(): Unit = {
+ val result = proctimeOrders
+ .join(proctimeRates('o_proctime), "currency = o_currency")
+ .select("o_amount * rate").as("rate")
+
+ util.verifyTable(result, getExpectedSimpleProctimeJoinPlan())
+ }
+
+ /**
+ * Test versioned joins with more complicated query.
+ * Important thing here is that we have complex OR join condition
+ * and there are some columns that are not being used (are being pruned).
+ */
+ @Test
+ def testComplexJoin(): Unit = {
+ val util = streamTestUtil()
+ val thirdTable = util.addTable[(String, Int)]("ThirdTable", 't3_comment, 't3_secondary_key)
+ val orders = util.addTable[(Timestamp, String, Long, String, Int)](
+ "Orders", 'o_rowtime.rowtime, 'o_comment, 'o_amount, 'o_currency, 'o_secondary_key)
+
+ val ratesHistory = util.addTable[(Timestamp, String, String, Int, Int)](
+ "RatesHistory", 'rowtime.rowtime, 'comment, 'currency, 'rate, 'secondary_key)
+ val rates = ratesHistory.createTemporalTableFunction('rowtime, 'currency)
+ util.addFunction("Rates", rates)
+
+ val result = orders
+ .join(rates('o_rowtime))
+ .filter('currency === 'o_currency || 'secondary_key === 'o_secondary_key)
+ .select('o_amount * 'rate, 'secondary_key).as('rate, 'secondary_key)
+ .join(thirdTable, 't3_secondary_key === 'secondary_key)
+
+ util.verifyTable(result, getExpectedComplexJoinPlan())
+ }
+
+ @Test
+ def testTemporalTableFunctionOnTopOfQuery(): Unit = {
+ val filteredRatesHistory = ratesHistory
+ .filter('rate > 100)
+ .select('currency, 'rate * 2, 'rowtime)
+ .as('currency, 'rate, 'rowtime)
+
+ val filteredRates = util.addFunction(
+ "FilteredRates",
+ filteredRatesHistory.createTemporalTableFunction('rowtime, 'currency))
+
+ val result = orders
+ .join(filteredRates('o_rowtime), "currency = o_currency")
+ .select("o_amount * rate")
+ .as('rate)
+
+ util.verifyTable(result, getExpectedTemporalTableFunctionOnTopOfQueryPlan())
+ }
+
+ @Test
+ def testUncorrelatedJoin(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage(startsWith("Unsupported argument"))
+
+ val result = orders
+ .join(rates(
+ java.sql.Timestamp.valueOf("2016-06-27 10:10:42.123")),
+ "o_currency = currency")
+ .select("o_amount * rate")
+
+ util.printTable(result)
+ }
+
@Test
def testTemporalTableFunctionScan(): Unit = {
expectedException.expect(classOf[ValidationException])
@@ -51,11 +144,7 @@ class TemporalTableJoinTest extends TableTestBase {
@Test
def testProcessingTimeIndicatorVersion(): Unit = {
- val util: TableTestUtil = streamTestUtil()
- val ratesHistory = util.addTable[(String, Int)](
- "RatesHistory", 'currency, 'rate, 'proctime.proctime)
- val rates = ratesHistory.createTemporalTableFunction('proctime, 'currency)
- assertRatesFunction(ratesHistory.getSchema, rates, true)
+ assertRatesFunction(proctimeRatesHistory.getSchema, proctimeRates, true)
}
@Test
@@ -82,3 +171,106 @@ class TemporalTableJoinTest extends TableTestBase {
}
}
+object TemporalTableJoinTest {
+ def getExpectedSimpleJoinPlan(): String = {
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamTemporalTableJoin",
+ streamTableNode(0),
+ streamTableNode(1),
+ term("where",
+ "AND(" +
+ s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " +
+ "=(currency, o_currency))"),
+ term("join", "o_amount", "o_currency", "o_rowtime", "currency", "rate", "rowtime"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "*(o_amount, rate) AS rate")
+ )
+ }
+
+ def getExpectedSimpleProctimeJoinPlan(): String = {
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamTemporalTableJoin",
+ streamTableNode(2),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(3),
+ term("select", "currency, rate")),
+ term("where",
+ "AND(" +
+ s"${TEMPORAL_JOIN_CONDITION.getName}(o_proctime, currency), " +
+ "=(currency, o_currency))"),
+ term("join", "o_amount", "o_currency", "o_proctime", "currency", "rate"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "*(o_amount, rate) AS rate")
+ )
+ }
+
+ def getExpectedComplexJoinPlan(): String = {
+ binaryNode(
+ "DataStreamJoin",
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamTemporalTableJoin",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(1),
+ term("select", "o_rowtime, o_amount, o_currency, o_secondary_key")
+ ),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(2),
+ term("select", "rowtime, currency, rate, secondary_key")
+ ),
+ term("where",
+ "AND(" +
+ s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " +
+ "OR(=(currency, o_currency), =(secondary_key, o_secondary_key)))"),
+ term("join",
+ "o_rowtime",
+ "o_amount",
+ "o_currency",
+ "o_secondary_key",
+ "rowtime",
+ "currency",
+ "rate",
+ "secondary_key"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "*(o_amount, rate) AS rate", "secondary_key")
+ ),
+ streamTableNode(0),
+ term("where", "=(t3_secondary_key, secondary_key)"),
+ term("join", "rate, secondary_key, t3_comment, t3_secondary_key"),
+ term("joinType", "InnerJoin")
+ )
+ }
+
+ def getExpectedTemporalTableFunctionOnTopOfQueryPlan(): String = {
+ unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamTemporalTableJoin",
+ streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(1),
+ term("select", "currency", "*(rate, 2) AS rate", "rowtime"),
+ term("where", ">(rate, 100)")),
+ term("where",
+ "AND(" +
+ s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, currency), " +
+ "=(currency, o_currency))"),
+ term("join", "o_amount", "o_currency", "o_rowtime", "currency", "rate", "rowtime"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "*(o_amount, rate) AS rate")
+ )
+ }
+}
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/TemporalTableJoinValidationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/TemporalTableJoinValidationTest.scala
index 71b1585..ab282ec 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/TemporalTableJoinValidationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/validation/TemporalTableJoinValidationTest.scala
@@ -52,5 +52,3 @@ class TemporalTableJoinValidationTest extends TableTestBase {
ratesHistory.createTemporalTableFunction("rowtime", "foobar")
}
}
-
-
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
index 6a77f12..1706169 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/TimeIndicatorConversionTest.scala
@@ -394,6 +394,116 @@ class TimeIndicatorConversionTest extends TableTestBase {
util.verifyTable(result, expected)
}
+ @Test
+ def testMaterializeRightSideOfTemporalTableJoin(): Unit = {
+ val util = streamTestUtil()
+
+ val proctimeOrders = util.addTable[(Long, String)](
+ "ProctimeOrders", 'o_amount, 'o_currency, 'o_proctime.proctime)
+
+ val proctimeRatesHistory = util.addTable[(String, Int)](
+ "ProctimeRatesHistory", 'currency, 'rate, 'proctime.proctime)
+
+ val proctimeRates = proctimeRatesHistory.createTemporalTableFunction('proctime, 'currency)
+
+ val result = proctimeOrders
+ .join(proctimeRates('o_proctime), "currency = o_currency")
+ .select("o_amount * rate, currency, proctime").as("converted_amount")
+ .window(Tumble over 1.second on 'proctime as 'w)
+ .groupBy('w, 'currency)
+ .select('converted_amount.sum)
+
+ val expected =
+ unaryAnyNode(
+ unaryAnyNode(
+ unaryNode(
+ "DataStreamCalc",
+ anySubtree(),
+ term(
+ "select",
+ "*(o_amount, rate) AS converted_amount",
+ "currency",
+ "PROCTIME(proctime) AS proctime")
+ )
+ )
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testDoNotMaterializeLeftSideOfTemporalTableJoin(): Unit = {
+ val util = streamTestUtil()
+
+ val proctimeOrders = util.addTable[(Long, String)](
+ "ProctimeOrders", 'o_amount, 'o_currency, 'o_proctime.proctime)
+
+ val proctimeRatesHistory = util.addTable[(String, Int)](
+ "ProctimeRatesHistory", 'currency, 'rate, 'proctime.proctime)
+
+ val proctimeRates = proctimeRatesHistory.createTemporalTableFunction('proctime, 'currency)
+
+ val result = proctimeOrders
+ .join(proctimeRates('o_proctime), "currency = o_currency")
+ .select("o_amount * rate, currency, o_proctime").as("converted_amount")
+ .window(Tumble over 1.second on 'o_proctime as 'w)
+ .groupBy('w, 'currency)
+ .select('converted_amount.sum)
+
+ val expected =
+ unaryAnyNode(
+ unaryAnyNode(
+ unaryNode(
+ "DataStreamCalc",
+ anySubtree(),
+ term(
+ "select",
+ "*(o_amount, rate) AS converted_amount",
+ "currency",
+ "o_proctime")
+ )
+ )
+ )
+
+ util.verifyTable(result, expected)
+ }
+
+ @Test
+ def testMaterializeLeftRowtimeWithProcessingTimeTemporalTableJoin(): Unit = {
+ val util = streamTestUtil()
+
+ val proctimeOrders = util.addTable[(Long, String)](
+ "ProctimeOrders", 'o_amount, 'o_currency, 'o_proctime.proctime, 'o_rowtime.rowtime)
+
+ val proctimeRatesHistory = util.addTable[(String, Int)](
+ "ProctimeRatesHistory", 'currency, 'rate, 'proctime.proctime)
+
+ val proctimeRates = proctimeRatesHistory.createTemporalTableFunction('proctime, 'currency)
+
+ val result = proctimeOrders
+ .join(proctimeRates('o_proctime), "currency = o_currency")
+ .select("o_amount * rate, currency, o_proctime, o_rowtime").as("converted_amount")
+ .window(Tumble over 1.second on 'o_rowtime as 'w)
+ .groupBy('w, 'currency)
+ .select('converted_amount.sum)
+
+ val expected =
+ unaryAnyNode(
+ unaryAnyNode(
+ unaryNode(
+ "DataStreamCalc",
+ anySubtree(),
+ term(
+ "select",
+ "*(o_amount, rate) AS converted_amount",
+ "currency",
+ "CAST(o_rowtime) AS o_rowtime")
+ )
+ )
+ )
+
+ util.verifyTable(result, expected)
+ }
}
object TimeIndicatorConversionTest {
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
index b987e34..42b16e9 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala
@@ -33,10 +33,12 @@ import org.apache.flink.table.api.{Table, TableEnvironment, TableSchema}
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
import org.junit.Assert.assertEquals
-import org.junit.Rule
+import org.junit.{ComparisonFailure, Rule}
import org.junit.rules.ExpectedException
import org.mockito.Mockito.{mock, when}
+import util.control.Breaks._
+
/**
* Test base for testing Table API / SQL plans.
*/
@@ -99,17 +101,48 @@ abstract class TableTestUtil {
// we remove the charset for testing because it
// depends on the native machine (Little/Big Endian)
val actualNoCharset = actual.replace("_UTF-16LE'", "'").replace("_UTF-16BE'", "'")
- assertEquals(
- expected.split("\n").map(_.trim).mkString("\n"),
- actualNoCharset.split("\n").map(_.trim).mkString("\n"))
+
+ val expectedLines = expected.split("\n").map(_.trim)
+ val actualLines = actualNoCharset.split("\n").map(_.trim)
+
+ val expectedMessage = expectedLines.mkString("\n")
+ val actualMessage = actualLines.mkString("\n")
+
+ breakable {
+ for ((expectedLine, actualLine) <- expectedLines.zip(actualLines)) {
+ if (expectedLine == TableTestUtil.ANY_NODE) {
+ }
+ else if (expectedLine == TableTestUtil.ANY_SUBTREE) {
+ break
+ }
+ else if (expectedLine != actualLine) {
+ throw new ComparisonFailure(null, expectedMessage, actualMessage)
+ }
+ }
+ }
}
+
+ def explain(resultTable: Table): String
}
object TableTestUtil {
+ val ANY_NODE = "%ANY_NODE%"
+
+ val ANY_SUBTREE = "%ANY_SUBTREE%"
// this methods are currently just for simplifying string construction,
// we could replace it with logic later
+ def unaryAnyNode(input: String): String = {
+ s"""$ANY_NODE
+ |$input
+ |""".stripMargin.stripLineEnd
+ }
+
+ def anySubtree(): String = {
+ ANY_SUBTREE
+ }
+
def unaryNode(node: String, input: String, term: String*): String = {
s"""$node(${term.mkString(", ")})
|$input
@@ -230,6 +263,10 @@ case class BatchTableTestUtil() extends TableTestUtil {
def printSql(query: String): Unit = {
printTable(tableEnv.sqlQuery(query))
}
+
+ def explain(resultTable: Table): String = {
+ tableEnv.explain(resultTable)
+ }
}
case class StreamTableTestUtil() extends TableTestUtil {
@@ -318,6 +355,10 @@ case class StreamTableTestUtil() extends TableTestUtil {
def printSql(query: String): Unit = {
printTable(tableEnv.sqlQuery(query))
}
+
+ def explain(resultTable: Table): String = {
+ tableEnv.explain(resultTable)
+ }
}
class EmptySource[T]() extends SourceFunction[T] {