You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/16 15:47:16 UTC

[47/47] flink git commit: [FLINK-5255] [table] Generalize detection of single row inputs for DataSetSingleRowJoinRule.

[FLINK-5255] [table] Generalize detection of single row inputs for DataSetSingleRowJoinRule.

- Add support for projections and filters following a global aggregation.

This closes #3009.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d5c7bf6a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d5c7bf6a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d5c7bf6a

Branch: refs/heads/master
Commit: d5c7bf6ac4807b718f5eb780520f74e11a794b74
Parents: cc34c14
Author: Alexander Shoshin <Al...@epam.com>
Authored: Wed Dec 14 13:59:08 2016 +0300
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Dec 16 16:41:21 2016 +0100

----------------------------------------------------------------------
 .../dataSet/DataSetSingleRowJoinRule.scala      | 29 +++++++------
 .../flink/table/runtime/MapJoinLeftRunner.scala |  8 +++-
 .../table/runtime/MapJoinRightRunner.scala      |  8 +++-
 .../flink/table/runtime/MapSideJoinRunner.scala | 13 +++++-
 .../table/api/scala/batch/sql/JoinITCase.scala  | 14 +++++++
 .../api/scala/batch/sql/SingleRowJoinTest.scala | 43 ++++++++++++++++++++
 6 files changed, 97 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
index 1f5c91a..dcd02d9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
@@ -23,7 +23,7 @@ import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall}
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.convert.ConverterRule
 import org.apache.calcite.rel.core.JoinRelType
-import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalJoin}
+import org.apache.calcite.rel.logical._
 import org.apache.flink.table.plan.nodes.dataset.{DataSetConvention, DataSetSingleRowJoin}
 
 class DataSetSingleRowJoinRule
@@ -31,14 +31,13 @@ class DataSetSingleRowJoinRule
       classOf[LogicalJoin],
       Convention.NONE,
       DataSetConvention.INSTANCE,
-      "DataSetSingleRowCrossRule") {
+      "DataSetSingleRowJoinRule") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val join = call.rel(0).asInstanceOf[LogicalJoin]
 
     if (isInnerJoin(join)) {
-      isGlobalAggregation(join.getRight.asInstanceOf[RelSubset].getOriginal) ||
-        isGlobalAggregation(join.getLeft.asInstanceOf[RelSubset].getOriginal)
+      isSingleRow(join.getRight) || isSingleRow(join.getLeft)
     } else {
       false
     }
@@ -48,13 +47,19 @@ class DataSetSingleRowJoinRule
     join.getJoinType == JoinRelType.INNER
   }
 
