You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2016/12/15 10:49:57 UTC

[7/7] flink git commit: [FLINK-5266] [table] Inject projection of unused fields before aggregations.

[FLINK-5266] [table] Inject projection of unused fields before aggregations.

This closes #2961.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/15e7f0a8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/15e7f0a8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/15e7f0a8

Branch: refs/heads/master
Commit: 15e7f0a8c7fd161d5847e7b2afae35b212ea23f0
Parents: 5dab934
Author: Kurt Young <yk...@gmail.com>
Authored: Thu Dec 8 10:35:43 2016 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Dec 15 11:36:40 2016 +0100

----------------------------------------------------------------------
 .../api/table/plan/ProjectionTranslator.scala   | 105 ++++--
 .../org/apache/flink/api/table/table.scala      |  83 ++---
 .../org/apache/flink/api/table/windows.scala    |   2 +-
 .../scala/stream/table/GroupWindowTest.scala    | 120 +++++--
 .../api/table/plan/FieldProjectionTest.scala    | 317 +++++++++++++++++++
 .../flink/api/table/utils/TableTestBase.scala   |   4 +
 6 files changed, 551 insertions(+), 80 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
index 22b77b4..a25c402 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala
@@ -29,31 +29,19 @@ object ProjectionTranslator {
 
   /**
     * Extracts and deduplicates all aggregation and window property expressions (zero, one, or more)
-    * from all expressions and replaces the original expressions by field accesses expressions.
+    * from the given expressions.
     *
-    * @param exprs a list of expressions to convert
+    * @param exprs    a list of expressions to extract
     * @param tableEnv the TableEnvironment
-    * @return a Tuple3, the first field contains the converted expressions, the second field the
-    *         extracted and deduplicated aggregations, and the third field the extracted and
-    *         deduplicated window properties.
+    * @return a Tuple2, the first field contains the extracted and deduplicated aggregations,
+    *         and the second field contains the extracted and deduplicated window properties.
     */
   def extractAggregationsAndProperties(
       exprs: Seq[Expression],
-      tableEnv: TableEnvironment)
-  : (Seq[NamedExpression], Seq[NamedExpression], Seq[NamedExpression]) = {
-
-    val (aggNames, propNames) =
-      exprs.foldLeft( (Map[Expression, String](), Map[Expression, String]()) ) {
-        (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
-      }
-
-    val replaced = exprs
-      .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
-      .map(UnresolvedAlias)
-    val aggs = aggNames.map( a => Alias(a._1, a._2)).toSeq
-    val props = propNames.map( p => Alias(p._1, p._2)).toSeq
-
-    (replaced, aggs, props)
+      tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = {
+    exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) {
+      (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2)
+    }
   }
 
   /** Identifies and deduplicates aggregation functions and window properties. */
@@ -106,7 +94,24 @@ object ProjectionTranslator {
     }
   }
 
-  /** Replaces aggregations and projections by named field references. */
+  /**
+    * Replaces expressions with deduplicated aggregations and properties.
+    *
+    * @param exprs     a list of expressions to replace
+    * @param tableEnv  the TableEnvironment
+    * @param aggNames  the deduplicated aggregations
+    * @param propNames the deduplicated properties
+    * @return a list of replaced expressions
+    */
+  def replaceAggregationsAndProperties(
+      exprs: Seq[Expression],
+      tableEnv: TableEnvironment,
+      aggNames: Map[Expression, String],
+      propNames: Map[Expression, String]): Seq[NamedExpression] = {
+    exprs.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames))
+        .map(UnresolvedAlias)
+  }
+
   private def replaceAggregationsAndProperties(
       exp: Expression,
       tableEnv: TableEnvironment,
@@ -197,4 +202,62 @@ object ProjectionTranslator {
     }
     projectList
   }
