You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/16 15:47:16 UTC
[47/47] flink git commit: [FLINK-5255] [table] Generalize detection
of single row inputs for DataSetSingleRowJoinRule.
[FLINK-5255] [table] Generalize detection of single row inputs for DataSetSingleRowJoinRule.
- Add support for projections and filters following a global aggregation.
This closes #3009.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d5c7bf6a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d5c7bf6a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d5c7bf6a
Branch: refs/heads/master
Commit: d5c7bf6ac4807b718f5eb780520f74e11a794b74
Parents: cc34c14
Author: Alexander Shoshin <Al...@epam.com>
Authored: Wed Dec 14 13:59:08 2016 +0300
Committer: Fabian Hueske <fh...@apache.org>
Committed: Fri Dec 16 16:41:21 2016 +0100
----------------------------------------------------------------------
.../dataSet/DataSetSingleRowJoinRule.scala | 29 +++++++------
.../flink/table/runtime/MapJoinLeftRunner.scala | 8 +++-
.../table/runtime/MapJoinRightRunner.scala | 8 +++-
.../flink/table/runtime/MapSideJoinRunner.scala | 13 +++++-
.../table/api/scala/batch/sql/JoinITCase.scala | 14 +++++++
.../api/scala/batch/sql/SingleRowJoinTest.scala | 43 ++++++++++++++++++++
6 files changed, 97 insertions(+), 18 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
index 1f5c91a..dcd02d9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala
@@ -23,7 +23,7 @@ import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.core.JoinRelType
-import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalJoin}
+import org.apache.calcite.rel.logical._
import org.apache.flink.table.plan.nodes.dataset.{DataSetConvention, DataSetSingleRowJoin}
class DataSetSingleRowJoinRule
@@ -31,14 +31,13 @@ class DataSetSingleRowJoinRule
classOf[LogicalJoin],
Convention.NONE,
DataSetConvention.INSTANCE,
- "DataSetSingleRowCrossRule") {
+ "DataSetSingleRowJoinRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val join = call.rel(0).asInstanceOf[LogicalJoin]
if (isInnerJoin(join)) {
- isGlobalAggregation(join.getRight.asInstanceOf[RelSubset].getOriginal) ||
- isGlobalAggregation(join.getLeft.asInstanceOf[RelSubset].getOriginal)
+ isSingleRow(join.getRight) || isSingleRow(join.getLeft)
} else {
false
}
@@ -48,13 +47,19 @@ class DataSetSingleRowJoinRule
join.getJoinType == JoinRelType.INNER
}
- private def isGlobalAggregation(node: RelNode) = {
- node.isInstanceOf[LogicalAggregate] &&
- isSingleRow(node.asInstanceOf[LogicalAggregate])
- }
-
- private def isSingleRow(agg: LogicalAggregate) = {
- agg.getGroupSet.isEmpty
+ /**
+ * Recursively checks if a [[RelNode]] returns at most a single row.
+ * Input must be a global aggregation possibly followed by projections or filters.
+ */
+ private def isSingleRow(node: RelNode): Boolean = {
+ node match {
+ case ss: RelSubset => isSingleRow(ss.getOriginal)
+ case lp: LogicalProject => isSingleRow(lp.getInput)
+ case lf: LogicalFilter => isSingleRow(lf.getInput)
+ case lc: LogicalCalc => isSingleRow(lc.getInput)
+ case la: LogicalAggregate => la.getGroupSet.isEmpty
+ case _ => false
+ }
}
override def convert(rel: RelNode): RelNode = {
@@ -62,7 +67,7 @@ class DataSetSingleRowJoinRule
val traitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
val dataSetLeftNode = RelOptRule.convert(join.getLeft, DataSetConvention.INSTANCE)
val dataSetRightNode = RelOptRule.convert(join.getRight, DataSetConvention.INSTANCE)
- val leftIsSingle = isGlobalAggregation(join.getLeft.asInstanceOf[RelSubset].getOriginal)
+ val leftIsSingle = isSingleRow(join.getLeft)
new DataSetSingleRowJoin(
rel.getCluster,
http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
index cf32404..644e855 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala
@@ -28,6 +28,10 @@ class MapJoinLeftRunner[IN1, IN2, OUT](
broadcastSetName: String)
extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, broadcastSetName) {
- override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit =
- function.join(multiInput, singleInput, out)
+ override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit = {
+ broadcastSet match {
+ case Some(singleInput) => function.join(multiInput, singleInput, out)
+ case None =>
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
index c4bc0d1..eee38d1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala
@@ -28,6 +28,10 @@ class MapJoinRightRunner[IN1, IN2, OUT](
broadcastSetName: String)
extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, broadcastSetName) {
- override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit =
- function.join(singleInput, multiInput, out)
+ override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit = {
+ broadcastSet match {
+ case Some(singleInput) => function.join(singleInput, multiInput, out)
+ case None =>
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
index f12590f..090e184 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapSideJoinRunner.scala
@@ -37,14 +37,23 @@ abstract class MapSideJoinRunner[IN1, IN2, SINGLE_IN, MULTI_IN, OUT](
val LOG = LoggerFactory.getLogger(this.getClass)
protected var function: FlatJoinFunction[IN1, IN2, OUT] = _
- protected var singleInput: SINGLE_IN = _
+ protected var broadcastSet: Option[SINGLE_IN] = _
override def open(parameters: Configuration): Unit = {
LOG.debug(s"Compiling FlatJoinFunction: $name \n\n Code:\n$code")
val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
LOG.debug("Instantiating FlatJoinFunction.")
function = clazz.newInstance()
- singleInput = getRuntimeContext.getBroadcastVariable(broadcastSetName).get(0)
+ broadcastSet = retrieveBroadcastSet
+ }
+
+ private def retrieveBroadcastSet: Option[SINGLE_IN] = {
+ val broadcastSet = getRuntimeContext.getBroadcastVariable(broadcastSetName)
+ if (!broadcastSet.isEmpty) {
+ Option(broadcastSet.get(0))
+ } else {
+ Option.empty
+ }
}
override def getProducedType: TypeInformation[OUT] = returnType
http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
index 344428b..96beea5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala
@@ -363,4 +363,18 @@ class JoinITCase(
val result = tEnv.sql(sqlQuery1).collect()
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
+
+ @Test
+ def testCrossJoinWithEmptySingleRowInput(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+ val table = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3)
+ tEnv.registerTable("A", table)
+
+ val sqlQuery1 = "SELECT * FROM A CROSS JOIN (SELECT count(*) FROM A HAVING count(*) < 0)"
+ val result = tEnv.sql(sqlQuery1).count()
+
+ Assert.assertEquals(0, result)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/d5c7bf6a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
index ecc685d..27e3853 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala
@@ -27,6 +27,49 @@ import org.junit.Test
class SingleRowJoinTest extends TableTestBase {
@Test
+ def testSingleRowJoinWithCalcInput(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Int)]("A", 'a1, 'a2)
+
+ val query =
+ "SELECT a1, asum " +
+ "FROM A, (SELECT sum(a1) + sum(a2) AS asum FROM A)"
+
+ val expected =
+ binaryNode(
+ "DataSetSingleRowJoin",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a1")
+ ),
+ unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetValues",
+ batchTableNode(0),
+ tuples(List(null, null)),
+ term("values", "a1", "a2")
+ ),
+ term("union","a1","a2")
+ ),
+ term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1")
+ ),
+ term("select", "+($f0, $f1) AS asum")
+ ),
+ term("where", "true"),
+ term("join", "a1", "asum"),
+ term("joinType", "NestedLoopJoin")
+ )
+
+ util.verifySql(query, expected)
+ }
+
+ @Test
def testSingleRowEquiJoin(): Unit = {
val util = batchTestUtil()
util.addTable[(Int, String)]("A", 'a1, 'a2)