You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/11/29 14:42:57 UTC
[3/3] flink git commit: [FLINK-4832] [table] Fix global aggregation
of empty tables (Count/Sum = 0).
[FLINK-4832] [table] Fix global aggregation of empty tables (Count/Sum = 0).
- Fix injects a union with a null record before the global aggregation.
This closes #2840
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ecfb5b5f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ecfb5b5f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ecfb5b5f
Branch: refs/heads/master
Commit: ecfb5b5f6fd6bf1555c7240d77dd9aca982f4416
Parents: 0bb6847
Author: Anton Mushin <an...@epam.com>
Authored: Mon Nov 21 15:49:41 2016 +0400
Committer: Fabian Hueske <fh...@apache.org>
Committed: Tue Nov 29 13:30:51 2016 +0100
----------------------------------------------------------------------
.../api/table/plan/rules/FlinkRuleSets.scala | 1 +
.../rules/dataSet/DataSetAggregateRule.scala | 6 +
.../DataSetAggregateWithNullValuesRule.scala | 96 +++++++
.../scala/batch/sql/AggregationsITCase.scala | 39 +++
.../flink/api/table/AggregationTest.scala | 261 +++++++++++++++++++
.../flink/api/table/utils/TableTestBase.scala | 9 +-
6 files changed, 410 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
index 5653083..26c025e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala
@@ -98,6 +98,7 @@ object FlinkRuleSets {
// translate to Flink DataSet nodes
DataSetAggregateRule.INSTANCE,
+ DataSetAggregateWithNullValuesRule.INSTANCE,
DataSetCalcRule.INSTANCE,
DataSetJoinRule.INSTANCE,
DataSetScanRule.INSTANCE,
http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
index 72ed27e..0311c48 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala
@@ -37,6 +37,12 @@ class DataSetAggregateRule
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
+ //for non grouped agg sets should attach null row to source data
+ //need apply DataSetAggregateWithNullValuesRule
+ if (agg.getGroupSet.isEmpty) {
+ return false
+ }
+
// check if we have distinct aggregates
val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
if (distinctAggs) {
http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
new file mode 100644
index 0000000..54cb8d1
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan.rules.dataSet
+
+import org.apache.calcite.plan._
+import scala.collection.JavaConversions._
+import com.google.common.collect.ImmutableList
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.convert.ConverterRule
+import org.apache.calcite.rel.logical.{LogicalValues, LogicalUnion, LogicalAggregate}
+import org.apache.calcite.rex.RexLiteral
+import org.apache.flink.api.table._
+import org.apache.flink.api.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
+
+/**
+ * Rule for insert [[Row]] with null records into a [[DataSetAggregate]]
+ * Rule apply for non grouped aggregate query
+ */
+class DataSetAggregateWithNullValuesRule
+ extends ConverterRule(
+ classOf[LogicalAggregate],
+ Convention.NONE,
+ DataSetConvention.INSTANCE,
+ "DataSetAggregateWithNullValuesRule")
+{
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
+
+ //for grouped agg sets shouldn't attach of null row
+ //need apply other rules. e.g. [[DataSetAggregateRule]]
+ if (!agg.getGroupSet.isEmpty) {
+ return false
+ }
+
+ // check if we have distinct aggregates
+ val distinctAggs = agg.getAggCallList.exists(_.isDistinct)
+ if (distinctAggs) {
+ throw TableException("DISTINCT aggregates are currently not supported.")
+ }
+
+ // check if we have grouping sets
+ val groupSets = agg.getGroupSets.size() == 0 || agg.getGroupSets.get(0) != agg.getGroupSet
+ if (groupSets || agg.indicator) {
+ throw TableException("GROUPING SETS are currently not supported.")
+ }
+ !distinctAggs && !groupSets && !agg.indicator
+ }
+
+ override def convert(rel: RelNode): RelNode = {
+ val agg: LogicalAggregate = rel.asInstanceOf[LogicalAggregate]
+ val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+ val cluster: RelOptCluster = rel.getCluster
+
+ val fieldTypes = agg.getInput.getRowType.getFieldList.map(_.getType)
+ val nullLiterals: ImmutableList[ImmutableList[RexLiteral]] =
+ ImmutableList.of(ImmutableList.copyOf[RexLiteral](
+ for (fieldType <- fieldTypes)
+ yield {
+ cluster.getRexBuilder.
+ makeLiteral(null, fieldType, false).asInstanceOf[RexLiteral]
+ }))
+
+ val logicalValues = LogicalValues.create(cluster, agg.getInput.getRowType, nullLiterals)
+ val logicalUnion = LogicalUnion.create(List(logicalValues, agg.getInput), true)
+
+ new DataSetAggregate(
+ cluster,
+ traitSet,
+ RelOptRule.convert(logicalUnion, DataSetConvention.INSTANCE),
+ agg.getNamedAggCalls,
+ rel.getRowType,
+ agg.getInput.getRowType,
+ agg.getGroupSet.toArray
+ )
+ }
+}
+
+object DataSetAggregateWithNullValuesRule {
+ val INSTANCE: RelOptRule = new DataSetAggregateWithNullValuesRule
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
index 2dce751..35bb7dc 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala
@@ -258,4 +258,43 @@ class AggregationsITCase(
// must fail. grouping sets are not supported
tEnv.sql(sqlQuery).toDataSet[Row]
}
+
+ @Test
+ def testAggregateEmptyDataSets(): Unit = {
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+ val sqlQuery = "SELECT avg(a), sum(a), count(b) " +
+ "FROM MyTable where a = 4 group by a"
+
+ val sqlQuery2 = "SELECT avg(a), sum(a), count(b) " +
+ "FROM MyTable where a = 4"
+
+ val sqlQuery3 = "SELECT avg(a), sum(a), count(b) " +
+ "FROM MyTable"
+
+ val ds = env.fromElements(
+ (1: Byte, 1: Short),
+ (2: Byte, 2: Short))
+ .toTable(tEnv, 'a, 'b)
+
+ tEnv.registerTable("MyTable", ds)
+
+ val result = tEnv.sql(sqlQuery)
+ val result2 = tEnv.sql(sqlQuery2)
+ val result3 = tEnv.sql(sqlQuery3)
+
+ val results = result.toDataSet[Row].collect()
+ val expected = Seq.empty
+ val results2 = result2.toDataSet[Row].collect()
+ val expected2 = "null,null,0"
+ val results3 = result3.toDataSet[Row].collect()
+ val expected3 = "1,3,2"
+
+ assert(results.equals(expected),
+ "Empty result is expected for grouped set, but actual: " + results)
+ TestBaseUtils.compareResultAsText(results2.asJava, expected2)
+ TestBaseUtils.compareResultAsText(results3.asJava, expected3)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala
new file mode 100644
index 0000000..6c9d2e8
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+/**
+ * Test for testing aggregate plans.
+ */
+class AggregationTest extends TableTestBase {
+
+ @Test
+ def testAggregateQueryBatchSQL(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable"
+
+ val setValues = unaryNode(
+ "DataSetValues",
+ batchTableNode(0),
+ tuples(List(null,null,null)),
+ term("values","a","b","c")
+ )
+ val union = unaryNode(
+ "DataSetUnion",
+ setValues,
+ term("union","a","b","c")
+ )
+
+ val aggregate = unaryNode(
+ "DataSetAggregate",
+ union,
+ term("select",
+ "AVG(a) AS EXPR$0",
+ "SUM(b) AS EXPR$1",
+ "COUNT(c) AS EXPR$2")
+ )
+ util.verifySql(sqlQuery, aggregate)
+ }
+
+ @Test
+ def testAggregateWithFilterQueryBatchSQL(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1"
+
+ val calcNode = unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b", "c"),
+ term("where", "=(a, 1)")
+ )
+
+ val setValues = unaryNode(
+ "DataSetValues",
+ calcNode,
+ tuples(List(null,null,null)),
+ term("values","a","b","c")
+ )
+
+ val union = unaryNode(
+ "DataSetUnion",
+ setValues,
+ term("union","a","b","c")
+ )
+
+ val aggregate = unaryNode(
+ "DataSetAggregate",
+ union,
+ term("select",
+ "AVG(a) AS EXPR$0",
+ "SUM(b) AS EXPR$1",
+ "COUNT(c) AS EXPR$2")
+ )
+ util.verifySql(sqlQuery, aggregate)
+ }
+
+ @Test
+ def testAggregateGroupQueryBatchSQL(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable GROUP BY a"
+
+ val aggregate = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "a"),
+ term("select",
+ "a",
+ "AVG(a) AS EXPR$0",
+ "SUM(b) AS EXPR$1",
+ "COUNT(c) AS EXPR$2")
+ )
+ val expected = unaryNode(
+ "DataSetCalc",
+ aggregate,
+ term("select",
+ "EXPR$0",
+ "EXPR$1",
+ "EXPR$2")
+ )
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testAggregateGroupWithFilterQueryBatchSQL(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1 GROUP BY a"
+
+ val calcNode = unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select","a", "b", "c") ,
+ term("where","=(a, 1)")
+ )
+
+ val aggregate = unaryNode(
+ "DataSetAggregate",
+ calcNode,
+ term("groupBy", "a"),
+ term("select",
+ "a",
+ "AVG(a) AS EXPR$0",
+ "SUM(b) AS EXPR$1",
+ "COUNT(c) AS EXPR$2")
+ )
+ val expected = unaryNode(
+ "DataSetCalc",
+ aggregate,
+ term("select",
+ "EXPR$0",
+ "EXPR$1",
+ "EXPR$2")
+ )
+ util.verifySql(sqlQuery, expected)
+ }
+
+ @Test
+ def testAggregateGroupWithFilterTableApi(): Unit = {
+
+ val util = batchTestUtil()
+ val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val resultTable = sourceTable.groupBy('a)
+ .select('a, 'a.avg, 'b.sum, 'c.count)
+ .where('a === 1)
+
+ val calcNode = unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b", "c"),
+ term("where", "=(a, 1)")
+ )
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ calcNode,
+ term("groupBy", "a"),
+ term("select",
+ "a",
+ "AVG(a) AS TMP_0",
+ "SUM(b) AS TMP_1",
+ "COUNT(c) AS TMP_2")
+ )
+
+ util.verifyTable(resultTable,expected)
+ }
+
+ @Test
+ def testAggregateTableApi(): Unit = {
+ val util = batchTestUtil()
+ val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+ val resultTable = sourceTable.select('a.avg,'b.sum,'c.count)
+
+ val setValues = unaryNode(
+ "DataSetValues",
+ batchTableNode(0),
+ tuples(List(null,null,null)),
+ term("values","a","b","c")
+ )
+ val union = unaryNode(
+ "DataSetUnion",
+ setValues,
+ term("union","a","b","c")
+ )
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ union,
+ term("select",
+ "AVG(a) AS TMP_0",
+ "SUM(b) AS TMP_1",
+ "COUNT(c) AS TMP_2")
+ )
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testAggregateWithFilterTableApi(): Unit = {
+ val util = batchTestUtil()
+ val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val resultTable = sourceTable.select('a,'b,'c).where('a === 1)
+ .select('a.avg,'b.sum,'c.count)
+
+ val calcNode = unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b", "c"),
+ term("where", "=(a, 1)")
+ )
+
+ val setValues = unaryNode(
+ "DataSetValues",
+ calcNode,
+ tuples(List(null,null,null)),
+ term("values","a","b","c")
+ )
+
+ val union = unaryNode(
+ "DataSetUnion",
+ setValues,
+ term("union","a","b","c")
+ )
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ union,
+ term("select",
+ "AVG(a) AS TMP_0",
+ "SUM(b) AS TMP_1",
+ "COUNT(c) AS TMP_2")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 2ea15a0..539bb61 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -70,20 +70,25 @@ object TableTestUtil {
def unaryNode(node: String, input: String, term: String*): String = {
s"""$node(${term.mkString(", ")})
|$input
- |""".stripMargin
+ |""".stripMargin.stripLineEnd
}
def binaryNode(node: String, left: String, right: String, term: String*): String = {
s"""$node(${term.mkString(", ")})
|$left
|$right
- |""".stripMargin
+ |""".stripMargin.stripLineEnd
}
def term(term: AnyRef, value: AnyRef*): String = {
s"$term=[${value.mkString(", ")}]"
}
+ def tuples(value:List[AnyRef]*): String={
+ val listValues = value.map( listValue => s"{ ${listValue.mkString(", ")} }")
+ term("tuples","[" + listValues.mkString(", ") + "]")
+ }
+
def batchTableNode(idx: Int): String = {
s"DataSetScan(table=[[_DataSetTable_$idx]])"
}