+
+  /**
+    * Extract all field references from the given expressions.
+    *
+    * @param exprs a list of expressions to extract
+    * @return a list of field references extracted from the given expressions
+    */
+  def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = {
+    exprs.foldLeft(Set[NamedExpression]()) {
+      (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+    }.toSeq
+  }
+
+  private def identifyFieldReferences(
+      expr: Expression,
+      fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match {
+
+    case f: UnresolvedFieldReference =>
+      fieldReferences + UnresolvedAlias(f)
+
+    case b: BinaryExpression =>
+      val l = identifyFieldReferences(b.left, fieldReferences)
+      identifyFieldReferences(b.right, l)
+
+    // Functions calls
+    case c @ Call(name, args) =>
+      args.foldLeft(fieldReferences) {
+        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+      }
+    case sfc @ ScalarFunctionCall(clazz, args) =>
+      args.foldLeft(fieldReferences) {
+        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+      }
+
+    // array constructor
+    case c @ ArrayConstructor(args) =>
+      args.foldLeft(fieldReferences) {
+        (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences)
+      }
+
+    // ignore fields from window property
+    case w : WindowProperty =>
+      fieldReferences
+
+    // keep this case after all unwanted unary expressions
+    case u: UnaryExpression =>
+      identifyFieldReferences(u.child, fieldReferences)
+
+    // General expression
+    case e: Expression =>
+      e.productIterator.foldLeft(fieldReferences) {
+        (fieldReferences, expr) => expr match {
+          case e: Expression => identifyFieldReferences(e, fieldReferences)
+          case _ => fieldReferences
+        }
+      }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
index b74ddb0..94c8e8c 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala
@@ -20,10 +20,9 @@ package org.apache.flink.api.table
 import org.apache.calcite.rel.RelNode
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.operators.join.JoinType
-import org.apache.flink.api.table.plan.logical.Minus
-import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall}
+import org.apache.flink.api.table.expressions._
 import org.apache.flink.api.table.plan.ProjectionTranslator._
-import org.apache.flink.api.table.plan.logical._
+import org.apache.flink.api.table.plan.logical.{Minus, _}
 import org.apache.flink.api.table.sinks.TableSink
 
 import scala.collection.JavaConverters._
@@ -77,21 +76,27 @@ class Table(
     * }}}
     */
   def select(fields: Expression*): Table = {
-
     val expandedFields = expandProjectList(fields, logicalPlan, tableEnv)
-    val (projection, aggs, props) = extractAggregationsAndProperties(expandedFields, tableEnv)
-
-    if (props.nonEmpty) {
+    val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, tableEnv)
+    if (propNames.nonEmpty) {
       throw ValidationException("Window properties can only be used on windowed tables.")
     }
 
-    if (aggs.nonEmpty) {
+    if (aggNames.nonEmpty) {
+      val projectsOnAgg = replaceAggregationsAndProperties(
+        expandedFields, tableEnv, aggNames, propNames)
+      val projectFields = extractFieldReferences(expandedFields)
+
       new Table(tableEnv,
-        Project(projection,
-          Aggregate(Nil, aggs, logicalPlan).validate(tableEnv)).validate(tableEnv))
+        Project(projectsOnAgg,
+          Aggregate(Nil, aggNames.map(a => Alias(a._1, a._2)).toSeq,
+            Project(projectFields, logicalPlan).validate(tableEnv)
+          ).validate(tableEnv)
+        ).validate(tableEnv)
+      )
     } else {
       new Table(tableEnv,
-        Project(projection, logicalPlan).validate(tableEnv))
+        Project(expandedFields.map(UnresolvedAlias), logicalPlan).validate(tableEnv))
     }
   }
 
@@ -806,24 +811,21 @@ class GroupedTable(
     * }}}
     */
   def select(fields: Expression*): Table = {
-
-    val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv)
-
-    if (props.nonEmpty) {
+    val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv)
+    if (propNames.nonEmpty) {
       throw ValidationException("Window properties can only be used on windowed tables.")
     }
 
-    val logical =
-      Project(
-        projection,
-        Aggregate(
-          groupKey,
-          aggs,
-          table.logicalPlan
-        ).validate(table.tableEnv)
-      ).validate(table.tableEnv)
+    val projectsOnAgg = replaceAggregationsAndProperties(
+      fields, table.tableEnv, aggNames, propNames)
+    val projectFields = extractFieldReferences(fields ++ groupKey)
 
