You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2019/08/29 06:43:38 UTC

[flink] 04/08: [FLINK-13774][table-planner-blink] Modify filterable table source accept ResolvedExpression

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit de22d7c0d5afd3233ab8e174ed4d837e08438ab3
Author: JingsongLi <lz...@aliyun.com>
AuthorDate: Thu Aug 22 12:46:52 2019 +0200

    [FLINK-13774][table-planner-blink] Modify filterable table source accept ResolvedExpression
---
 .../planner/plan/utils/RexNodeExtractor.scala      | 52 ++++++++++++----------
 .../table/planner/utils/testTableSources.scala     | 12 ++---
 2 files changed, 35 insertions(+), 29 deletions(-)

diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala
index f938c79..b4535bf 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala
@@ -28,6 +28,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory
 import org.apache.flink.table.planner.utils.Logging
 import org.apache.flink.table.runtime.functions.SqlDateTimeUtils.unixTimestampToLocalDateTime
 import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
+import org.apache.flink.table.types.DataType
 import org.apache.flink.table.types.logical.LogicalTypeRoot._
 import org.apache.flink.util.Preconditions
 
@@ -294,9 +295,9 @@ class RexNodeToExpressionConverter(
     inputNames: Array[String],
     functionCatalog: FunctionCatalog,
     timeZone: TimeZone)
