You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2018/09/21 11:43:40 UTC

[flink] 07/11: [hotfix][table] Extract DataStreamJoinToCoProcessTranslator

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 6ff0bb6c0f23f3810a1265546d217054d5d06417
Author: Piotr Nowojski <pi...@gmail.com>
AuthorDate: Tue Jul 17 18:59:31 2018 +0200

    [hotfix][table] Extract DataStreamJoinToCoProcessTranslator
---
 .../plan/nodes/datastream/DataStreamJoin.scala     | 147 +++++---------------
 .../DataStreamJoinToCoProcessTranslator.scala      | 154 +++++++++++++++++++++
 2 files changed, 190 insertions(+), 111 deletions(-)

diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
index 1e2311f..d54fd78 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoin.scala
@@ -23,19 +23,13 @@ import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
 import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
 import org.apache.calcite.rex.RexNode
-import org.apache.flink.api.common.functions.FlatJoinFunction
 import org.apache.flink.streaming.api.datastream.DataStream
 import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException}
-import org.apache.flink.table.codegen.FunctionCodeGenerator
 import org.apache.flink.table.plan.nodes.CommonJoin
 import org.apache.flink.table.plan.schema.RowSchema
-import org.apache.flink.table.runtime.CRowKeySelector
-import org.apache.flink.table.runtime.join._
 import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