-    new Table(table.tableEnv, logical)
+    new Table(table.tableEnv,
+      Project(projectsOnAgg,
+        Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq,
+          Project(projectFields, table.logicalPlan).validate(table.tableEnv)
+        ).validate(table.tableEnv)
+      ).validate(table.tableEnv))
   }
 
   /**
@@ -877,24 +879,29 @@ class GroupWindowedTable(
     * }}}
     */
   def select(fields: Expression*): Table = {
+    val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv)
+    val projectsOnAgg = replaceAggregationsAndProperties(
+      fields, table.tableEnv, aggNames, propNames)
+
+    val projectFields = (table.tableEnv, window) match {
+      // event time can be arbitrary field in batch environment
+      case (_: BatchTableEnvironment, w: EventTimeWindow) =>
+        extractFieldReferences(fields ++ groupKey ++ Seq(w.timeField))
+      case (_, _) =>
+        extractFieldReferences(fields ++ groupKey)
+    }
 
-    val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv)
-
-    val groupWindow = window.toLogicalWindow
-
-    val logical =
+    new Table(table.tableEnv,
       Project(
-        projection,
+        projectsOnAgg,
         WindowAggregate(
           groupKey,
-          groupWindow,
-          props,
-          aggs,
-          table.logicalPlan
+          window.toLogicalWindow,
+          propNames.map(a => Alias(a._1, a._2)).toSeq,
+          aggNames.map(a => Alias(a._1, a._2)).toSeq,
+          Project(projectFields, table.logicalPlan).validate(table.tableEnv)
         ).validate(table.tableEnv)
-      ).validate(table.tableEnv)
-
-    new Table(table.tableEnv, logical)
+      ).validate(table.tableEnv))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
index 32d67d7..5637d7a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala
@@ -48,7 +48,7 @@ trait GroupWindow {
   * @param timeField defines the time mode for streaming tables. For batch table it defines the
   *                  time attribute on which is grouped.
   */
-abstract class EventTimeWindow(timeField: Expression) extends GroupWindow {
+abstract class EventTimeWindow(val timeField: Expression) extends GroupWindow {
 
   protected var name: Option[Expression] = None
 

http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
index b59b151..9c2d6b3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala
@@ -164,7 +164,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", ProcessingTimeTumblingGroupWindow(None, 50.milli)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -185,7 +189,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", ProcessingTimeTumblingGroupWindow(None, 2.rows)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -206,7 +214,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 5.milli)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -249,7 +261,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", ProcessingTimeSlidingGroupWindow(None, 50.milli, 50.milli)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -270,7 +286,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", ProcessingTimeSlidingGroupWindow(None, 2.rows, 1.rows)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -291,7 +311,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 8.milli, 10.milli)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -334,7 +358,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", EventTimeSessionGroupWindow(None, RowtimeAttribute(), 7.milli)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -355,7 +383,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window", ProcessingTimeTumblingGroupWindow(None, 50.milli)),
       term("select", "string", "COUNT(int) AS TMP_0")
@@ -375,7 +407,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", ProcessingTimeTumblingGroupWindow(None, 2.rows)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -394,7 +430,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 5.milli)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -414,7 +454,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 2.rows)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -434,7 +478,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", ProcessingTimeSlidingGroupWindow(None, 50.milli, 50.milli)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -453,7 +501,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", ProcessingTimeSlidingGroupWindow(None, 2.rows, 1.rows)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -472,7 +524,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 8.milli, 10.milli)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -492,7 +548,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 2.rows, 1.rows)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -511,7 +571,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "int")
+      ),
       term("window", EventTimeSessionGroupWindow(None, RowtimeAttribute(), 7.milli)),
       term("select", "COUNT(int) AS TMP_0")
     )