-  extends RexVisitor[Option[Expression]] {
+  extends RexVisitor[Option[ResolvedExpression]] {
 
-  override def visitInputRef(inputRef: RexInputRef): Option[Expression] = {
+  override def visitInputRef(inputRef: RexInputRef): Option[ResolvedExpression] = {
     Preconditions.checkArgument(inputRef.getIndex < inputNames.length)
     Some(new FieldReferenceExpression(
       inputNames(inputRef.getIndex),
@@ -306,14 +307,14 @@ class RexNodeToExpressionConverter(
     ))
   }
 
-  override def visitTableInputRef(rexTableInputRef: RexTableInputRef): Option[Expression] =
+  override def visitTableInputRef(rexTableInputRef: RexTableInputRef): Option[ResolvedExpression] =
     visitInputRef(rexTableInputRef)
 
-  override def visitLocalRef(localRef: RexLocalRef): Option[Expression] = {
+  override def visitLocalRef(localRef: RexLocalRef): Option[ResolvedExpression] = {
     throw new TableException("Bug: RexLocalRef should have been expanded")
   }
 
-  override def visitLiteral(literal: RexLiteral): Option[Expression] = {
+  override def visitLiteral(literal: RexLiteral): Option[ResolvedExpression] = {
     // TODO support SqlTrimFunction.Flag
     literal.getValue match {
       case _: SqlTrimFunction.Flag => return None
@@ -384,53 +385,58 @@ class RexNodeToExpressionConverter(
       fromLogicalTypeToDataType(literalType)))
   }
 
-  override def visitCall(rexCall: RexCall): Option[Expression] = {
+  override def visitCall(rexCall: RexCall): Option[ResolvedExpression] = {
     val operands = rexCall.getOperands.map(
       operand => operand.accept(this).orNull
     )
 
+    val outputType = fromLogicalTypeToDataType(FlinkTypeFactory.toLogicalType(rexCall.getType))
+
     // return null if we cannot translate all the operands of the call
     if (operands.contains(null)) {
       None
     } else {
       rexCall.getOperator match {
         case SqlStdOperatorTable.OR =>
-          Option(operands.reduceLeft { (l, r) => unresolvedCall(OR, l, r) })
+          Option(operands.reduceLeft((l, r) => new CallExpression(OR, Seq(l, r), outputType)))
         case SqlStdOperatorTable.AND =>
-          Option(operands.reduceLeft { (l, r) => unresolvedCall(AND, l, r) })
+          Option(operands.reduceLeft((l, r) => new CallExpression(AND, Seq(l, r), outputType)))
         case SqlStdOperatorTable.CAST =>
-          Option(unresolvedCall(CAST, operands.head,
-            typeLiteral(fromLogicalTypeToDataType(
-              FlinkTypeFactory.toLogicalType(rexCall.getType)))))
+          Option(new CallExpression(CAST, Seq(operands.head, typeLiteral(outputType)), outputType))
         case function: SqlFunction =>
-          lookupFunction(replace(function.getName), operands)
+          lookupFunction(replace(function.getName), operands, outputType)
         case postfix: SqlPostfixOperator =>
-          lookupFunction(replace(postfix.getName), operands)
+          lookupFunction(replace(postfix.getName), operands, outputType)
         case operator@_ =>
-          lookupFunction(replace(s"${operator.getKind}"), operands)
+          lookupFunction(replace(s"${operator.getKind}"), operands, outputType)
       }
     }
   }
 
-  override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[Expression] = None
+  override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = None
 
-  override def visitCorrelVariable(correlVariable: RexCorrelVariable): Option[Expression] = None
+  override def visitCorrelVariable(
+      correlVariable: RexCorrelVariable): Option[ResolvedExpression] = None
 
-  override def visitRangeRef(rangeRef: RexRangeRef): Option[Expression] = None
+  override def visitRangeRef(rangeRef: RexRangeRef): Option[ResolvedExpression] = None
 
-  override def visitSubQuery(subQuery: RexSubQuery): Option[Expression] = None
+  override def visitSubQuery(subQuery: RexSubQuery): Option[ResolvedExpression] = None
 
-  override def visitDynamicParam(dynamicParam: RexDynamicParam): Option[Expression] = None
+  override def visitDynamicParam(dynamicParam: RexDynamicParam): Option[ResolvedExpression] = None
 
-  override def visitOver(over: RexOver): Option[Expression] = None
+  override def visitOver(over: RexOver): Option[ResolvedExpression] = None
 
-  override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): Option[Expression] = None
+  override def visitPatternFieldRef(
+      fieldRef: RexPatternFieldRef): Option[ResolvedExpression] = None
 
-  private def lookupFunction(name: String, operands: Seq[Expression]): Option[Expression] = {
+  private def lookupFunction(
+      name: String,
+      operands: Seq[ResolvedExpression],
+      outputType: DataType): Option[ResolvedExpression] = {
     Try(functionCatalog.lookupFunction(name)) match {
       case Success(f: java.util.Optional[FunctionLookup.Result]) =>
         if (f.isPresent) {
-          Some(unresolvedCall(f.get().getFunctionDefinition, operands: _*))
+          Some(new CallExpression(f.get().getFunctionDefinition, operands, outputType))
         } else {
           None
         }
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala
index 44cb4eb..24fab42 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala
@@ -29,7 +29,7 @@ import org.apache.flink.streaming.api.datastream.DataStream
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
 import org.apache.flink.table.api.{TableSchema, Types}
 import org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedCall
-import org.apache.flink.table.expressions.{Expression, FieldReferenceExpression, UnresolvedCallExpression, ValueLiteralExpression}
+import org.apache.flink.table.expressions.{CallExpression, Expression, FieldReferenceExpression, ValueLiteralExpression}
 import org.apache.flink.table.functions.BuiltInFunctionDefinitions
 import org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND
 import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
@@ -398,12 +398,12 @@ class TestFilterableTableSource(
 
   private def shouldPushDown(expr: Expression): Boolean = {
     expr match {
-      case expr: UnresolvedCallExpression if expr.getChildren.size() == 2 => shouldPushDown(expr)
+      case expr: CallExpression if expr.getChildren.size() == 2 => shouldPushDown(expr)
       case _ => false
     }
   }
 
-  private def shouldPushDown(binExpr: UnresolvedCallExpression): Boolean = {
+  private def shouldPushDown(binExpr: CallExpression): Boolean = {
     val children = binExpr.getChildren
     require(children.size() == 2)
     (children.head, children.last) match {
@@ -419,13 +419,13 @@ class TestFilterableTableSource(
 
   private def shouldKeep(row: Row): Boolean = {
     filterPredicates.isEmpty || filterPredicates.forall {
-      case expr: UnresolvedCallExpression if expr.getChildren.size() == 2 =>
+      case expr: CallExpression if expr.getChildren.size() == 2 =>
         binaryFilterApplies(expr, row)
       case expr => throw new RuntimeException(expr + " not supported!")
     }
   }
 
-  private def binaryFilterApplies(binExpr: UnresolvedCallExpression, row: Row): Boolean = {
+  private def binaryFilterApplies(binExpr: CallExpression, row: Row): Boolean = {
     val children = binExpr.getChildren
     require(children.size() == 2)
     val (lhsValue, rhsValue) = extractValues(binExpr, row)
@@ -447,7 +447,7 @@ class TestFilterableTableSource(
   }
 
   private def extractValues(
-      binExpr: UnresolvedCallExpression,
+      binExpr: CallExpression,
       row: Row): (Comparable[Any], Comparable[Any]) = {
     val children = binExpr.getChildren
     require(children.size() == 2)