You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/02/14 13:34:38 UTC

flink git commit: [FLINK-5406] [table] Add normalization phase for predicate logical plan rewriting

Repository: flink
Updated Branches:
  refs/heads/master 186b12309 -> 8efacf588


[FLINK-5406] [table] Add normalization phase for predicate logical plan rewriting

This closes #3101.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8efacf58
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8efacf58
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8efacf58

Branch: refs/heads/master
Commit: 8efacf588cee45eb99a24628136ced308c4fb418
Parents: 186b123
Author: godfreyhe <go...@163.com>
Authored: Thu Jan 12 18:42:49 2017 +0800
Committer: twalthr <tw...@apache.org>
Committed: Tue Feb 14 14:30:04 2017 +0100

----------------------------------------------------------------------
 .../flink/table/api/BatchTableEnvironment.scala |  52 +++----
 .../table/api/StreamTableEnvironment.scala      |  48 ++++---
 .../flink/table/api/TableEnvironment.scala      | 103 ++++++++++++--
 .../flink/table/calcite/CalciteConfig.scala     |  90 ++++++++++---
 .../table/plan/nodes/dataset/DataSetCalc.scala  |   4 +-
 .../flink/table/plan/rules/FlinkRuleSets.scala  |  32 +++--
 .../api/java/batch/TableEnvironmentITCase.java  |   2 +-
 .../apache/flink/table/AggregationTest.scala    |   4 +-
 .../flink/table/CalciteConfigBuilderTest.scala  | 135 +++++++++++++++----
 .../flink/table/TableEnvironmentTest.scala      |   4 +-
 .../api/scala/batch/sql/SetOperatorsTest.scala  |   2 +-
 .../scala/batch/table/FieldProjectionTest.scala |   4 +-
 .../expressions/utils/ExpressionTestBase.scala  |  25 +++-
 .../plan/rules/NormalizationRulesTest.scala     |  98 ++++++++++++++
 14 files changed, 477 insertions(+), 126 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
index 2dec00e..b48e9f9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
@@ -20,12 +20,12 @@ package org.apache.flink.table.api
 
 import _root_.java.util.concurrent.atomic.AtomicInteger
 
-import org.apache.calcite.plan.RelOptPlanner.CannotPlanException
 import org.apache.calcite.plan.RelOptUtil
+import org.apache.calcite.plan.hep.HepMatchOrder
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.sql2rel.RelDecorrelator
-import org.apache.calcite.tools.{Programs, RuleSet}
+import org.apache.calcite.tools.RuleSet
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.io.DiscardingOutputFormat
 import org.apache.flink.api.java.typeutils.GenericTypeInfo
@@ -199,9 +199,14 @@ abstract class BatchTableEnvironment(
   }
 
   /**
-    * Returns the built-in rules that are defined by the environment.
+    * Returns the built-in normalization rules that are defined by the environment.
     */