@@ -531,7 +595,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window",
         EventTimeTumblingGroupWindow(
@@ -560,7 +628,11 @@ class GroupWindowTest extends TableTestBase {
 
     val expected = unaryNode(
       "DataStreamAggregate",
-      streamTableNode(0),
+      unaryNode(
+        "DataStreamCalc",
+        streamTableNode(0),
+        term("select", "string", "int")
+      ),
       term("groupBy", "string"),
       term("window",
         EventTimeSlidingGroupWindow(
@@ -592,7 +664,11 @@ class GroupWindowTest extends TableTestBase {
       "DataStreamCalc",
       unaryNode(
         "DataStreamAggregate",
-        streamTableNode(0),
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "string", "int")
+        ),
         term("groupBy", "string"),
         term("window",
           EventTimeSessionGroupWindow(
@@ -626,7 +702,11 @@ class GroupWindowTest extends TableTestBase {
       "DataStreamCalc",
       unaryNode(
         "DataStreamAggregate",
-        streamTableNode(0),
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "string", "int")
+        ),
         term("groupBy", "string"),
         term("window",
           EventTimeTumblingGroupWindow(

http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala
new file mode 100644
index 0000000..1cefb8a
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.plan
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.ValidationException
+import org.apache.flink.api.table.expressions.{RowtimeAttribute, Upper, WindowReference}
+import org.apache.flink.api.table.functions.ScalarFunction
+import org.apache.flink.api.table.plan.FieldProjectionTest._
+import org.apache.flink.api.table.plan.logical.EventTimeTumblingGroupWindow
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.junit.Test
+
+/**
+  * Tests for all the situations when we can do fields projection. Like selecting few fields
+  * from a large field count source.
+  */
+class FieldProjectionTest extends TableTestBase {
+
+  val util = batchTestUtil()
+
+  val streamUtil = streamTestUtil()
+
+  @Test
+  def testSimpleSelect(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.select('a, 'b)
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select", "a", "b")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectAllFields(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable1 = sourceTable.select('*)
+    val resultTable2 = sourceTable.select('a, 'b, 'c, 'd)
+
+    val expected = batchTableNode(0)
+
+    util.verifyTable(resultTable1, expected)
+    util.verifyTable(resultTable2, expected)
+  }
+
+  @Test
+  def testSelectAggregation(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.select('a.sum, 'b.max)
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      binaryNode(
+        "DataSetUnion",
+        values(
+          "DataSetValues",
+          tuples(List(null, null)),
+          term("values", "a", "b")
+        ),
+        unaryNode(
+          "DataSetCalc",
+          batchTableNode(0),
+          term("select", "a", "b")
+        ),
+        term("union", "a", "b")
+      ),
+      term("select", "SUM(a) AS TMP_0", "MAX(b) AS TMP_1")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectFunction(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+
+    util.tEnv.registerFunction("hashCode", MyHashCode)
+
+    val resultTable = sourceTable.select("hashCode(c), b")
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select", s"${MyHashCode.getClass.getCanonicalName}(c) AS _c0", "b")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectFromGroupedTable(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.groupBy('a, 'c).select('a)
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetAggregate",
+        unaryNode(
+          "DataSetCalc",
+          batchTableNode(0),
+          term("select", "a", "c")
+        ),
+        term("groupBy", "a", "c"),
+        term("select", "a", "c")
+      ),
+      term("select", "a")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectAllFieldsFromGroupedTable(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.groupBy('a, 'c).select('a, 'c)
+
+    val expected = unaryNode(
+      "DataSetAggregate",
+      unaryNode(
+        "DataSetCalc",
+        batchTableNode(0),
+        term("select", "a", "c")
+      ),
+      term("groupBy", "a", "c"),
+      term("select", "a", "c")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectAggregationFromGroupedTable(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.groupBy('c).select('a.sum)
+
+    val expected =
+      unaryNode(
+        "DataSetCalc",
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetCalc",
+            batchTableNode(0),
+            term("select", "a", "c")
+          ),
+          term("groupBy", "c"),
+          term("select", "c", "SUM(a) AS TMP_0")
+        ),
+        term("select", "TMP_0 AS TMP_1")
+      )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectFromGroupedTableWithNonTrivialKey(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.groupBy(Upper('c) as 'k).select('a.sum)
+
+    val expected =
+      unaryNode(
+        "DataSetCalc",
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetCalc",
+            batchTableNode(0),
+            term("select", "a", "c", "UPPER(c) AS k")
+          ),
+          term("groupBy", "k"),
+          term("select", "k", "SUM(a) AS TMP_0")
+        ),
+        term("select", "TMP_0 AS TMP_1")
+      )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectFromGroupedTableWithFunctionKey(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.groupBy(MyHashCode('c) as 'k).select('a.sum)
+
+    val expected =
+      unaryNode(
+        "DataSetCalc",
+        unaryNode(
+          "DataSetAggregate",
+          unaryNode(
+            "DataSetCalc",
+            batchTableNode(0),
+            term("select", "a", "c", s"${MyHashCode.getClass.getCanonicalName}(c) AS k")
+          ),
+          term("groupBy", "k"),
+          term("select", "k", "SUM(a) AS TMP_0")
+        ),
+        term("select", "TMP_0 AS TMP_1")
+      )
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectFromStreamingWindow(): Unit = {
+    val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable
+        .window(Tumble over 5.millis on 'rowtime as 'w)
+        .select(Upper('c).count, 'a.sum)
+
+    val expected =
+      unaryNode(
+        "DataStreamAggregate",
+        unaryNode(
+          "DataStreamCalc",
+          streamTableNode(0),
+          term("select", "c", "a", "UPPER(c) AS $f2")
+        ),
+        term("window",
+          EventTimeTumblingGroupWindow(
+            Some(WindowReference("w")),
+            RowtimeAttribute(),
+            5.millis)),
+        term("select", "COUNT($f2) AS TMP_0", "SUM(a) AS TMP_1")
+      )
+
+    streamUtil.verifyTable(resultTable, expected)
+  }
+
+  @Test
+  def testSelectFromStreamingGroupedWindow(): Unit = {
+    val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable
+        .groupBy('b)
+        .window(Tumble over 5.millis on 'rowtime as 'w)
+        .select(Upper('c).count, 'a.sum, 'b)
+
+    val expected = unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(0),
+            term("select", "c", "a", "b", "UPPER(c) AS $f3")
+          ),
+          term("groupBy", "b"),
+          term("window",
+            EventTimeTumblingGroupWindow(
+              Some(WindowReference("w")),
+              RowtimeAttribute(),
+              5.millis)),
+          term("select", "b", "COUNT($f3) AS TMP_0", "SUM(a) AS TMP_1")
+        ),
+        term("select", "TMP_0 AS TMP_2", "TMP_1 AS TMP_3", "b")
+    )
+
+    streamUtil.verifyTable(resultTable, expected)
+  }
+
+  @Test(expected = classOf[ValidationException])
+  def testSelectFromBatchWindow1(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+
+    // time field is selected
+    val resultTable = sourceTable
+        .window(Tumble over 5.millis on 'a as 'w)
+        .select('a.sum, 'c.count)
+
+    val expected = "TODO"
+
+    util.verifyTable(resultTable, expected)
+  }
+
+  @Test(expected = classOf[ValidationException])
+  def testSelectFromBatchWindow2(): Unit = {
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+
+    // time field is not selected
+    val resultTable = sourceTable
+        .window(Tumble over 5.millis on 'a as 'w)
+        .select('c.count)
+
+    val expected = "TODO"
+
+    util.verifyTable(resultTable, expected)
+  }
+}
+
+object FieldProjectionTest {
+
+  object MyHashCode extends ScalarFunction {
+    def eval(s: String): Int = s.hashCode()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 4eaba90..b281dfc 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -91,6 +91,10 @@ object TableTestUtil {
        |""".stripMargin.stripLineEnd
   }
 
+  def values(node: String, term: String*): String = {
+    s"$node(${term.mkString(", ")})"
+  }
+
   def term(term: AnyRef, value: AnyRef*): String = {
     s"$term=[${value.mkString(", ")}]"
   }