You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/11/29 14:42:57 UTC

[3/3] flink git commit: [FLINK-4832] [table] Fix global aggregation of empty tables (Count/Sum = 0).

[FLINK-4832] [table] Fix global aggregation of empty tables (Count/Sum = 0).

- Fix injects a union with a null record before the global aggregation.

This closes #2840


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ecfb5b5f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ecfb5b5f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ecfb5b5f

Branch: refs/heads/master
Commit: ecfb5b5f6fd6bf1555c7240d77dd9aca982f4416
Parents: 0bb6847
Author: Anton Mushin <an...@epam.com>
Authored: Mon Nov 21 15:49:41 2016 +0400
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Nov 29 13:30:51 2016 +0100

----------------------------------------------------------------------
 .../api/table/plan/rules/FlinkRuleSets.scala    |   1 +
 .../rules/dataSet/DataSetAggregateRule.scala    |   6 +
 .../DataSetAggregateWithNullValuesRule.scala    |  96 +++++++
 .../scala/batch/sql/AggregationsITCase.scala    |  39 +++
 .../flink/api/table/AggregationTest.scala       | 261 +++++++++++++++++++
 .../flink/api/table/utils/TableTestBase.scala   |   9 +-
 6 files changed, 410 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 5653083..26c025e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -98,6 +98,7 @@ object FlinkRuleSets {
 
     // translate to Flink DataSet nodes
     DataSetAggregateRule.INSTANCE,
+    DataSetAggregateWithNullValuesRule.INSTANCE,
     DataSetCalcRule.INSTANCE,
     DataSetJoinRule.INSTANCE,
     DataSetScanRule.INSTANCE,

http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
index 72ed27e..0311c48 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
@@ -37,6 +37,12 @@ class DataSetAggregateRule
   override def matches(call: RelOptRuleCall): Boolean = {
     val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
 
+    //for non grouped agg sets should attach null row to source data
+    //need apply DataSetAggregateWithNullValuesRule
+    if (agg.getGroupSet.isEmpty) {
+      return false
+    }
+
     // check if we have distinct aggregates
     val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
     if (distinctAggs) {

http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
new file mode 100644
index 0000000..54cb8d1
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan._
+import scala.collection.JavaConversions._
+import com.google.common.collect.ImmutableList
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalValues, LogicalUnion, LogicalAggregate}
+import org.apache.calcite.rex.RexLiteral
+import org.apache.flink.api.table._
+import org.apache.flink.api.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
+
+/**
+  * Rule for insert [[Row]] with null records into a [[DataSetAggregate]]
+  * Rule apply for non grouped aggregate query
+  */
+class DataSetAggregateWithNullValuesRule
+  extends ConverterRule(
+    classOf[LogicalAggregate],
+    Convention.NONE,
+    DataSetConvention.INSTANCE,
+    "DataSetAggregateWithNullValuesRule")
+{
+
+  override def matches(call: RelOptRuleCall): Boolean = {
+    val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
+
+    //for grouped agg sets shouldn't attach of null row
+    //need apply other rules. e.g. [[DataSetAggregateRule]]
+    if (!agg.getGroupSet.isEmpty) {
+      return false
+    }
+
+    // check if we have distinct aggregates
+    val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
+    if (distinctAggs) {
+      throw TableException("DISTINCT aggregates are currently not supported.")
+    }
+
+    // check if we have grouping sets
+    val groupSets = agg.getGroupSets.size() == 0 || agg.getGroupSets.get(0) != agg.getGroupSet
+    if (groupSets || agg.indicator) {
+      throw TableException("GROUPING SETS are currently not supported.")
+    }
+    !distinctAggs && !groupSets && !agg.indicator
+  }
+
+  override def convert(rel: RelNode): RelNode = {
+    val agg: LogicalAggregate = rel.asInstanceOf[LogicalAggregate]
+    val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+    val cluster: RelOptCluster = rel.getCluster
+
+    val fieldTypes = agg.getInput.getRowType.getFieldList.map(_.getType)
+    val nullLiterals: ImmutableList[ImmutableList[RexLiteral]] =
+      ImmutableList.of(ImmutableList.copyOf[RexLiteral](
+        for (fieldType <- fieldTypes)
+          yield {
+            cluster.getRexBuilder.
+              makeLiteral(null, fieldType, false).asInstanceOf[RexLiteral]
+          }))
+
+    val logicalValues = LogicalValues.create(cluster, agg.getInput.getRowType, nullLiterals)
+    val logicalUnion = LogicalUnion.create(List(logicalValues, agg.getInput), true)
+
+    new DataSetAggregate(
+      cluster,
+      traitSet,
+      RelOptRule.convert(logicalUnion, DataSetConvention.INSTANCE),
+      agg.getNamedAggCalls,
+      rel.getRowType,
+      agg.getInput.getRowType,
+      agg.getGroupSet.toArray
+    )
+  }
+}
+
+object DataSetAggregateWithNullValuesRule {
+  val INSTANCE: RelOptRule = new DataSetAggregateWithNullValuesRule
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
index 2dce751..35bb7dc 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
@@ -258,4 +258,43 @@ class AggregationsITCase(
     // must fail. grouping sets are not supported
     tEnv.sql(sqlQuery).toDataSet[Row]
   }
+
+  @Test
+  def testAggregateEmptyDataSets(): Unit = {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val sqlQuery = "SELECT avg(a), sum(a), count(b) " +
+      "FROM MyTable where a = 4 group by a"
+
+    val sqlQuery2 = "SELECT avg(a), sum(a), count(b) " +
+      "FROM MyTable where a = 4"
+
+    val sqlQuery3 = "SELECT avg(a), sum(a), count(b) " +
+      "FROM MyTable"
+
+    val ds = env.fromElements(
+      (1: Byte, 1: Short),
+      (2: Byte, 2: Short))
+      .toTable(tEnv, 'a, 'b)
+
+    tEnv.registerTable("MyTable", ds)
+
+    val result = tEnv.sql(sqlQuery)
+    val result2 = tEnv.sql(sqlQuery2)
+    val result3 = tEnv.sql(sqlQuery3)
+
+    val results = result.toDataSet[Row].collect()
+    val expected = Seq.empty
+    val results2 =  result2.toDataSet[Row].collect()
+    val expected2 = "null,null,0"
+    val results3 = result3.toDataSet[Row].collect()
+    val expected3 = "1,3,2"
+
+    assert(results.equals(expected),
+      "Empty result is expected for grouped set, but actual: " + results)
+    TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+    TestBaseUtils.compareResultAsText(results3.asJava, expected3)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala
new file mode 100644
index 0000000..6c9d2e8
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+/**
+  * Test for testing aggregate plans.
+  */
+class AggregationTest extends TableTestBase {
+
+  @Test
+  def testAggregateQueryBatchSQL(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable"
+
+    val setValues = unaryNode(
+      "DataSetValues",
+      batchTableNode(0),
+      tuples(List(null,null,null)),
+      term("values","a","b","c")
+    )
+    val union = unaryNode(
+      "DataSetUnion",
+      setValues,
+      term("union","a","b","c")
+    )
+
+    val aggregate = unaryNode(
+      "DataSetAggregate",
+      union,
+      term("select",
+        "AVG(a) AS EXPR$0",
+        "SUM(b) AS EXPR$1",
+        "COUNT(c) AS EXPR$2")
+    )
+    util.verifySql(sqlQuery, aggregate)
+  }
+
+  @Test
+  def testAggregateWithFilterQueryBatchSQL(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1"
+
+    val calcNode = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select", "a", "b", "c"),
+      term("where", "=(a, 1)")
+    )
+
+    val setValues =  unaryNode(
+        "DataSetValues",
+        calcNode,
+        tuples(List(null,null,null)),
+        term("values","a","b","c")
+    )
+
+    val union = unaryNode(
+      "DataSetUnion",
+      setValues,
+      term("union","a","b","c")
+    )
+
+    val aggregate = unaryNode(
+      "DataSetAggregate",
+      union,
+      term("select",
+        "AVG(a) AS EXPR$0",
+        "SUM(b) AS EXPR$1",
+        "COUNT(c) AS EXPR$2")
+    )
+    util.verifySql(sqlQuery, aggregate)
+  }
+
+  @Test
+  def testAggregateGroupQueryBatchSQL(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable GROUP BY a"
+
+    val aggregate = unaryNode(
+        "DataSetAggregate",
+        batchTableNode(0),
+        term("groupBy", "a"),
+        term("select",
+          "a",
+          "AVG(a) AS EXPR$0",
+          "SUM(b) AS EXPR$1",
+          "COUNT(c) AS EXPR$2")
+    )
+    val expected = unaryNode(
+        "DataSetCalc",
+        aggregate,
+        term("select",
+          "EXPR$0",
+          "EXPR$1",
+          "EXPR$2")
+    )
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testAggregateGroupWithFilterQueryBatchSQL(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1 GROUP BY a"
+
+    val calcNode = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select","a", "b", "c") ,
+      term("where","=(a, 1)")
+    )
+
+    val aggregate = unaryNode(
+        "DataSetAggregate",
+        calcNode,
+        term("groupBy", "a"),
+        term("select",
+          "a",
+          "AVG(a) AS EXPR$0",
+          "SUM(b) AS EXPR$1",
+          "COUNT(c) AS EXPR$2")
+    )
+    val expected = unaryNode(
+        "DataSetCalc",
+        aggregate,
+        term("select",
+          "EXPR$0",
+          "EXPR$1",
+          "EXPR$2")
+    )
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testAggregateGroupWithFilterTableApi(): Unit = {
+
+    val util = batchTestUtil()
+    val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val resultTable = sourceTable.groupBy('a)
+      .select('a, 'a.avg, 'b.sum, 'c.count)
+      .where('a === 1)
+
+    val calcNode = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select", "a", "b", "c"),
+      term("where", "=(a, 1)")
+    )
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      calcNode,
+      term("groupBy", "a"),
+      term("select",
+        "a",
+        "AVG(a) AS TMP_0",
+        "SUM(b) AS TMP_1",
+        "COUNT(c) AS TMP_2")
+    )
+
+    util.verifyTable(resultTable,expected)
+  }
+
+  @Test
+  def testAggregateTableApi(): Unit = {
+    val util = batchTestUtil()
+    val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+    val resultTable = sourceTable.select('a.avg,'b.sum,'c.count)
+
+    val setValues = unaryNode(
+      "DataSetValues",
+      batchTableNode(0),
+      tuples(List(null,null,null)),
+      term("values","a","b","c")
+    )
+    val union = unaryNode(
+      "DataSetUnion",
+      setValues,
+      term("union","a","b","c")
+    )
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      union,
+      term("select",
+        "AVG(a) AS TMP_0",
+        "SUM(b) AS TMP_1",
+        "COUNT(c) AS TMP_2")
+    )
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testAggregateWithFilterTableApi(): Unit = {
+    val util = batchTestUtil()
+    val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val resultTable = sourceTable.select('a,'b,'c).where('a === 1)
+      .select('a.avg,'b.sum,'c.count)
+
+    val calcNode = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select", "a", "b", "c"),
+      term("where", "=(a, 1)")
+    )
+
+    val setValues =  unaryNode(
+      "DataSetValues",
+      calcNode,
+      tuples(List(null,null,null)),
+      term("values","a","b","c")
+    )
+
+    val union = unaryNode(
+      "DataSetUnion",
+      setValues,
+      term("union","a","b","c")
+    )
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      union,
+      term("select",
+        "AVG(a) AS TMP_0",
+        "SUM(b) AS TMP_1",
+        "COUNT(c) AS TMP_2")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 2ea15a0..539bb61 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -70,20 +70,25 @@ object TableTestUtil {
   def unaryNode(node: String, input: String, term: String*): String = {
     s"""$node(${term.mkString(", ")})
        |$input
-       |""".stripMargin
+       |""".stripMargin.stripLineEnd
   }
 
   def binaryNode(node: String, left: String, right: String, term: String*): String = {
     s"""$node(${term.mkString(", ")})
        |$left
        |$right
-       |""".stripMargin
+       |""".stripMargin.stripLineEnd
   }
 
   def term(term: AnyRef, value: AnyRef*): String = {
     s"$term=[${value.mkString(", ")}]"
   }
 
+  def tuples(value:List[AnyRef]*): String={
+    val listValues = value.map( listValue => s"{ ${listValue.mkString(", ")} }")
+    term("tuples","[" + listValues.mkString(", ") + "]")
+  }
+
   def batchTableNode(idx: Int): String = {
     s"DataSetScan(table=[[_DataSetTable_$idx]])"
   }