You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2018/09/21 11:43:37 UTC

[flink] 04/11: [hotfix][table] Deduplicate RelTimeInidicatoConverter logic

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 3cca8b654b6a245085a0a911a39ff3a8ef2b3ed6
Author: Piotr Nowojski <pi...@gmail.com>
AuthorDate: Thu Sep 20 13:15:43 2018 +0200

    [hotfix][table] Deduplicate RelTimeInidicatoConverter logic
---
 .../table/calcite/RelTimeIndicatorConverter.scala  | 160 +++++++++++----------
 1 file changed, 86 insertions(+), 74 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
index 4f3fbaa..f67b715 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
@@ -42,10 +42,7 @@ import scala.collection.mutable
   */
 class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
 
-  private val timestamp = rexBuilder
-      .getTypeFactory
-      .asInstanceOf[FlinkTypeFactory]
-      .createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP, isNullable = false)
+  val materializerUtils = new RexTimeIndicatorMaterializerUtils(rexBuilder)
 
   override def visit(intersect: LogicalIntersect): RelNode =
     throw new TableException("Logical intersect in a stream environment is not supported yet.")
@@ -213,23 +210,9 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
 
     // add a project to materialize aggregation arguments/grouping keys
 
-    val refIndices = mutable.Set[Int]()
-
-    // check arguments of agg calls
-    aggregate.getAggCallList.foreach(call => if (call.getArgList.size() == 0) {
-        // count(*) has an empty argument list
-        (0 until input.getRowType.getFieldCount).foreach(refIndices.add)
-      } else {
-        // for other aggregations
-        call.getArgList.map(_.asInstanceOf[Int]).foreach(refIndices.add)
-      })
+    val indicesToMaterialize = gatherIndicesToMaterialize(aggregate)
 
-    // check grouping sets
-    aggregate.getGroupSets.foreach(set =>
-      set.asList().map(_.asInstanceOf[Int]).foreach(refIndices.add)
-    )
-
-    val needsMaterialization = refIndices.exists(idx =>
+    val needsMaterialization = indicesToMaterialize.exists(idx =>
       isTimeIndicatorType(input.getRowType.getFieldList.get(idx).getType))
 
     // create project if necessary
@@ -242,17 +225,7 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
         // merge
         case lp: LogicalProject =>
           val projects = lp.getProjects.zipWithIndex.map { case (expr, idx) =>
-            if (isTimeIndicatorType(expr.getType) && refIndices.contains(idx)) {
-              if (isRowtimeIndicatorType(expr.getType)) {
-                // cast rowtime indicator to regular timestamp
-                rexBuilder.makeAbstractCast(timestamp, expr)
-              } else {
-                // generate proctime access
-                rexBuilder.makeCall(ProctimeSqlFunction, expr)
-              }
-            } else {
-              expr
-            }
+            materializerUtils.materializeIfContains(expr, idx, indicesToMaterialize)
           }
 
           LogicalProject.create(
@@ -262,28 +235,7 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
 
         // new project
         case _ =>
-          val projects = input.getRowType.getFieldList.map { field =>
-            if (isTimeIndicatorType(field.getType) && refIndices.contains(field.getIndex)) {
-              if (isRowtimeIndicatorType(field.getType)) {
-                // cast rowtime indicator to regular timestamp
-                rexBuilder.makeAbstractCast(
-                  timestamp,
-                  new RexInputRef(field.getIndex, field.getType))
-              } else {
-                // generate proctime access
-                rexBuilder.makeCall(
-                  ProctimeSqlFunction,
-                  new RexInputRef(field.getIndex, field.getType))
-              }
-            } else {
-              new RexInputRef(field.getIndex, field.getType)
-            }
-          }
-
-          LogicalProject.create(
-            input,
-            projects,
-            input.getRowType.getFieldNames)
+          materializerUtils.projectAndMaterializeFields(input, indicesToMaterialize)
       }
     } else {
       // no project necessary
@@ -293,7 +245,7 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
     // remove time indicator type as agg call return type
     val updatedAggCalls = aggregate.getAggCallList.map { call =>
       val callType = if (isTimeIndicatorType(call.getType)) {
-        timestamp
+        materializerUtils.getTimestamp
       } else {
         call.getType
       }
@@ -314,6 +266,25 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder) extends RelShuttle {
       updatedAggCalls)
   }
 