-import org.apache.flink.types.Row
 
 import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
 
 /**
   * RelNode for a non-windowed stream join.
@@ -103,27 +97,42 @@ class DataStreamJoin(
       tableEnv: StreamTableEnvironment,
       queryConfig: StreamQueryConfig): DataStream[CRow] = {
 
-    val config = tableEnv.getConfig
-    val returnType = schema.typeInfo
-    val keyPairs = joinInfo.pairs().toList
+    validateKeyTypes()
 
-    // get the equality keys
-    val leftKeys = ArrayBuffer.empty[Int]
-    val rightKeys = ArrayBuffer.empty[Int]
+    val leftDataStream =
+      left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
+    val rightDataStream =
+      right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
+
+    val connectOperator = leftDataStream.connect(rightDataStream)
+
+    val joinTranslator = createTranslator(tableEnv)
+
+    val joinOpName = joinToString(getRowType, joinCondition, joinType, getExpressionString)
+    val coProcessFunction = joinTranslator.getCoProcessFunction(
+      joinType,
+      schema.fieldNames,
+      ruleDescription,
+      queryConfig)
+    connectOperator
+      .keyBy(
+        joinTranslator.getLeftKeySelector(),
+        joinTranslator.getRightKeySelector())
+      .process(coProcessFunction)
+      .name(joinOpName)
+      .returns(CRowTypeInfo(schema.typeInfo))
+  }
 
+  private def validateKeyTypes(): Unit = {
     // at least one equality expression
     val leftFields = left.getRowType.getFieldList
     val rightFields = right.getRowType.getFieldList
 
-    keyPairs.foreach(pair => {
+    joinInfo.pairs().toList.foreach(pair => {
       val leftKeyType = leftFields.get(pair.source).getType.getSqlTypeName
       val rightKeyType = rightFields.get(pair.target).getType.getSqlTypeName
       // check if keys are compatible
-      if (leftKeyType == rightKeyType) {
-        // add key pair
-        leftKeys.add(pair.source)
-        rightKeys.add(pair.target)
-      } else {
+      if (leftKeyType != rightKeyType) {
         throw TableException(
           "Equality join predicate on incompatible types.\n" +
             s"\tLeft: $left,\n" +
@@ -133,100 +142,16 @@ class DataStreamJoin(
         )
       }
     })
+  }
 
-    val leftDataStream =
-      left.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
-    val rightDataStream =
-      right.asInstanceOf[DataStreamRel].translateToPlan(tableEnv, queryConfig)
-
-    val connectOperator = leftDataStream.connect(rightDataStream)
-    // input must not be nullable, because the runtime join function will make sure
-    // the code-generated function won't process null inputs
-    val generator = new FunctionCodeGenerator(
-      config,
-      nullableInput = false,
-      leftSchema.typeInfo,
-      Some(rightSchema.typeInfo))
-    val conversion = generator.generateConverterResultExpression(
+  protected def createTranslator(
+      tableEnv: StreamTableEnvironment): DataStreamJoinToCoProcessTranslator = {
+    new DataStreamJoinToCoProcessTranslator(
+      tableEnv.getConfig,
       schema.typeInfo,
-      schema.fieldNames)
-
-    val body = if (joinInfo.isEqui) {
-      // only equality condition
-      s"""
-         |${conversion.code}
-         |${generator.collectorTerm}.collect(${conversion.resultTerm});
-         |""".stripMargin
-    } else {
-      val nonEquiPredicates = joinInfo.getRemaining(this.cluster.getRexBuilder)
-      val condition = generator.generateExpression(nonEquiPredicates)
-      s"""
-         |${condition.code}
-         |if (${condition.resultTerm}) {
-         |  ${conversion.code}
-         |  ${generator.collectorTerm}.collect(${conversion.resultTerm});
-         |}
-         |""".stripMargin
-    }
-
-    val genFunction = generator.generateFunction(
-      ruleDescription,
-      classOf[FlatJoinFunction[Row, Row, Row]],
-      body,
-      returnType)
-
-    val coMapFun = joinType match {
-      case JoinRelType.INNER =>
-        new NonWindowInnerJoin(
-          leftSchema.typeInfo,
-          rightSchema.typeInfo,
-          CRowTypeInfo(returnType),
-          genFunction.name,
-          genFunction.code,
-          queryConfig)
-      case JoinRelType.LEFT | JoinRelType.RIGHT if joinInfo.isEqui =>
-        new NonWindowLeftRightJoin(
-          leftSchema.typeInfo,
-          rightSchema.typeInfo,
-          CRowTypeInfo(returnType),
-          genFunction.name,
-          genFunction.code,
-          joinType == JoinRelType.LEFT,
-          queryConfig)
-      case JoinRelType.LEFT | JoinRelType.RIGHT =>
-        new NonWindowLeftRightJoinWithNonEquiPredicates(
-          leftSchema.typeInfo,
-          rightSchema.typeInfo,
-          CRowTypeInfo(returnType),
-          genFunction.name,
-          genFunction.code,
-          joinType == JoinRelType.LEFT,
-          queryConfig)
-      case JoinRelType.FULL if joinInfo.isEqui =>
-        new NonWindowFullJoin(
-          leftSchema.typeInfo,
-          rightSchema.typeInfo,
-          CRowTypeInfo(returnType),
-          genFunction.name,
-          genFunction.code,
-          queryConfig)
-      case JoinRelType.FULL =>
-        new NonWindowFullJoinWithNonEquiPredicates(
-          leftSchema.typeInfo,
-          rightSchema.typeInfo,
-          CRowTypeInfo(returnType),
-          genFunction.name,
-          genFunction.code,
-          queryConfig)
-    }
-
-    val joinOpName = joinToString(getRowType, joinCondition, joinType, getExpressionString)
-    connectOperator
-      .keyBy(
-        new CRowKeySelector(leftKeys.toArray, leftSchema.projectedTypeInfo(leftKeys.toArray)),
-        new CRowKeySelector(rightKeys.toArray, rightSchema.projectedTypeInfo(rightKeys.toArray)))
-      .process(coMapFun)
-      .name(joinOpName)
-      .returns(CRowTypeInfo(returnType))
+      leftSchema,
+      rightSchema,
+      joinInfo,
+      cluster.getRexBuilder)
   }
 }
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala
new file mode 100644
index 0000000..5a8d1a4
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamJoinToCoProcessTranslator.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes.datastream
+
+import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
+import org.apache.calcite.rex.{RexBuilder, RexNode}
+import org.apache.flink.api.common.functions.FlatJoinFunction
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction
+import org.apache.flink.table.api.{StreamQueryConfig, TableConfig}
+import org.apache.flink.table.codegen.{FunctionCodeGenerator, GeneratedFunction}
+import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.runtime.CRowKeySelector
+import org.apache.flink.table.runtime.join._
+import org.apache.flink.table.runtime.types.CRow
+import org.apache.flink.types.Row
+
+class DataStreamJoinToCoProcessTranslator(
+    config: TableConfig,
+    returnType: TypeInformation[Row],
+    leftSchema: RowSchema,
+    rightSchema: RowSchema,
+    joinInfo: JoinInfo,
+    rexBuilder: RexBuilder) {
+
+  val nonEquiJoinPredicates: Option[RexNode] = if (joinInfo.isEqui) {
+    None
+  }
+  else {
+    Some(joinInfo.getRemaining(rexBuilder))
+  }
+
+  def getLeftKeySelector(): CRowKeySelector = {
+    new CRowKeySelector(
+      joinInfo.leftKeys.toIntArray,
+      leftSchema.projectedTypeInfo(joinInfo.leftKeys.toIntArray))
+  }
+
+  def getRightKeySelector(): CRowKeySelector = {
+    new CRowKeySelector(
+      joinInfo.rightKeys.toIntArray,
+      rightSchema.projectedTypeInfo(joinInfo.rightKeys.toIntArray))
+  }
+
+  def getCoProcessFunction(
+      joinType: JoinRelType,
+      returnFieldNames: Seq[String],
+      ruleDescription: String,
+      queryConfig: StreamQueryConfig): CoProcessFunction[CRow, CRow, CRow] = {
+    // input must not be nullable, because the runtime join function will make sure
+    // the code-generated function won't process null inputs
+    val generator = new FunctionCodeGenerator(
+      config,
+      nullableInput = false,
+      leftSchema.typeInfo,
+      Some(rightSchema.typeInfo))
+    val conversion = generator.generateConverterResultExpression(
+      returnType,
+      returnFieldNames)
+
+    val body = if (nonEquiJoinPredicates.isEmpty) {
+      // only equality condition
+      s"""
+         |${conversion.code}
+         |${generator.collectorTerm}.collect(${conversion.resultTerm});
+         |""".stripMargin
+    } else {
+      val condition = generator.generateExpression(nonEquiJoinPredicates.get)
+      s"""
+         |${condition.code}
+         |if (${condition.resultTerm}) {
+         |  ${conversion.code}
+         |  ${generator.collectorTerm}.collect(${conversion.resultTerm});
+         |}
+         |""".stripMargin
+    }
+
+    val genFunction = generator.generateFunction(
+      ruleDescription,
+      classOf[FlatJoinFunction[Row, Row, Row]],
+      body,
+      returnType)
+
+    createCoProcessFunction(joinType, queryConfig, genFunction)
+  }
+
+  protected def createCoProcessFunction(
+    joinType: JoinRelType,
+    queryConfig: StreamQueryConfig,
+    genFunction: GeneratedFunction[FlatJoinFunction[Row, Row, Row], Row])
+    : CoProcessFunction[CRow, CRow, CRow] = {
+
+    joinType match {
+      case JoinRelType.INNER =>
+        new NonWindowInnerJoin(
+          leftSchema.typeInfo,
+          rightSchema.typeInfo,
+          CRowTypeInfo(returnType),
+          genFunction.name,
+          genFunction.code,
+          queryConfig)
+      case JoinRelType.LEFT | JoinRelType.RIGHT if joinInfo.isEqui =>
+        new NonWindowLeftRightJoin(
+          leftSchema.typeInfo,
+          rightSchema.typeInfo,
+          CRowTypeInfo(returnType),
+          genFunction.name,
+          genFunction.code,
+          joinType == JoinRelType.LEFT,
+          queryConfig)
+      case JoinRelType.LEFT | JoinRelType.RIGHT =>
+        new NonWindowLeftRightJoinWithNonEquiPredicates(
+          leftSchema.typeInfo,
+          rightSchema.typeInfo,
+          CRowTypeInfo(returnType),
+          genFunction.name,
+          genFunction.code,
+          joinType == JoinRelType.LEFT,
+          queryConfig)
+      case JoinRelType.FULL if joinInfo.isEqui =>
+        new NonWindowFullJoin(
+          leftSchema.typeInfo,
+          rightSchema.typeInfo,
+          CRowTypeInfo(returnType),
+          genFunction.name,
+          genFunction.code,
+          queryConfig)
+      case JoinRelType.FULL =>
+        new NonWindowFullJoinWithNonEquiPredicates(
+          leftSchema.typeInfo,
+          rightSchema.typeInfo,
+          CRowTypeInfo(returnType),
+          genFunction.name,
+          genFunction.code,
+          queryConfig)
+    }
+  }
+}