You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/15 10:49:57 UTC
[7/7] flink git commit: [FLINK-5266] [table] Inject projection of
unused fields before aggregations.
[FLINK-5266] [table] Inject projection of unused fields before aggregations.
This closes #2961.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/15e7f0a8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/15e7f0a8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/15e7f0a8
Branch: refs/heads/master
Commit: 15e7f0a8c7fd161d5847e7b2afae35b212ea23f0
Parents: 5dab934
Author: Kurt Young <yk...@gmail.com>
Authored: Thu Dec 8 10:35:43 2016 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Dec 15 11:36:40 2016 +0100
----------------------------------------------------------------------
.../api/table/plan/ProjectionTranslator.scala | 105 ++++--
.../org/apache/flink/api/table/table.scala | 83 ++---
.../org/apache/flink/api/table/windows.scala | 2 +-
.../scala/stream/table/GroupWindowTest.scala | 120 +++++--
.../api/table/plan/FieldProjectionTest.scala | 317 +++++++++++++++++++
.../flink/api/table/utils/TableTestBase.scala | 4 +
6 files changed, 551 insertions(+), 80 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index 22b77b4..a25c402 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -29,31 +29,19 @@ object ProjectionTranslator {
/**
* Extracts and deduplicates all aggregation and window property expressions (zero, one, or more)
- * from all expressions and replaces the original expressions by field accesses expressions.
+ * from the given expressions.
*
- * @param exprs a list of expressions to convert
+ * @param exprs a list of expressions to extract
* @param tableEnv the TableEnvironment
- * @return a Tuple3, the first field contains the converted expressions, the second field the
- * extracted and deduplicated aggregations, and the third field the extracted and
- * deduplicated window properties.
+ * @return a Tuple2, the first field contains the extracted and deduplicated aggregations,
+ * and the second field contains the extracted and deduplicated window properties.
*/
def extractAggregationsAndProperties(
exprs: Seq[Expression],
- tableEnv: TableEnvironment)
- : (Seq[NamedExpression], Seq[NamedExpression], Seq[NamedExpression]) = {
-
- val (aggNames, propNames) =
- exprs.foldLeft( (Map[Expression, String](), Map[Expression, String]()) ) {
- (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
- }
-
- val replaced = exprs
- .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
- .map(UnresolvedAlias)
- val aggs = aggNames.map( a => Alias(a._1, a._2)).toSeq
- val props = propNames.map( p => Alias(p._1, p._2)).toSeq
-
- (replaced, aggs, props)
+ tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = {
+ exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) {
+ (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
+ }
}
/** Identifies and deduplicates aggregation functions and window properties. */
@@ -106,7 +94,24 @@ object ProjectionTranslator {
}
}
- /** Replaces aggregations and projections by named field references. */
+ /**
+ * Replaces expressions with deduplicated aggregations and properties.
+ *
+ * @param exprs a list of expressions to replace
+ * @param tableEnv the TableEnvironment
+ * @param aggNames the deduplicated aggregations
+ * @param propNames the deduplicated properties
+ * @return a list of replaced expressions
+ */
+ def replaceAggregationsAndProperties(
+ exprs: Seq[Expression],
+ tableEnv: TableEnvironment,
+ aggNames: Map[Expression, String],
+ propNames: Map[Expression, String]): Seq[NamedExpression] = {
+ exprs.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+ .map(UnresolvedAlias)
+ }
+
private def replaceAggregationsAndProperties(
exp: Expression,
tableEnv: TableEnvironment,
@@ -197,4 +202,62 @@ object ProjectionTranslator {
}
projectList
}
+
+ /**
+ * Extract all field references from the given expressions.
+ *
+ * @param exprs a list of expressions to extract
+ * @return a list of field references extracted from the given expressions
+ */
+ def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = {
+ exprs.foldLeft(Set[NamedExpression]()) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }.toSeq
+ }
+
+ private def identifyFieldReferences(
+ expr: Expression,
+ fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match {
+
+ case f: UnresolvedFieldReference =>
+ fieldReferences + UnresolvedAlias(f)
+
+ case b: BinaryExpression =>
+ val l = identifyFieldReferences(b.left, fieldReferences)
+ identifyFieldReferences(b.right, l)
+
+ // Functions calls
+ case c @ Call(name, args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+ case sfc @ ScalarFunctionCall(clazz, args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+
+ // array constructor
+ case c @ ArrayConstructor(args) =>
+ args.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+ }
+
+ // ignore fields from window property
+ case w : WindowProperty =>
+ fieldReferences
+
+ // keep this case after all unwanted unary expressions
+ case u: UnaryExpression =>
+ identifyFieldReferences(u.child, fieldReferences)
+
+ // General expression
+ case e: Expression =>
+ e.productIterator.foldLeft(fieldReferences) {
+ (fieldReferences, expr) => expr match {
+ case e: Expression => identifyFieldReferences(e, fieldReferences)
+ case _ => fieldReferences
+ }
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
index b74ddb0..94c8e8c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
@@ -20,10 +20,9 @@ package org.apache.flink.api.table
import org.apache.calcite.rel.RelNode
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.join.JoinType
-import org.apache.flink.api.table.plan.logical.Minus
-import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall}
+import org.apache.flink.api.table.expressions._
import org.apache.flink.api.table.plan.ProjectionTranslator._
-import org.apache.flink.api.table.plan.logical._
+import org.apache.flink.api.table.plan.logical.{Minus, _}
import org.apache.flink.api.table.sinks.TableSink
import scala.collection.JavaConverters._
@@ -77,21 +76,27 @@ class Table(
* }}}
*/
def select(fields: Expression*): Table = {
-
val expandedFields = expandProjectList(fields, logicalPlan, tableEnv)
- val (projection, aggs, props) = extractAggregationsAndProperties(expandedFields, tableEnv)
-
- if (props.nonEmpty) {
+ val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, tableEnv)
+ if (propNames.nonEmpty) {
throw ValidationException("Window properties can only be used on windowed tables.")
}
- if (aggs.nonEmpty) {
+ if (aggNames.nonEmpty) {
+ val projectsOnAgg = replaceAggregationsAndProperties(
+ expandedFields, tableEnv, aggNames, propNames)
+ val projectFields = extractFieldReferences(expandedFields)
+
new Table(tableEnv,
- Project(projection,
- Aggregate(Nil, aggs, logicalPlan).validate(tableEnv)).validate(tableEnv))
+ Project(projectsOnAgg,
+ Aggregate(Nil, aggNames.map(a => Alias(a._1, a._2)).toSeq,
+ Project(projectFields, logicalPlan).validate(tableEnv)
+ ).validate(tableEnv)
+ ).validate(tableEnv)
+ )
} else {
new Table(tableEnv,
- Project(projection, logicalPlan).validate(tableEnv))
+ Project(expandedFields.map(UnresolvedAlias), logicalPlan).validate(tableEnv))
}
}
@@ -806,24 +811,21 @@ class GroupedTable(
* }}}
*/
def select(fields: Expression*): Table = {
-
- val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv)
-
- if (props.nonEmpty) {
+ val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv)
+ if (propNames.nonEmpty) {
throw ValidationException("Window properties can only be used on windowed tables.")
}
- val logical =
- Project(
- projection,
- Aggregate(
- groupKey,
- aggs,
- table.logicalPlan
- ).validate(table.tableEnv)
- ).validate(table.tableEnv)
+ val projectsOnAgg = replaceAggregationsAndProperties(
+ fields, table.tableEnv, aggNames, propNames)
+ val projectFields = extractFieldReferences(fields ++ groupKey)
- new Table(table.tableEnv, logical)
+ new Table(table.tableEnv,
+ Project(projectsOnAgg,
+ Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq,
+ Project(projectFields, table.logicalPlan).validate(table.tableEnv)
+ ).validate(table.tableEnv)
+ ).validate(table.tableEnv))
}
/**
@@ -877,24 +879,29 @@ class GroupWindowedTable(
* }}}
*/
def select(fields: Expression*): Table = {
+ val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv)
+ val projectsOnAgg = replaceAggregationsAndProperties(
+ fields, table.tableEnv, aggNames, propNames)
+
+ val projectFields = (table.tableEnv, window) match {
+ // event time can be arbitrary field in batch environment
+ case (_: BatchTableEnvironment, w: EventTimeWindow) =>
+ extractFieldReferences(fields ++ groupKey ++ Seq(w.timeField))
+ case (_, _) =>
+ extractFieldReferences(fields ++ groupKey)
+ }
- val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv)
-
- val groupWindow = window.toLogicalWindow
-
- val logical =
+ new Table(table.tableEnv,
Project(
- projection,
+ projectsOnAgg,
WindowAggregate(
groupKey,
- groupWindow,
- props,
- aggs,
- table.logicalPlan
+ window.toLogicalWindow,
+ propNames.map(a => Alias(a._1, a._2)).toSeq,
+ aggNames.map(a => Alias(a._1, a._2)).toSeq,
+ Project(projectFields, table.logicalPlan).validate(table.tableEnv)
).validate(table.tableEnv)
- ).validate(table.tableEnv)
-
- new Table(table.tableEnv, logical)
+ ).validate(table.tableEnv))
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
index 32d67d7..5637d7a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
@@ -48,7 +48,7 @@ trait GroupWindow {
* @param timeField defines the time mode for streaming tables. For batch table it defines the
* time attribute on which is grouped.
*/
-abstract class EventTimeWindow(timeField: Expression) extends GroupWindow {
+abstract class EventTimeWindow(val timeField: Expression) extends GroupWindow {
protected var name: Option[Expression] = None
http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
index b59b151..9c2d6b3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
@@ -164,7 +164,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", ProcessingTimeTumblingGroupWindow(None, 50.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -185,7 +189,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", ProcessingTimeTumblingGroupWindow(None, 2.rows)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -206,7 +214,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 5.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -249,7 +261,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", ProcessingTimeSlidingGroupWindow(None, 50.milli, 50.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -270,7 +286,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", ProcessingTimeSlidingGroupWindow(None, 2.rows, 1.rows)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -291,7 +311,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 8.milli, 10.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -334,7 +358,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", EventTimeSessionGroupWindow(None, RowtimeAttribute(), 7.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -355,7 +383,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window", ProcessingTimeTumblingGroupWindow(None, 50.milli)),
term("select", "string", "COUNT(int) AS TMP_0")
@@ -375,7 +407,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", ProcessingTimeTumblingGroupWindow(None, 2.rows)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -394,7 +430,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 5.milli)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -414,7 +454,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 2.rows)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -434,7 +478,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", ProcessingTimeSlidingGroupWindow(None, 50.milli, 50.milli)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -453,7 +501,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", ProcessingTimeSlidingGroupWindow(None, 2.rows, 1.rows)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -472,7 +524,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 8.milli, 10.milli)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -492,7 +548,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 2.rows, 1.rows)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -511,7 +571,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "int")
+ ),
term("window", EventTimeSessionGroupWindow(None, RowtimeAttribute(), 7.milli)),
term("select", "COUNT(int) AS TMP_0")
)
@@ -531,7 +595,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window",
EventTimeTumblingGroupWindow(
@@ -560,7 +628,11 @@ class GroupWindowTest extends TableTestBase {
val expected = unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window",
EventTimeSlidingGroupWindow(
@@ -592,7 +664,11 @@ class GroupWindowTest extends TableTestBase {
"DataStreamCalc",
unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window",
EventTimeSessionGroupWindow(
@@ -626,7 +702,11 @@ class GroupWindowTest extends TableTestBase {
"DataStreamCalc",
unaryNode(
"DataStreamAggregate",
- streamTableNode(0),
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "string", "int")
+ ),
term("groupBy", "string"),
term("window",
EventTimeTumblingGroupWindow(
http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala
new file mode 100644
index 0000000..1cefb8a
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.ValidationException
+import org.apache.flink.api.table.expressions.{RowtimeAttribute, Upper, WindowReference}
+import org.apache.flink.api.table.functions.ScalarFunction
+import org.apache.flink.api.table.plan.FieldProjectionTest._
+import org.apache.flink.api.table.plan.logical.EventTimeTumblingGroupWindow
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+/**
+ * Tests for all the situations when we can do fields projection. Like selecting few fields
+ * from a large field count source.
+ */
+class FieldProjectionTest extends TableTestBase {
+
+ val util = batchTestUtil()
+
+ val streamUtil = streamTestUtil()
+
+ @Test
+ def testSimpleSelect(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.select('a, 'b)
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectAllFields(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable1 = sourceTable.select('*)
+ val resultTable2 = sourceTable.select('a, 'b, 'c, 'd)
+
+ val expected = batchTableNode(0)
+
+ util.verifyTable(resultTable1, expected)
+ util.verifyTable(resultTable2, expected)
+ }
+
+ @Test
+ def testSelectAggregation(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.select('a.sum, 'b.max)
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ binaryNode(
+ "DataSetUnion",
+ values(
+ "DataSetValues",
+ tuples(List(null, null)),
+ term("values", "a", "b")
+ ),
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "b")
+ ),
+ term("union", "a", "b")
+ ),
+ term("select", "SUM(a) AS TMP_0", "MAX(b) AS TMP_1")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectFunction(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+
+ util.tEnv.registerFunction("hashCode", MyHashCode)
+
+ val resultTable = sourceTable.select("hashCode(c), b")
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", s"${MyHashCode.getClass.getCanonicalName}(c) AS _c0", "b")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectFromGroupedTable(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.groupBy('a, 'c).select('a)
+
+ val expected = unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "c")
+ ),
+ term("groupBy", "a", "c"),
+ term("select", "a", "c")
+ ),
+ term("select", "a")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectAllFieldsFromGroupedTable(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.groupBy('a, 'c).select('a, 'c)
+
+ val expected = unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "c")
+ ),
+ term("groupBy", "a", "c"),
+ term("select", "a", "c")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectAggregationFromGroupedTable(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.groupBy('c).select('a.sum)
+
+ val expected =
+ unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "c")
+ ),
+ term("groupBy", "c"),
+ term("select", "c", "SUM(a) AS TMP_0")
+ ),
+ term("select", "TMP_0 AS TMP_1")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectFromGroupedTableWithNonTrivialKey(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.groupBy(Upper('c) as 'k).select('a.sum)
+
+ val expected =
+ unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "c", "UPPER(c) AS k")
+ ),
+ term("groupBy", "k"),
+ term("select", "k", "SUM(a) AS TMP_0")
+ ),
+ term("select", "TMP_0 AS TMP_1")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectFromGroupedTableWithFunctionKey(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable.groupBy(MyHashCode('c) as 'k).select('a.sum)
+
+ val expected =
+ unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "a", "c", s"${MyHashCode.getClass.getCanonicalName}(c) AS k")
+ ),
+ term("groupBy", "k"),
+ term("select", "k", "SUM(a) AS TMP_0")
+ ),
+ term("select", "TMP_0 AS TMP_1")
+ )
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectFromStreamingWindow(): Unit = {
+ val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable
+ .window(Tumble over 5.millis on 'rowtime as 'w)
+ .select(Upper('c).count, 'a.sum)
+
+ val expected =
+ unaryNode(
+ "DataStreamAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "c", "a", "UPPER(c) AS $f2")
+ ),
+ term("window",
+ EventTimeTumblingGroupWindow(
+ Some(WindowReference("w")),
+ RowtimeAttribute(),
+ 5.millis)),
+ term("select", "COUNT($f2) AS TMP_0", "SUM(a) AS TMP_1")
+ )
+
+ streamUtil.verifyTable(resultTable, expected)
+ }
+
+ @Test
+ def testSelectFromStreamingGroupedWindow(): Unit = {
+ val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+ val resultTable = sourceTable
+ .groupBy('b)
+ .window(Tumble over 5.millis on 'rowtime as 'w)
+ .select(Upper('c).count, 'a.sum, 'b)
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(0),
+ term("select", "c", "a", "b", "UPPER(c) AS $f3")
+ ),
+ term("groupBy", "b"),
+ term("window",
+ EventTimeTumblingGroupWindow(
+ Some(WindowReference("w")),
+ RowtimeAttribute(),
+ 5.millis)),
+ term("select", "b", "COUNT($f3) AS TMP_0", "SUM(a) AS TMP_1")
+ ),
+ term("select", "TMP_0 AS TMP_2", "TMP_1 AS TMP_3", "b")
+ )
+
+ streamUtil.verifyTable(resultTable, expected)
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testSelectFromBatchWindow1(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+
+ // time field is selected
+ val resultTable = sourceTable
+ .window(Tumble over 5.millis on 'a as 'w)
+ .select('a.sum, 'c.count)
+
+ val expected = "TODO"
+
+ util.verifyTable(resultTable, expected)
+ }
+
+ @Test(expected = classOf[ValidationException])
+ def testSelectFromBatchWindow2(): Unit = {
+ val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+
+ // time field is not selected
+ val resultTable = sourceTable
+ .window(Tumble over 5.millis on 'a as 'w)
+ .select('c.count)
+
+ val expected = "TODO"
+
+ util.verifyTable(resultTable, expected)
+ }
+}
+
+object FieldProjectionTest {
+
+ object MyHashCode extends ScalarFunction {
+ def eval(s: String): Int = s.hashCode()
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 4eaba90..b281dfc 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -91,6 +91,10 @@ object TableTestUtil {
|""".stripMargin.stripLineEnd
}
+ def values(node: String, term: String*): String = {
+ s"$node(${term.mkString(", ")})"
+ }
+
def term(term: AnyRef, value: AnyRef*): String = {
s"$term=[${value.mkString(", ")}]"
}