+  private def gatherIndicesToMaterialize(aggregate: Aggregate): Set[Int] = {
+    val indicesToMaterialize = mutable.Set[Int]()
+
+    // check arguments of agg calls
+    aggregate.getAggCallList.foreach(call => if (call.getArgList.size() == 0) {
+      // count(*) has an empty argument list
+      (0 until aggregate.getRowType.getFieldCount).foreach(indicesToMaterialize.add)
+    } else {
+      // for other aggregations
+      call.getArgList.map(_.asInstanceOf[Int]).foreach(indicesToMaterialize.add)
+    })
+
+    // check grouping sets
+    aggregate.getGroupSets.foreach(set =>
+      set.asList().map(_.asInstanceOf[Int]).foreach(indicesToMaterialize.add)
+    )
+
+    indicesToMaterialize.toSet
+  }
 }
 
 object RelTimeIndicatorConverter {
@@ -365,20 +336,21 @@ object RelTimeIndicatorConverter {
   }
 }
 
+/**
+  * Takes `newResolvedInput` types of the [[RexNode]] and if those types have changed rewrites
+  * the [[RexNode]] to make it consistent with new type.
+  */
 class RexTimeIndicatorMaterializer(
   private val rexBuilder: RexBuilder,
-  private val input: Seq[RelDataType])
+  private val newResolvedInput: Seq[RelDataType])
   extends RexShuttle {
 
-  private val timestamp = rexBuilder
-    .getTypeFactory
-    .asInstanceOf[FlinkTypeFactory]
-    .createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP, isNullable = false)
+  private val materializerUtils = new RexTimeIndicatorMaterializerUtils(rexBuilder)
 
   override def visitInputRef(inputRef: RexInputRef): RexNode = {
     // reference is interesting
     if (isTimeIndicatorType(inputRef.getType)) {
-      val resolvedRefType = input(inputRef.getIndex)
+      val resolvedRefType = newResolvedInput(inputRef.getIndex)
       // input is a valid time indicator
       if (isTimeIndicatorType(resolvedRefType)) {
         inputRef
@@ -405,19 +377,7 @@ class RexTimeIndicatorMaterializer(
         updatedCall.getOperands.toList
 
       case _ =>
-        updatedCall.getOperands.map { o =>
-          if (isTimeIndicatorType(o.getType)) {
-            if (isRowtimeIndicatorType(o.getType)) {
-              // cast rowtime indicator to regular timestamp
-              rexBuilder.makeAbstractCast(timestamp, o)
-            } else {
-              // generate proctime access
-              rexBuilder.makeCall(ProctimeSqlFunction, o)
-            }
-          } else {
-            o
-          }
-        }
+        updatedCall.getOperands.map { materializerUtils.materialize }
     }
 
     // remove time indicator return type
@@ -442,7 +402,7 @@ class RexTimeIndicatorMaterializer(
 
       // materialize function's result and operands
       case _ if isTimeIndicatorType(updatedCall.getType) =>
-        updatedCall.clone(timestamp, materializedOperands)
+        updatedCall.clone(materializerUtils.getTimestamp, materializedOperands)
 
       // materialize function's operands only
       case _ =>
@@ -450,3 +410,55 @@ class RexTimeIndicatorMaterializer(
     }
   }
 }
+
+/**
+  * Helper class for shared logic of materializing time attributes in [[RelNode]] and [[RexNode]].
+  */
+class RexTimeIndicatorMaterializerUtils(rexBuilder: RexBuilder) {
+
+  private val timestamp = rexBuilder
+    .getTypeFactory
+    .asInstanceOf[FlinkTypeFactory]
+    .createTypeFromTypeInfo(SqlTimeTypeInfo.TIMESTAMP, isNullable = false)
+
+  def getTimestamp: RelDataType = {
+    timestamp
+  }
+
+  def projectAndMaterializeFields(input: RelNode, indicesToMaterialize: Set[Int]) : RelNode = {
+    val projects = input.getRowType.getFieldList.map { field =>
+      materializeIfContains(
+        new RexInputRef(field.getIndex, field.getType),
+        field.getIndex,
+        indicesToMaterialize)
+    }
+
+    LogicalProject.create(
+      input,
+      projects,
+      input.getRowType.getFieldNames)
+  }
+
+  def materializeIfContains(expr: RexNode, index: Int, indicesToMaterialize: Set[Int]): RexNode = {
+    if (indicesToMaterialize.contains(index)) {
+      materialize(expr)
+    }
+    else {
+      expr
+    }
+  }
+
+  def materialize(expr: RexNode): RexNode = {
+    if (isTimeIndicatorType(expr.getType)) {
+      if (isRowtimeIndicatorType(expr.getType)) {
+        // cast rowtime indicator to regular timestamp
+        rexBuilder.makeAbstractCast(timestamp, expr)
+      } else {
+        // generate proctime access
+        rexBuilder.makeCall(ProctimeSqlFunction, expr)
+      }
+    } else {
+      expr
+    }
+  }
+}