-  protected def getBuiltInRuleSet: RuleSet = FlinkRuleSets.DATASET_OPT_RULES
+  protected def getBuiltInNormRuleSet: RuleSet = FlinkRuleSets.DATASET_NORM_RULES
+
+  /**
+    * Returns the built-in optimization rules that are defined by the environment.
+    */
+  protected def getBuiltInOptRuleSet: RuleSet = FlinkRuleSets.DATASET_OPT_RULES
 
   /**
     * Generates the optimized [[RelNode]] tree from the original relational node tree.
@@ -211,32 +216,27 @@ abstract class BatchTableEnvironment(
     */
   private[flink] def optimize(relNode: RelNode): RelNode = {
 
-    // decorrelate
+    // 1. decorrelate
     val decorPlan = RelDecorrelator.decorrelateQuery(relNode)
 
-    // optimize the logical Flink plan
-    val optProgram = Programs.ofRules(getRuleSet)
-    val flinkOutputProps = relNode.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
+    // 2. normalize the logical plan
+    val normRuleSet = getNormRuleSet
+    val normalizedPlan = if (normRuleSet.iterator().hasNext) {
+      runHepPlanner(HepMatchOrder.BOTTOM_UP, normRuleSet, decorPlan, decorPlan.getTraitSet)
+    } else {
+      decorPlan
+    }
 
-    val dataSetPlan = try {
-      optProgram.run(getPlanner, decorPlan, flinkOutputProps)
-    } catch {
-      case e: CannotPlanException =>
-        throw new TableException(
-          s"Cannot generate a valid execution plan for the given query: \n\n" +
-            s"${RelOptUtil.toString(relNode)}\n" +
-            s"This exception indicates that the query uses an unsupported SQL feature.\n" +
-            s"Please check the documentation for the set of currently supported SQL features.")
-      case t: TableException =>
-        throw new TableException(
-          s"Cannot generate a valid execution plan for the given query: \n\n" +
-            s"${RelOptUtil.toString(relNode)}\n" +
-            s"${t.msg}\n" +
-            s"Please check the documentation for the set of currently supported SQL features.")
-      case a: AssertionError =>
-        throw a.getCause
+    // 3. optimize the logical Flink plan
+    val optRuleSet = getOptRuleSet
+    val flinkOutputProps = relNode.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
+    val optimizedPlan = if (optRuleSet.iterator().hasNext) {
+      runVolcanoPlanner(optRuleSet, normalizedPlan, flinkOutputProps)
+    } else {
+      normalizedPlan
     }
-    dataSetPlan
+
+    optimizedPlan
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
index 19c4af1..d927c3a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
@@ -20,14 +20,14 @@ package org.apache.flink.table.api
 
 import _root_.java.util.concurrent.atomic.AtomicInteger
 
-import org.apache.calcite.plan.RelOptPlanner.CannotPlanException
 import org.apache.calcite.plan.RelOptUtil
+import org.apache.calcite.plan.hep.HepMatchOrder
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.sql2rel.RelDecorrelator
-import org.apache.calcite.tools.{Programs, RuleSet}
+import org.apache.calcite.tools.RuleSet
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo}
+import org.apache.flink.api.java.typeutils.GenericTypeInfo
 import org.apache.flink.streaming.api.datastream.DataStream
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
 import org.apache.flink.table.explain.PlanJsonParser
@@ -211,9 +211,14 @@ abstract class StreamTableEnvironment(
   }
 
   /**
-    * Returns the built-in rules that are defined by the environment.
+    * Returns the built-in normalization rules that are defined by the environment.
     */
-  protected def getBuiltInRuleSet: RuleSet = FlinkRuleSets.DATASTREAM_OPT_RULES
+  protected def getBuiltInNormRuleSet: RuleSet = FlinkRuleSets.DATASTREAM_NORM_RULES
+
+  /**
+    * Returns the built-in optimization rules that are defined by the environment.
+    */
+  protected def getBuiltInOptRuleSet: RuleSet = FlinkRuleSets.DATASTREAM_OPT_RULES
 
   /**
     * Generates the optimized [[RelNode]] tree from the original relational node tree.
@@ -222,25 +227,28 @@ abstract class StreamTableEnvironment(
     * @return The optimized [[RelNode]] tree
     */
   private[flink] def optimize(relNode: RelNode): RelNode = {
-    // decorrelate
-    val decorPlan = RelDecorrelator.decorrelateQuery(relNode)
 
-    // optimize the logical Flink plan
-    val optProgram = Programs.ofRules(getRuleSet)
-    val flinkOutputProps = relNode.getTraitSet.replace(DataStreamConvention.INSTANCE).simplify()
+    // 1. decorrelate
+    val decorPlan = RelDecorrelator.decorrelateQuery(relNode)
 
-    val dataStreamPlan = try {
-      optProgram.run(getPlanner, decorPlan, flinkOutputProps)
+    // 2. normalize the logical plan
+    val normRuleSet = getNormRuleSet
+    val normalizedPlan = if (normRuleSet.iterator().hasNext) {
+      runHepPlanner(HepMatchOrder.BOTTOM_UP, normRuleSet, decorPlan, decorPlan.getTraitSet)
+    } else {
+      decorPlan
     }
-    catch {
-      case e: CannotPlanException =>
-        throw TableException(
-          s"Cannot generate a valid execution plan for the given query: \n\n" +
-            s"${RelOptUtil.toString(relNode)}\n" +
-            s"This exception indicates that the query uses an unsupported SQL feature.\n" +
-            s"Please check the documentation for the set of currently supported SQL features.", e)
+
+    // 3. optimize the logical Flink plan
+    val optRuleSet = getOptRuleSet
+    val flinkOutputProps = relNode.getTraitSet.replace(DataStreamConvention.INSTANCE).simplify()
+    val optimizedPlan = if (optRuleSet.iterator().hasNext) {
+      runVolcanoPlanner(optRuleSet, normalizedPlan, flinkOutputProps)
+    } else {
+      normalizedPlan
     }
-    dataStreamPlan
+
+    optimizedPlan
   }
 
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
index b36441a..4a36320 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
@@ -23,14 +23,17 @@ import _root_.java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.calcite.config.Lex
 import org.apache.calcite.jdbc.CalciteSchema
-import org.apache.calcite.plan.RelOptPlanner
+import org.apache.calcite.plan.RelOptPlanner.CannotPlanException
+import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgramBuilder}
+import org.apache.calcite.plan.{RelOptPlanner, RelOptUtil, RelTraitSet}
+import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.schema.SchemaPlus
 import org.apache.calcite.schema.impl.AbstractTable
 import org.apache.calcite.sql.SqlOperatorTable
 import org.apache.calcite.sql.parser.SqlParser
 import org.apache.calcite.sql.util.ChainedSqlOperatorTable
-import org.apache.calcite.tools.{FrameworkConfig, Frameworks, RuleSet, RuleSets}
+import org.apache.calcite.tools._
 import org.apache.flink.api.common.functions.MapFunction
 import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
 import org.apache.flink.api.common.typeutils.CompositeType
@@ -119,20 +122,41 @@ abstract class TableEnvironment(val config: TableConfig) {
   }
 
   /**
-    * Returns the rule set for this environment including a custom Calcite configuration.
+    * Returns the normalization rule set for this environment
+    * including a custom RuleSet configuration.
     */
-  protected def getRuleSet: RuleSet = {
+  protected def getNormRuleSet: RuleSet = {
     val calciteConfig = config.getCalciteConfig
-    calciteConfig.getRuleSet match {
+    calciteConfig.getNormRuleSet match {
 
       case None =>
-        getBuiltInRuleSet
+        getBuiltInNormRuleSet
 
       case Some(ruleSet) =>
-        if (calciteConfig.replacesRuleSet) {
+        if (calciteConfig.replacesNormRuleSet) {
           ruleSet
         } else {
-          RuleSets.ofList((getBuiltInRuleSet.asScala ++ ruleSet.asScala).asJava)
+          RuleSets.ofList((getBuiltInNormRuleSet.asScala ++ ruleSet.asScala).asJava)
+        }
+    }
+  }
+
+  /**
+    * Returns the optimization rule set for this environment
+    * including a custom RuleSet configuration.
+    */
+  protected def getOptRuleSet: RuleSet = {
+    val calciteConfig = config.getCalciteConfig
+    calciteConfig.getOptRuleSet match {
+
+      case None =>
+        getBuiltInOptRuleSet
+
+      case Some(ruleSet) =>
+        if (calciteConfig.replacesOptRuleSet) {
+          ruleSet
+        } else {
+          RuleSets.ofList((getBuiltInOptRuleSet.asScala ++ ruleSet.asScala).asJava)
         }
     }
   }
@@ -158,9 +182,68 @@ abstract class TableEnvironment(val config: TableConfig) {
   }
 
   /**
-    * Returns the built-in rules that are defined by the environment.
+    * Returns the built-in normalization rules that are defined by the environment.
     */
-  protected def getBuiltInRuleSet: RuleSet
+  protected def getBuiltInNormRuleSet: RuleSet
+
+  /**
+    * Returns the built-in optimization rules that are defined by the environment.
+    */
+  protected def getBuiltInOptRuleSet: RuleSet
+
+  /**
+    * run HEP planner
+    */
+  protected def runHepPlanner(
+    hepMatchOrder: HepMatchOrder,
+    ruleSet: RuleSet,
+    input: RelNode,
+    targetTraits: RelTraitSet): RelNode = {
+    val builder = new HepProgramBuilder
+    builder.addMatchOrder(hepMatchOrder)
+
+    val it = ruleSet.iterator()
+    while (it.hasNext) {
+      builder.addRuleInstance(it.next())
+    }
+
+    val planner = new HepPlanner(builder.build, frameworkConfig.getContext)
+    planner.setRoot(input)
+    if (input.getTraitSet != targetTraits) {
+      planner.changeTraits(input, targetTraits.simplify)
+    }
+    planner.findBestExp
+  }
+
+  /**
+    * run VOLCANO planner
+    */
+  protected def runVolcanoPlanner(
+    ruleSet: RuleSet,
+    input: RelNode,
+    targetTraits: RelTraitSet): RelNode = {
+    val optProgram = Programs.ofRules(ruleSet)
+
+    val output = try {
+      optProgram.run(getPlanner, input, targetTraits)
+    } catch {
+      case e: CannotPlanException =>
+        throw new TableException(
+          s"Cannot generate a valid execution plan for the given query: \n\n" +
+            s"${RelOptUtil.toString(input)}\n" +
+            s"This exception indicates that the query uses an unsupported SQL feature.\n" +
+            s"Please check the documentation for the set of currently supported SQL features.")
+      case t: TableException =>
+        throw new TableException(
+          s"Cannot generate a valid execution plan for the given query: \n\n" +
+            s"${RelOptUtil.toString(input)}\n" +
+            s"${t.msg}\n" +
+            s"Please check the documentation for the set of currently supported SQL features.")
+      case a: AssertionError =>
+        throw a.getCause
+    }
+    output
+  }
 
   /**
     * Registers a [[ScalarFunction]] under a unique name. Replaces already existing

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala
index f646caf..65a61b2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/CalciteConfig.scala
@@ -31,8 +31,11 @@ import scala.collection.JavaConverters._
   * Builder for creating a Calcite configuration.
   */
 class CalciteConfigBuilder {
-  private var replaceRules: Boolean = false
-  private var ruleSets: List[RuleSet] = Nil
+  private var replaceNormRules: Boolean = false
+  private var normRuleSets: List[RuleSet] = Nil
+
+  private var replaceOptRules: Boolean = false
+  private var optRuleSets: List[RuleSet] = Nil
 
   private var replaceOperatorTable: Boolean = false
   private var operatorTables: List[SqlOperatorTable] = Nil
@@ -40,21 +43,40 @@ class CalciteConfigBuilder {
   private var replaceSqlParserConfig: Option[SqlParser.Config] = None
 
   /**
-    * Replaces the built-in rule set with the given rule set.
+    * Replaces the built-in normalization rule set with the given rule set.
+    */
+  def replaceNormRuleSet(replaceRuleSet: RuleSet): CalciteConfigBuilder = {
+    Preconditions.checkNotNull(replaceRuleSet)
+    normRuleSets = List(replaceRuleSet)
+    replaceNormRules = true
+    this
+  }
+
+  /**
+    * Appends the given normalization rule set to the built-in rule set.
+    */
+  def addNormRuleSet(addedRuleSet: RuleSet): CalciteConfigBuilder = {
+    Preconditions.checkNotNull(addedRuleSet)
+    normRuleSets = addedRuleSet :: normRuleSets
+    this
+  }
+
+  /**
+    * Replaces the built-in optimization rule set with the given rule set.
     */
-  def replaceRuleSet(replaceRuleSet: RuleSet): CalciteConfigBuilder = {
+  def replaceOptRuleSet(replaceRuleSet: RuleSet): CalciteConfigBuilder = {
     Preconditions.checkNotNull(replaceRuleSet)
-    ruleSets = List(replaceRuleSet)
-    replaceRules = true
+    optRuleSets = List(replaceRuleSet)
+    replaceOptRules = true
     this
   }
 
   /**
-    * Appends the given rule set to the built-in rule set.
+    * Appends the given optimization rule set to the built-in rule set.
     */
-  def addRuleSet(addedRuleSet: RuleSet): CalciteConfigBuilder = {
+  def addOptRuleSet(addedRuleSet: RuleSet): CalciteConfigBuilder = {
     Preconditions.checkNotNull(addedRuleSet)
-    ruleSets = addedRuleSet :: ruleSets
+    optRuleSets = addedRuleSet :: optRuleSets
     this
   }
 
@@ -87,32 +109,45 @@ class CalciteConfigBuilder {
   }
 
   private class CalciteConfigImpl(
-      val getRuleSet: Option[RuleSet],
-      val replacesRuleSet: Boolean,
-      val getSqlOperatorTable: Option[SqlOperatorTable],
-      val replacesSqlOperatorTable: Boolean,
-      val getSqlParserConfig: Option[SqlParser.Config])
+    val getNormRuleSet: Option[RuleSet],
+    val replacesNormRuleSet: Boolean,
+    val getOptRuleSet: Option[RuleSet],
+    val replacesOptRuleSet: Boolean,
+    val getSqlOperatorTable: Option[SqlOperatorTable],
+    val replacesSqlOperatorTable: Boolean,
+    val getSqlParserConfig: Option[SqlParser.Config])
     extends CalciteConfig
 
   /**
     * Builds a new [[CalciteConfig]].
     */
   def build(): CalciteConfig = new CalciteConfigImpl(
-        ruleSets match {
+    normRuleSets match {
       case Nil => None
       case h :: Nil => Some(h)
       case _ =>
         // concat rule sets
-        val concatRules = ruleSets.foldLeft(Nil: Iterable[RelOptRule])( (c, r) => r.asScala ++ c)
+        val concatRules =
+          normRuleSets.foldLeft(Nil: Iterable[RelOptRule])((c, r) => r.asScala ++ c)
         Some(RuleSets.ofList(concatRules.asJava))
     },
-    this.replaceRules,
+    replaceNormRules,
+    optRuleSets match {
+      case Nil => None
+      case h :: Nil => Some(h)
+      case _ =>
+        // concat rule sets
+        val concatRules =
+          optRuleSets.foldLeft(Nil: Iterable[RelOptRule])((c, r) => r.asScala ++ c)
+        Some(RuleSets.ofList(concatRules.asJava))
+    },
+    replaceOptRules,
     operatorTables match {
       case Nil => None
       case h :: Nil => Some(h)
       case _ =>
         // chain operator tables
-        Some(operatorTables.reduce( (x, y) => ChainedSqlOperatorTable.of(x, y)))
+        Some(operatorTables.reduce((x, y) => ChainedSqlOperatorTable.of(x, y)))
     },
     this.replaceOperatorTable,
     replaceSqlParserConfig)
@@ -122,15 +157,26 @@ class CalciteConfigBuilder {
   * Calcite configuration for defining a custom Calcite configuration for Table and SQL API.
   */
 trait CalciteConfig {
+
+  /**
+    * Returns whether this configuration replaces the built-in normalization rule set.
+    */
+  def replacesNormRuleSet: Boolean
+
+  /**
+    * Returns a custom normalization rule set.
+    */
+  def getNormRuleSet: Option[RuleSet]
+
   /**
-    * Returns whether this configuration replaces the built-in rule set.
+    * Returns whether this configuration replaces the built-in optimization rule set.
     */
-  def replacesRuleSet: Boolean
+  def replacesOptRuleSet: Boolean
 
   /**
-    * Returns a custom rule set.
+    * Returns a custom optimization rule set.
     */
-  def getRuleSet: Option[RuleSet]
+  def getOptRuleSet: Option[RuleSet]
 
   /**
     * Returns whether this configuration replaces the built-in SQL operator table.

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala
index 245a038..9b3ff63 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala
@@ -24,7 +24,6 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
 import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
 import org.apache.calcite.rex._
 import org.apache.flink.api.common.functions.FlatMapFunction
-import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.DataSet
 import org.apache.flink.table.api.BatchTableEnvironment
 import org.apache.flink.table.calcite.FlinkTypeFactory
@@ -78,9 +77,12 @@ class DataSetCalc(
 
     // compute number of expressions that do not access a field or literal, i.e. computations,
     //   conditions, etc. We only want to account for computations, not for simple projections.
+    // CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule
+    // in normalization stage. So we should ignore CASTs here in optimization stage.
     val compCnt = calcProgram.getExprList.asScala.toList.count {
       case i: RexInputRef => false
       case l: RexLiteral => false
+      case c: RexCall if c.getOperator.getName.equals("CAST") => false
       case _ => true
     }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index 0b60848..a24a06d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -28,6 +28,17 @@ import org.apache.flink.table.plan.rules.datastream.{DataStreamCalcRule, DataStr
 object FlinkRuleSets {
 
   /**
+    * RuleSet to normalize plans for batch / DataSet execution
+    */
+  val DATASET_NORM_RULES: RuleSet = RuleSets.ofList(
+    // simplify expressions rules
+    ReduceExpressionsRule.FILTER_INSTANCE,
+    ReduceExpressionsRule.PROJECT_INSTANCE,
+    ReduceExpressionsRule.CALC_INSTANCE,
+    ReduceExpressionsRule.JOIN_INSTANCE
+  )
+
+  /**
     * RuleSet to optimize plans for batch / DataSet execution
     */
   val DATASET_OPT_RULES: RuleSet = RuleSets.ofList(
@@ -75,12 +86,6 @@ object FlinkRuleSets {
     // remove unnecessary sort rule
     SortRemoveRule.INSTANCE,
 
-    // simplify expressions rules
-    ReduceExpressionsRule.FILTER_INSTANCE,
-    ReduceExpressionsRule.PROJECT_INSTANCE,
-    ReduceExpressionsRule.CALC_INSTANCE,
-    ReduceExpressionsRule.JOIN_INSTANCE,
-
     // prune empty results rules
     PruneEmptyRules.AGGREGATE_INSTANCE,
     PruneEmptyRules.FILTER_INSTANCE,
@@ -117,6 +122,16 @@ object FlinkRuleSets {
   )
 
   /**
+    * RuleSet to normalize plans for stream / DataStream execution
+    */
+  val DATASTREAM_NORM_RULES: RuleSet = RuleSets.ofList(
+    // simplify expressions rules
+    ReduceExpressionsRule.FILTER_INSTANCE,
+    ReduceExpressionsRule.PROJECT_INSTANCE,
+    ReduceExpressionsRule.CALC_INSTANCE
+  )
+
+  /**
   * RuleSet to optimize plans for stream / DataStream execution
   */
   val DATASTREAM_OPT_RULES: RuleSet = RuleSets.ofList(
@@ -142,11 +157,6 @@ object FlinkRuleSets {
       FilterProjectTransposeRule.INSTANCE,
       ProjectRemoveRule.INSTANCE,
 
-      // simplify expressions rules
-      ReduceExpressionsRule.FILTER_INSTANCE,
-      ReduceExpressionsRule.PROJECT_INSTANCE,
-      ReduceExpressionsRule.CALC_INSTANCE,
-
       // merge and push unions rules
       UnionEliminatorRule.INSTANCE,
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/TableEnvironmentITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/TableEnvironmentITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/TableEnvironmentITCase.java
index dece295..ebe79fa 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/TableEnvironmentITCase.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/TableEnvironmentITCase.java
@@ -482,7 +482,7 @@ public class TableEnvironmentITCase extends TableProgramsCollectionTestBase {
 		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
 		BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
 
-		CalciteConfig cc = new CalciteConfigBuilder().replaceRuleSet(RuleSets.ofList()).build();
+		CalciteConfig cc = new CalciteConfigBuilder().replaceOptRuleSet(RuleSets.ofList()).build();
 		tableEnv.getConfig().setCalciteConfig(cc);
 
 		DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala
index 708e007..aad3403 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/AggregationTest.scala
@@ -230,7 +230,9 @@ class AggregationTest extends TableTestBase {
     val calcNode = unaryNode(
       "DataSetCalc",
       batchTableNode(0),
-      term("select", "a", "b", "c"),
+      // ReduceExpressionsRule will add cast for Project node by force
+      // if the input of the Project node has constant expression.
+      term("select", "CAST(1) AS a", "b", "c"),
       term("where", "=(a, 1)")
     )
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CalciteConfigBuilderTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CalciteConfigBuilderTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CalciteConfigBuilderTest.scala
index e69bd11..6c07e28 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CalciteConfigBuilderTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CalciteConfigBuilderTest.scala
@@ -18,12 +18,12 @@
 
 package org.apache.flink.table
 
-import org.apache.calcite.rel.rules.{CalcSplitRule, CalcMergeRule, FilterMergeRule}
-import org.apache.calcite.sql.fun.{SqlStdOperatorTable, OracleSqlOperatorTable}
+import org.apache.calcite.rel.rules._
+import org.apache.calcite.sql.fun.{OracleSqlOperatorTable, SqlStdOperatorTable}
 import org.apache.calcite.tools.RuleSets
-import org.apache.flink.table.calcite.{CalciteConfigBuilder, CalciteConfig}
-import org.junit.Test
+import org.apache.flink.table.calcite.{CalciteConfig, CalciteConfigBuilder}
 import org.junit.Assert._
+import org.junit.Test
 
 import scala.collection.JavaConverters._
 
@@ -32,38 +32,117 @@ class CalciteConfigBuilderTest {
   @Test
   def testDefaultRules(): Unit = {
 
+    val cc: CalciteConfig = new CalciteConfigBuilder().build()
+
+    assertFalse(cc.replacesNormRuleSet)
+    assertFalse(cc.getNormRuleSet.isDefined)
+
+    assertFalse(cc.replacesOptRuleSet)
+    assertFalse(cc.getOptRuleSet.isDefined)
+  }
+
+  @Test
+  def testRules(): Unit = {
+
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .addNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.FILTER_INSTANCE))
+      .replaceOptRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
+      .build()
+
+    assertFalse(cc.replacesNormRuleSet)
+    assertTrue(cc.getNormRuleSet.isDefined)
+
+    assertTrue(cc.replacesOptRuleSet)
+    assertTrue(cc.getOptRuleSet.isDefined)
+  }
+
+  @Test
+  def testReplaceNormalizationRules(): Unit = {
+
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .replaceNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.FILTER_INSTANCE))
+      .build()
+
+    assertEquals(true, cc.replacesNormRuleSet)
+    assertTrue(cc.getNormRuleSet.isDefined)
+    val cSet = cc.getNormRuleSet.get.iterator().asScala.toSet
+    assertEquals(1, cSet.size)
+    assertTrue(cSet.contains(ReduceExpressionsRule.FILTER_INSTANCE))
+  }
+
+  @Test
+  def testReplaceNormalizationAddRules(): Unit = {
+
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .replaceNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.FILTER_INSTANCE))
+      .addNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.PROJECT_INSTANCE))
+      .build()
+
+    assertEquals(true, cc.replacesNormRuleSet)
+    assertTrue(cc.getNormRuleSet.isDefined)
+    val cSet = cc.getNormRuleSet.get.iterator().asScala.toSet
+    assertEquals(2, cSet.size)
+    assertTrue(cSet.contains(ReduceExpressionsRule.FILTER_INSTANCE))
+    assertTrue(cSet.contains(ReduceExpressionsRule.PROJECT_INSTANCE))
+  }
+
+  @Test
+  def testAddNormalizationRules(): Unit = {
+
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .addNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.FILTER_INSTANCE))
+      .build()
+
+    assertEquals(false, cc.replacesNormRuleSet)
+    assertTrue(cc.getNormRuleSet.isDefined)
+    val cSet = cc.getNormRuleSet.get.iterator().asScala.toSet
+    assertEquals(1, cSet.size)
+    assertTrue(cSet.contains(ReduceExpressionsRule.FILTER_INSTANCE))
+  }
+
+  @Test
+  def testAddAddNormalizationRules(): Unit = {
+
     val cc: CalciteConfig = new CalciteConfigBuilder()
+      .addNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.FILTER_INSTANCE))
+      .addNormRuleSet(RuleSets.ofList(ReduceExpressionsRule.PROJECT_INSTANCE,
+        ReduceExpressionsRule.CALC_INSTANCE))
       .build()
 
-    assertEquals(false, cc.replacesRuleSet)
-    assertFalse(cc.getRuleSet.isDefined)
+    assertEquals(false, cc.replacesNormRuleSet)
+    assertTrue(cc.getNormRuleSet.isDefined)
+    val cList = cc.getNormRuleSet.get.iterator().asScala.toList
+    assertEquals(3, cList.size)
+    assertEquals(cList.head, ReduceExpressionsRule.FILTER_INSTANCE)
+    assertEquals(cList(1), ReduceExpressionsRule.PROJECT_INSTANCE)
+    assertEquals(cList(2), ReduceExpressionsRule.CALC_INSTANCE)
   }
 
   @Test
-  def testReplaceRules(): Unit = {
+  def testReplaceOptimizationRules(): Unit = {
 
     val cc: CalciteConfig = new CalciteConfigBuilder()
-      .replaceRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
+      .replaceOptRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
       .build()
 
-    assertEquals(true, cc.replacesRuleSet)
-    assertTrue(cc.getRuleSet.isDefined)
-    val cSet = cc.getRuleSet.get.iterator().asScala.toSet
+    assertEquals(true, cc.replacesOptRuleSet)
+    assertTrue(cc.getOptRuleSet.isDefined)
+    val cSet = cc.getOptRuleSet.get.iterator().asScala.toSet
     assertEquals(1, cSet.size)
     assertTrue(cSet.contains(FilterMergeRule.INSTANCE))
   }
 
   @Test
-  def testReplaceAddRules(): Unit = {
+  def testReplaceOptimizationAddRules(): Unit = {
 
     val cc: CalciteConfig = new CalciteConfigBuilder()
-      .replaceRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
-      .addRuleSet(RuleSets.ofList(CalcMergeRule.INSTANCE, CalcSplitRule.INSTANCE))
+      .replaceOptRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
+      .addOptRuleSet(RuleSets.ofList(CalcMergeRule.INSTANCE, CalcSplitRule.INSTANCE))
       .build()
 
-    assertEquals(true, cc.replacesRuleSet)
-    assertTrue(cc.getRuleSet.isDefined)
-    val cSet = cc.getRuleSet.get.iterator().asScala.toSet
+    assertEquals(true, cc.replacesOptRuleSet)
+    assertTrue(cc.getOptRuleSet.isDefined)
+    val cSet = cc.getOptRuleSet.get.iterator().asScala.toSet
     assertEquals(3, cSet.size)
     assertTrue(cSet.contains(FilterMergeRule.INSTANCE))
     assertTrue(cSet.contains(CalcMergeRule.INSTANCE))
@@ -71,30 +150,30 @@ class CalciteConfigBuilderTest {
   }
 
   @Test
-  def testAddRules(): Unit = {
+  def testAddOptimizationRules(): Unit = {
 
     val cc: CalciteConfig = new CalciteConfigBuilder()
-      .addRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
+      .addOptRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
       .build()
 
-    assertEquals(false, cc.replacesRuleSet)
-    assertTrue(cc.getRuleSet.isDefined)
-    val cSet = cc.getRuleSet.get.iterator().asScala.toSet
+    assertEquals(false, cc.replacesOptRuleSet)
+    assertTrue(cc.getOptRuleSet.isDefined)
+    val cSet = cc.getOptRuleSet.get.iterator().asScala.toSet
     assertEquals(1, cSet.size)
     assertTrue(cSet.contains(FilterMergeRule.INSTANCE))
   }
 
   @Test
-  def testAddAddRules(): Unit = {
+  def testAddAddOptimizationRules(): Unit = {
 
     val cc: CalciteConfig = new CalciteConfigBuilder()
-      .addRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
-      .addRuleSet(RuleSets.ofList(CalcMergeRule.INSTANCE, CalcSplitRule.INSTANCE))
+      .addOptRuleSet(RuleSets.ofList(FilterMergeRule.INSTANCE))
+      .addOptRuleSet(RuleSets.ofList(CalcMergeRule.INSTANCE, CalcSplitRule.INSTANCE))
       .build()
 
-    assertEquals(false, cc.replacesRuleSet)
-    assertTrue(cc.getRuleSet.isDefined)
-    val cSet = cc.getRuleSet.get.iterator().asScala.toSet
+    assertEquals(false, cc.replacesOptRuleSet)
+    assertTrue(cc.getOptRuleSet.isDefined)
+    val cSet = cc.getOptRuleSet.get.iterator().asScala.toSet
     assertEquals(3, cSet.size)
     assertTrue(cSet.contains(FilterMergeRule.INSTANCE))
     assertTrue(cSet.contains(CalcMergeRule.INSTANCE))

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala
index 1f73427..8ce27e8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala
@@ -285,7 +285,9 @@ class MockTableEnvironment extends TableEnvironment(new TableConfig) {
 
   override protected def checkValidTableName(name: String): Unit = ???
 
-  override protected def getBuiltInRuleSet: RuleSet = ???
+  override protected def getBuiltInNormRuleSet: RuleSet = ???
+
+  override protected def getBuiltInOptRuleSet: RuleSet = ???
 
   override def registerTableSource(name: String, tableSource: TableSource[_]) = ???
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
index be98a89..d70a32a 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
@@ -79,7 +79,7 @@ class SetOperatorsTest extends TableTestBase {
               term("join", "b_long", "a_long"),
               term("joinType", "InnerJoin")
             ),
-            term("select", "true AS $f0", "a_long")
+            term("select", "a_long", "true AS $f0")
           ),
           term("groupBy", "a_long"),
           term("select", "a_long", "MIN($f0) AS $f1")

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
index a7da5b5..d053b9f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala
@@ -297,8 +297,8 @@ class FieldProjectionTest extends TableTestBase {
           term("groupBy", "word"),
           term("select", "word", "SUM(frequency) AS TMP_0")
         ),
-        term("select", "word, frequency"),
-        term("where", "=(frequency, 2)")
+        term("select", "word, TMP_0 AS frequency"),
+        term("where", "=(TMP_0, 2)")
       )
 
     util.verifyTable(resultTable, expected)

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
index b4327ec..679942c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.expressions.utils
 
+import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgramBuilder}
 import org.apache.calcite.rex.RexNode
 import org.apache.calcite.sql.`type`.SqlTypeName._
 import org.apache.calcite.sql2rel.RelDecorrelator
@@ -58,6 +59,16 @@ abstract class ExpressionTestBase {
     context._2.getTypeFactory)
   private val optProgram = Programs.ofRules(FlinkRuleSets.DATASET_OPT_RULES)
 
+  private def hepPlanner = {
+    val builder = new HepProgramBuilder
+    builder.addMatchOrder(HepMatchOrder.BOTTOM_UP)
+    val it = FlinkRuleSets.DATASET_NORM_RULES.iterator()
+    while (it.hasNext) {
+      builder.addRuleInstance(it.next())
+    }
+    new HepPlanner(builder.build, context._2.getFrameworkConfig.getContext)
+  }
+
   private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = {
     // create DataSetTable
     val dataSetMock = mock(classOf[DataSet[Any]])
@@ -140,10 +151,20 @@ abstract class ExpressionTestBase {
     val validated = planner.validate(parsed)
     val converted = planner.rel(validated).rel
 
-    // create DataSetCalc
     val decorPlan = RelDecorrelator.decorrelateQuery(converted)
+
+    // normalize
+    val normalizedPlan = if (FlinkRuleSets.DATASET_NORM_RULES.iterator().hasNext) {
+      val planner = hepPlanner
+      planner.setRoot(decorPlan)
+      planner.findBestExp
+    } else {
+      decorPlan
+    }
+
+    // create DataSetCalc
     val flinkOutputProps = converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
-    val dataSetCalc = optProgram.run(context._2.getPlanner, decorPlan, flinkOutputProps)
+    val dataSetCalc = optProgram.run(context._2.getPlanner, normalizedPlan, flinkOutputProps)
 
     // extract RexNode
     val calcProgram = dataSetCalc

http://git-wip-us.apache.org/repos/asf/flink/blob/8efacf58/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
new file mode 100644
index 0000000..8b6d6cf
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/NormalizationRulesTest.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.rules
+
+import org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule
+import org.apache.calcite.tools.RuleSets
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.calcite.{CalciteConfig, CalciteConfigBuilder}
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+class NormalizationRulesTest extends TableTestBase {
+
+  @Test
+  def testApplyNormalizationRuleForForBatchSQL(): Unit = {
+    val util = batchTestUtil()
+
+    // rewrite distinct aggregate
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .replaceNormRuleSet(RuleSets.ofList(AggregateExpandDistinctAggregatesRule.JOIN))
+      .replaceOptRuleSet(RuleSets.ofList())
+      .build()
+    util.tEnv.getConfig.setCalciteConfig(cc)
+
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT " +
+      "COUNT(DISTINCT a)" +
+      "FROM MyTable group by b"
+
+    // expect double aggregate
+    val expected = unaryNode("LogicalProject",
+      unaryNode("LogicalAggregate",
+        unaryNode("LogicalAggregate",
+          unaryNode("LogicalProject",
+            values("LogicalTableScan", term("table", "[MyTable]")),
+            term("b", "$1"), term("a", "$0")),
+          term("group", "{0, 1}")),
+        term("group", "{0}"), term("EXPR$0", "COUNT($1)")
+      ),
+      term("EXPR$0", "$1")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testApplyNormalizationRuleForForStreamSQL(): Unit = {
+    val util = streamTestUtil()
+
+    // rewrite distinct aggregate
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .replaceNormRuleSet(RuleSets.ofList(AggregateExpandDistinctAggregatesRule.JOIN))
+      .replaceOptRuleSet(RuleSets.ofList())
+      .build()
+    util.tEnv.getConfig.setCalciteConfig(cc)
+
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT " +
+      "COUNT(DISTINCT a)" +
+      "FROM MyTable group by b"
+
+    // expect double aggregate
+    val expected = unaryNode(
+      "LogicalProject",
+      unaryNode("LogicalAggregate",
+        unaryNode("LogicalAggregate",
+          unaryNode("LogicalProject",
+            values("LogicalTableScan", term("table", "[MyTable]")),
+            term("b", "$1"), term("a", "$0")),
+          term("group", "{0, 1}")),
+        term("group", "{0}"), term("EXPR$0", "COUNT($1)")
+      ),
+      term("EXPR$0", "$1")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+}