-  private def isGlobalAggregation(node: RelNode) = {
-    node.isInstanceOf[LogicalAggregate] &&
-      isSingleRow(node.asInstanceOf[LogicalAggregate])
-  }
-
-  private def isSingleRow(agg: LogicalAggregate) = {
-    agg.getGroupSet.isEmpty
+  /**
+    * Recursively checks if a [[RelNode]] returns at most a single row.
+    * Input must be a global aggregation possibly followed by projections or filters.
+    */
+  private def isSingleRow(node: RelNode): Boolean = {
+    node match {
+      case ss: RelSubset => isSingleRow(ss.getOriginal)
+      case lp: LogicalProject => isSingleRow(lp.getInput)
+      case lf: LogicalFilter => isSingleRow(lf.getInput)
+      case lc: LogicalCalc => isSingleRow(lc.getInput)
+      case la: LogicalAggregate => la.getGroupSet.isEmpty
+      case _ => false
+    }
   }
 
   override def convert(rel: RelNode): RelNode = {
@@ -62,7 +67,7 @@ class DataSetSingleRowJoinRule
     val traitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
     val dataSetLeftNode = RelOptRule.convert(join.getLeft, DataSetConvention.INSTANCE)
     val dataSetRightNode = RelOptRule.convert(join.getRight, DataSetConvention.INSTANCE)
-    val leftIsSingle = isGlobalAggregation(join.getLeft.asInstanceOf[RelSubset].getOriginal)
+    val leftIsSingle = isSingleRow(join.getLeft)
 
     new DataSetSingleRowJoin(
       rel.getCluster,

http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
index cf32404..644e855 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
@@ -28,6 +28,10 @@ class MapJoinLeftRunner[IN1, IN2, OUT](
     broadcastSetName: String)
   extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, broadcastSetName) {
 
-  override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit =
-    function.join(multiInput, singleInput, out)
+  override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit = {
+    broadcastSet match {
+      case Some(singleInput) => function.join(multiInput, singleInput, out)
+      case None =>
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
index c4bc0d1..eee38d1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
@@ -28,6 +28,10 @@ class MapJoinRightRunner[IN1, IN2, OUT](
     broadcastSetName: String)
   extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, broadcastSetName) {
 
-  override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit =
-    function.join(singleInput, multiInput, out)
+  override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit = {
+    broadcastSet match {
+      case Some(singleInput) => function.join(singleInput, multiInput, out)
+      case None =>
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
index f12590f..090e184 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
@@ -37,14 +37,23 @@ abstract class MapSideJoinRunner[IN1, IN2, SINGLE_IN, MULTI_IN, OUT](
   val LOG = LoggerFactory.getLogger(this.getClass)
 
   protected var function: FlatJoinFunction[IN1, IN2, OUT] = _
-  protected var singleInput: SINGLE_IN = _
+  protected var broadcastSet: Option[SINGLE_IN] = _
 
   override def open(parameters: Configuration): Unit = {
     LOG.debug(s"Compiling FlatJoinFunction: $name \n\n Code:\n$code")
     val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
     LOG.debug("Instantiating FlatJoinFunction.")
     function = clazz.newInstance()
-    singleInput = getRuntimeContext.getBroadcastVariable(broadcastSetName).get(0)
+    broadcastSet = retrieveBroadcastSet
+  }
+
+  private def retrieveBroadcastSet: Option[SINGLE_IN] = {
+    val broadcastSet = getRuntimeContext.getBroadcastVariable(broadcastSetName)
+    if (!broadcastSet.isEmpty) {
+      Option(broadcastSet.get(0))
+    } else {
+      Option.empty
+    }
   }
 
   override def getProducedType: TypeInformation[OUT] = returnType

http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
index 344428b..96beea5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
@@ -363,4 +363,18 @@ class JoinITCase(
     val result = tEnv.sql(sqlQuery1).collect()
     TestBaseUtils.compareResultAsText(result.asJava, expected)
   }
+
+  @Test
+  def testCrossJoinWithEmptySingleRowInput(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val table = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3)
+    tEnv.registerTable("A", table)
+
+    val sqlQuery1 = "SELECT * FROM A CROSS JOIN (SELECT count(*) FROM A HAVING count(*) < 0)"
+    val result = tEnv.sql(sqlQuery1).count()
+
+    Assert.assertEquals(0, result)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
index ecc685d..27e3853 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
@@ -27,6 +27,49 @@ import org.junit.Test
 class SingleRowJoinTest extends TableTestBase {
 
   @Test
+  def testSingleRowJoinWithCalcInput(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Int)]("A", 'a1, 'a2)
+
+    val query =
+      "SELECT a1, asum " +
+      "FROM A, (SELECT sum(a1) + sum(a2) AS asum FROM A)"
+
+    val expected =
+      binaryNode(
+        "DataSetSingleRowJoin",
+        unaryNode(
+          "DataSetCalc",
+          batchTableNode(0),
+          term("select", "a1")
+        ),
+        unaryNode(
+          "DataSetCalc",
+          unaryNode(
+            "DataSetAggregate",
+            unaryNode(
+              "DataSetUnion",
+              unaryNode(
+                "DataSetValues",
+                batchTableNode(0),
+                tuples(List(null, null)),
+                term("values", "a1", "a2")
+              ),
+              term("union","a1","a2")
+            ),
+            term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1")
+          ),
+          term("select", "+($f0, $f1) AS asum")
+        ),
+        term("where", "true"),
+        term("join", "a1", "asum"),
+        term("joinType", "NestedLoopJoin")
+      )
+
+    util.verifySql(query, expected)
+  }
+
+  @Test
   def testSingleRowEquiJoin(): Unit = {
     val util = batchTestUtil()
     util.addTable[(Int, String)]("A", 'a1, 'a2)