You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/08/29 06:43:38 UTC
[flink] 04/08: [FLINK-13774][table-planner-blink] Modify filterable
table source accept ResolvedExpression
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit de22d7c0d5afd3233ab8e174ed4d837e08438ab3
Author: JingsongLi <lz...@aliyun.com>
AuthorDate: Thu Aug 22 12:46:52 2019 +0200
[FLINK-13774][table-planner-blink] Modify filterable table source accept ResolvedExpression
---
.../planner/plan/utils/RexNodeExtractor.scala | 52 ++++++++++++----------
.../table/planner/utils/testTableSources.scala | 12 ++---
2 files changed, 35 insertions(+), 29 deletions(-)
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala
index f938c79..b4535bf 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala
@@ -28,6 +28,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.utils.Logging
import org.apache.flink.table.runtime.functions.SqlDateTimeUtils.unixTimestampToLocalDateTime
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
+import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.util.Preconditions
@@ -294,9 +295,9 @@ class RexNodeToExpressionConverter(
inputNames: Array[String],
functionCatalog: FunctionCatalog,
timeZone: TimeZone)
- extends RexVisitor[Option[Expression]] {
+ extends RexVisitor[Option[ResolvedExpression]] {
- override def visitInputRef(inputRef: RexInputRef): Option[Expression] = {
+ override def visitInputRef(inputRef: RexInputRef): Option[ResolvedExpression] = {
Preconditions.checkArgument(inputRef.getIndex < inputNames.length)
Some(new FieldReferenceExpression(
inputNames(inputRef.getIndex),
@@ -306,14 +307,14 @@ class RexNodeToExpressionConverter(
))
}
- override def visitTableInputRef(rexTableInputRef: RexTableInputRef): Option[Expression] =
+ override def visitTableInputRef(rexTableInputRef: RexTableInputRef): Option[ResolvedExpression] =
visitInputRef(rexTableInputRef)
- override def visitLocalRef(localRef: RexLocalRef): Option[Expression] = {
+ override def visitLocalRef(localRef: RexLocalRef): Option[ResolvedExpression] = {
throw new TableException("Bug: RexLocalRef should have been expanded")
}
- override def visitLiteral(literal: RexLiteral): Option[Expression] = {
+ override def visitLiteral(literal: RexLiteral): Option[ResolvedExpression] = {
// TODO support SqlTrimFunction.Flag
literal.getValue match {
case _: SqlTrimFunction.Flag => return None
@@ -384,53 +385,58 @@ class RexNodeToExpressionConverter(
fromLogicalTypeToDataType(literalType)))
}
- override def visitCall(rexCall: RexCall): Option[Expression] = {
+ override def visitCall(rexCall: RexCall): Option[ResolvedExpression] = {
val operands = rexCall.getOperands.map(
operand => operand.accept(this).orNull
)
+ val outputType = fromLogicalTypeToDataType(FlinkTypeFactory.toLogicalType(rexCall.getType))
+
// return null if we cannot translate all the operands of the call
if (operands.contains(null)) {
None
} else {
rexCall.getOperator match {
case SqlStdOperatorTable.OR =>
- Option(operands.reduceLeft { (l, r) => unresolvedCall(OR, l, r) })
+ Option(operands.reduceLeft((l, r) => new CallExpression(OR, Seq(l, r), outputType)))
case SqlStdOperatorTable.AND =>
- Option(operands.reduceLeft { (l, r) => unresolvedCall(AND, l, r) })
+ Option(operands.reduceLeft((l, r) => new CallExpression(AND, Seq(l, r), outputType)))
case SqlStdOperatorTable.CAST =>
- Option(unresolvedCall(CAST, operands.head,
- typeLiteral(fromLogicalTypeToDataType(
- FlinkTypeFactory.toLogicalType(rexCall.getType)))))
+ Option(new CallExpression(CAST, Seq(operands.head, typeLiteral(outputType)), outputType))
case function: SqlFunction =>
- lookupFunction(replace(function.getName), operands)
+ lookupFunction(replace(function.getName), operands, outputType)
case postfix: SqlPostfixOperator =>
- lookupFunction(replace(postfix.getName), operands)
+ lookupFunction(replace(postfix.getName), operands, outputType)
case operator@_ =>
- lookupFunction(replace(s"${operator.getKind}"), operands)
+ lookupFunction(replace(s"${operator.getKind}"), operands, outputType)
}
}
}
- override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[Expression] = None
+ override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = None
- override def visitCorrelVariable(correlVariable: RexCorrelVariable): Option[Expression] = None
+ override def visitCorrelVariable(
+ correlVariable: RexCorrelVariable): Option[ResolvedExpression] = None
- override def visitRangeRef(rangeRef: RexRangeRef): Option[Expression] = None
+ override def visitRangeRef(rangeRef: RexRangeRef): Option[ResolvedExpression] = None
- override def visitSubQuery(subQuery: RexSubQuery): Option[Expression] = None
+ override def visitSubQuery(subQuery: RexSubQuery): Option[ResolvedExpression] = None
- override def visitDynamicParam(dynamicParam: RexDynamicParam): Option[Expression] = None
+ override def visitDynamicParam(dynamicParam: RexDynamicParam): Option[ResolvedExpression] = None
- override def visitOver(over: RexOver): Option[Expression] = None
+ override def visitOver(over: RexOver): Option[ResolvedExpression] = None
- override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): Option[Expression] = None
+ override def visitPatternFieldRef(
+ fieldRef: RexPatternFieldRef): Option[ResolvedExpression] = None
- private def lookupFunction(name: String, operands: Seq[Expression]): Option[Expression] = {
+ private def lookupFunction(
+ name: String,
+ operands: Seq[ResolvedExpression],
+ outputType: DataType): Option[ResolvedExpression] = {
Try(functionCatalog.lookupFunction(name)) match {
case Success(f: java.util.Optional[FunctionLookup.Result]) =>
if (f.isPresent) {
- Some(unresolvedCall(f.get().getFunctionDefinition, operands: _*))
+ Some(new CallExpression(f.get().getFunctionDefinition, operands, outputType))
} else {
None
}
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala
index 44cb4eb..24fab42 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala
@@ -29,7 +29,7 @@ import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.table.api.{TableSchema, Types}
import org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedCall
-import org.apache.flink.table.expressions.{Expression, FieldReferenceExpression, UnresolvedCallExpression, ValueLiteralExpression}
+import org.apache.flink.table.expressions.{CallExpression, Expression, FieldReferenceExpression, ValueLiteralExpression}
import org.apache.flink.table.functions.BuiltInFunctionDefinitions
import org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND
import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
@@ -398,12 +398,12 @@ class TestFilterableTableSource(
private def shouldPushDown(expr: Expression): Boolean = {
expr match {
- case expr: UnresolvedCallExpression if expr.getChildren.size() == 2 => shouldPushDown(expr)
+ case expr: CallExpression if expr.getChildren.size() == 2 => shouldPushDown(expr)
case _ => false
}
}
- private def shouldPushDown(binExpr: UnresolvedCallExpression): Boolean = {
+ private def shouldPushDown(binExpr: CallExpression): Boolean = {
val children = binExpr.getChildren
require(children.size() == 2)
(children.head, children.last) match {
@@ -419,13 +419,13 @@ class TestFilterableTableSource(
private def shouldKeep(row: Row): Boolean = {
filterPredicates.isEmpty || filterPredicates.forall {
- case expr: UnresolvedCallExpression if expr.getChildren.size() == 2 =>
+ case expr: CallExpression if expr.getChildren.size() == 2 =>
binaryFilterApplies(expr, row)
case expr => throw new RuntimeException(expr + " not supported!")
}
}
- private def binaryFilterApplies(binExpr: UnresolvedCallExpression, row: Row): Boolean = {
+ private def binaryFilterApplies(binExpr: CallExpression, row: Row): Boolean = {
val children = binExpr.getChildren
require(children.size() == 2)
val (lhsValue, rhsValue) = extractValues(binExpr, row)
@@ -447,7 +447,7 @@ class TestFilterableTableSource(
}
private def extractValues(
- binExpr: UnresolvedCallExpression,
+ binExpr: CallExpression,
row: Row): (Comparable[Any], Comparable[Any]) = {
val children = binExpr.getChildren
require(children.size() == 2)