You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2019/07/30 13:20:05 UTC

[flink] branch release-1.9 updated: [FLINK-12249][table] Fix type equivalence check problems for Window Aggregates

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.9 by this push:
     new 1de0053  [FLINK-12249][table] Fix type equivalence check problems for Window Aggregates
1de0053 is described below

commit 1de005392022404adbce4cd8b9de90f157052bd0
Author: hequn8128 <ch...@gmail.com>
AuthorDate: Fri Jul 19 16:53:17 2019 +0800

    [FLINK-12249][table] Fix type equivalence check problems for Window Aggregates
    
    This closes #9141
---
 .../flink/sql/tests/StreamSQLTestProgram.java      |   4 +-
 .../logical/LogicalWindowAggregateRuleBase.scala   | 100 ++++++++++++++++---
 .../plan/batch/sql/agg/WindowAggregateTest.xml     | 107 +++++++++++++++++++++
 .../plan/stream/sql/agg/WindowAggregateTest.xml    |  35 +++++++
 .../plan/batch/sql/agg/WindowAggregateTest.scala   |  22 +++++
 .../plan/stream/sql/agg/WindowAggregateTest.scala  |  21 ++++
 .../rules/common/LogicalWindowAggregateRule.scala  |  96 ++++++++++++++++--
 .../table/api/batch/sql/GroupWindowTest.scala      |  41 ++++++++
 .../table/api/stream/sql/GroupWindowTest.scala     |  43 +++++++++
 9 files changed, 448 insertions(+), 21 deletions(-)

diff --git a/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java b/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java
index cde040d..47bca8e 100644
--- a/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java
+++ b/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java
@@ -106,9 +106,7 @@ public class StreamSQLTestProgram {
 		String tumbleQuery = String.format(
 			"SELECT " +
 			"  key, " +
-			//TODO: The "WHEN -1 THEN NULL" part is a temporary workaround, to make the test pass, for
-			// https://issues.apache.org/jira/browse/FLINK-12249. We should remove it once the issue is fixed.
-			"  CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 WHEN -1 THEN NULL ELSE 99 END AS correct, " +
+			"  CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 ELSE 99 END AS correct, " +
 			"  TUMBLE_START(rowtime, INTERVAL '%d' SECOND) AS wStart, " +
 			"  TUMBLE_ROWTIME(rowtime, INTERVAL '%d' SECOND) AS rowtime " +
 			"FROM (%s) " +
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala
index 6c24296..ee24adb 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala
@@ -33,8 +33,10 @@ import org.apache.calcite.plan._
 import org.apache.calcite.plan.hep.HepRelVertex
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.Aggregate.Group
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
 import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject}
 import org.apache.calcite.rex._
+import org.apache.calcite.sql.`type`.SqlTypeUtil
 import org.apache.calcite.util.ImmutableBitSet
 
 import _root_.java.math.BigDecimal
@@ -84,15 +86,31 @@ abstract class LogicalWindowAggregateRuleBase(description: String)
       .project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression))
       .build()
 
+    // Currently, this rule removes the window from GROUP BY operation which may lead to changes
+    // of AggCall's type which brings fails on type checks.
+    // To solve the problem, we change the types to the inferred types in the Aggregate and then
+    // cast back in the project after Aggregate.
+    val indexAndTypes = getIndexAndInferredTypesIfChanged(agg)
+    val finalCalls = adjustTypes(agg, indexAndTypes)
+
     // we don't use the builder here because it uses RelMetadataQuery which affects the plan
     val newAgg = LogicalAggregate.create(
       newProject,
       agg.indicator,
       newGroupSet,
       ImmutableList.of(newGroupSet),
-      agg.getAggCallList)
+      finalCalls)
+
+    val transformed = call.builder()
+    val windowAgg = LogicalWindowAggregate.create(
+      window,
+      Seq[PlannerNamedWindowProperty](),
+      newAgg)
+    transformed.push(windowAgg)
 
-    // create an additional project to conform with types
+    // The transformation adds an additional LogicalProject at the top to ensure
+    // that the types are equivalent.
+    // 1. ensure group key types, create an additional project to conform with types
     val outAggGroupExpression0 = getOutAggregateGroupExpression(rexBuilder, windowExpr)
     // fix up the nullability if it is changed.
     val outAggGroupExpression = if (windowExpr.getType.isNullable !=
@@ -103,20 +121,80 @@ abstract class LogicalWindowAggregateRuleBase(description: String)
     } else {
       outAggGroupExpression0
     }
-    val transformed = call.builder()
-    val windowAgg = LogicalWindowAggregate.create(
-      window,
-      Seq[PlannerNamedWindowProperty](),
-      newAgg)
-    // The transformation adds an additional LogicalProject at the top to ensure
-    // that the types are equivalent.
-    transformed.push(windowAgg)
-      .project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0))
+    val projectsEnsureGroupKeyTypes =
+      transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0)
+    // 2. ensure aggCall types
+    val projectsEnsureAggCallTypes =
+      projectsEnsureGroupKeyTypes.zipWithIndex.map {
+        case (aggCall, index) =>
+          val aggCallIndex = index - agg.getGroupCount
+          if (indexAndTypes.containsKey(aggCallIndex)) {
+            rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true)
+          } else {
+            aggCall
+          }
+      }
+    transformed.project(projectsEnsureAggCallTypes)
 
     val result = transformed.build()
     call.transformTo(result)
   }
 
+  /**
+   * Change the types of [[AggregateCall]] to the corresponding inferred types.
+   */
+  private def adjustTypes(
+      agg: LogicalAggregate,
+      indexAndTypes: Map[Int, RelDataType]) = {
+
+    agg.getAggCallList.zipWithIndex.map {
+      case (aggCall, index) =>
+        if (indexAndTypes.containsKey(index)) {
+          AggregateCall.create(
+            aggCall.getAggregation,
+            aggCall.isDistinct,
+            aggCall.isApproximate,
+            aggCall.ignoreNulls(),
+            aggCall.getArgList,
+            aggCall.filterArg,
+            aggCall.collation,
+            agg.getGroupCount,
+            agg.getInput,
+            indexAndTypes(index),
+            aggCall.name)
+        } else {
+          aggCall
+        }
+    }
+  }
+
+  /**
+   * Check if there are any types of [[AggregateCall]] that need to be changed. Return the
+   * [[AggregateCall]] indexes and the corresponding inferred types.
+   */
+  private def getIndexAndInferredTypesIfChanged(
+      agg: LogicalAggregate)
+    : Map[Int, RelDataType] = {
+
+    agg.getAggCallList.zipWithIndex.flatMap {
+      case (aggCall, index) =>
+        val origType = aggCall.`type`
+        val aggCallBinding = new Aggregate.AggCallBinding(
+          agg.getCluster.getTypeFactory,
+          aggCall.getAggregation,
+          SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList),
+          0,
+          aggCall.hasFilter)
+        val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding)
+
+        if (origType != inferredType && agg.getGroupCount == 1) {
+          Some(index, inferredType)
+        } else {
+          None
+        }
+    }.toMap
+  }
+
   private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = {
     val project = agg.getInput.asInstanceOf[HepRelVertex].getCurrentRel.asInstanceOf[LogicalProject]
     val groupKeys = agg.getGroupSet
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml
index 88ac82a..6286349 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml
@@ -1409,4 +1409,111 @@ Calc(select=[w$end AS EXPR$0])
 ]]>
     </Resource>
   </TestCase>
+    <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=AUTO]">
+        <Resource name="sql">
+            <![CDATA[
+SELECT
+  SUM(correct) AS s,
+  AVG(correct) AS a,
+  TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+FROM (
+  SELECT CASE a
+      WHEN 1 THEN 1
+      ELSE 99
+    END AS correct, b
+  FROM MyTable
+)
+GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+      ]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+   +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+      +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1])
+   +- Exchange(distribution=[single])
+      +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1])
+         +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+            +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=ONE_PHASE]">
+        <Resource name="sql">
+            <![CDATA[
+SELECT
+  SUM(correct) AS s,
+  AVG(correct) AS a,
+  TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+FROM (
+  SELECT CASE a
+      WHEN 1 THEN 1
+      ELSE 99
+    END AS correct, b
+  FROM MyTable
+)
+GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+      ]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+   +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+      +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1])
+   +- Exchange(distribution=[single])
+      +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+         +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=TWO_PHASE]">
+        <Resource name="sql">
+            <![CDATA[
+SELECT
+  SUM(correct) AS s,
+  AVG(correct) AS a,
+  TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+FROM (
+  SELECT CASE a
+      WHEN 1 THEN 1
+      ELSE 99
+    END AS correct, b
+  FROM MyTable
+)
+GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+      ]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+   +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+      +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1])
+   +- Exchange(distribution=[single])
+      +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1])
+         +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+            +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+]]>
+        </Resource>
+    </TestCase>
 </Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml
index 5847055..9f10574e 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml
@@ -467,4 +467,39 @@ Calc(select=[EXPR$0, wAvg, w$start AS EXPR$2, w$end AS EXPR$3])
 ]]>
     </Resource>
   </TestCase>
+    <TestCase name="testReturnTypeInferenceForWindowAgg">
+        <Resource name="sql">
+            <![CDATA[
+SELECT
+  SUM(correct) AS s,
+  AVG(correct) AS a,
+  TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart
+FROM (
+  SELECT CASE a
+      WHEN 1 THEN 1
+      ELSE 99
+    END AS correct, rowtime
+  FROM MyTable
+)
+GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)
+      ]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+   +- LogicalProject($f0=[TUMBLE($4, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+      +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- GroupWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime, w$proctime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime])
+   +- Exchange(distribution=[single])
+      +- Calc(select=[rowtime, CASE(=(a, 1), 1, 99) AS $f1])
+         +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
+]]>
+        </Resource>
+    </TestCase>
 </Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala
index be1ad8b..0b021b7 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala
@@ -300,6 +300,28 @@ class WindowAggregateTest(aggStrategy: AggregatePhaseStrategy) extends TableTest
       """.stripMargin
     util.verifyPlan(sql)
   }
+
+  @Test
+  def testReturnTypeInferenceForWindowAgg() = {
+
+    val sql =
+      """
+        |SELECT
+        |  SUM(correct) AS s,
+        |  AVG(correct) AS a,
+        |  TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+        |FROM (
+        |  SELECT CASE a
+        |      WHEN 1 THEN 1
+        |      ELSE 99
+        |    END AS correct, b
+        |  FROM MyTable
+        |)
+        |GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+      """.stripMargin
+
+    util.verifyPlan(sql)
+  }
 }
 
 object WindowAggregateTest {
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala
index 414450e..3c773ca 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala
@@ -295,4 +295,25 @@ class WindowAggregateTest extends TableTestBase {
     util.verifyPlan(sql)
   }
 
+  @Test
+  def testReturnTypeInferenceForWindowAgg() = {
+
+    val sql =
+      """
+        |SELECT
+        |  SUM(correct) AS s,
+        |  AVG(correct) AS a,
+        |  TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart
+        |FROM (
+        |  SELECT CASE a
+        |      WHEN 1 THEN 1
+        |      ELSE 99
+        |    END AS correct, rowtime
+        |  FROM MyTable
+        |)
+        |GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)
+      """.stripMargin
+
+    util.verifyPlan(sql)
+  }
 }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala
index 8c0f0c0..431fe9e 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala
@@ -21,8 +21,10 @@ import com.google.common.collect.ImmutableList
 import org.apache.calcite.plan._
 import org.apache.calcite.plan.hep.HepRelVertex
 import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
 import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject}
 import org.apache.calcite.rex._
+import org.apache.calcite.sql.`type`.SqlTypeUtil
 import org.apache.calcite.util.ImmutableBitSet
 import org.apache.flink.table.api._
 import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
@@ -78,24 +80,104 @@ abstract class LogicalWindowAggregateRule(ruleName: String)
       .project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression))
       .build()
 
+    // Currently, this rule removes the window from GROUP BY operation which may lead to changes
+    // of AggCall's type which brings fails on type checks.
+    // To solve the problem, we change the types to the inferred types in the Aggregate and then
+    // cast back in the project after Aggregate.
+    val indexAndTypes = getIndexAndInferredTypesIfChanged(agg)
+    val finalCalls = adjustTypes(agg, indexAndTypes)
+
     // we don't use the builder here because it uses RelMetadataQuery which affects the plan
     val newAgg = LogicalAggregate.create(
       newProject,
       agg.indicator,
       newGroupSet,
       ImmutableList.of(newGroupSet),
-      agg.getAggCallList)
+      finalCalls)
 
-    // create an additional project to conform with types
-    val outAggGroupExpression = getOutAggregateGroupExpression(rexBuilder, windowExpr)
     val transformed = call.builder()
-    transformed.push(LogicalWindowAggregate.create(
+    val windowAgg = LogicalWindowAggregate.create(
       window,
       Seq[NamedWindowProperty](),
-      newAgg))
-      .project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0))
+      newAgg)
+    transformed.push(windowAgg)
 
-    call.transformTo(transformed.build())
+    // The transformation adds an additional LogicalProject at the top to ensure
+    // that the types are equivalent.
+    // 1. ensure group key types, create an additional project to conform with types
+    val outAggGroupExpression = getOutAggregateGroupExpression(rexBuilder, windowExpr)
+    val projectsEnsureGroupKeyTypes =
+    transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0)
+    // 2. ensure aggCall types
+    val projectsEnsureAggCallTypes =
+      projectsEnsureGroupKeyTypes.zipWithIndex.map {
+        case (aggCall, index) =>
+          val aggCallIndex = index - agg.getGroupCount
+          if (indexAndTypes.containsKey(aggCallIndex)) {
+            rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true)
+          } else {
+            aggCall
+          }
+      }
+    transformed.project(projectsEnsureAggCallTypes)
+
+    val result = transformed.build()
+    call.transformTo(result)
+  }
+
+  /**
+   * Change the types of [[AggregateCall]] to the corresponding inferred types.
+   */
+  private def adjustTypes(
+      agg: LogicalAggregate,
+      indexAndTypes: Map[Int, RelDataType]) = {
+
+    agg.getAggCallList.zipWithIndex.map {
+      case (aggCall, index) =>
+        if (indexAndTypes.containsKey(index)) {
+          AggregateCall.create(
+            aggCall.getAggregation,
+            aggCall.isDistinct,
+            aggCall.isApproximate,
+            aggCall.ignoreNulls(),
+            aggCall.getArgList,
+            aggCall.filterArg,
+            aggCall.collation,
+            agg.getGroupCount,
+            agg.getInput,
+            indexAndTypes(index),
+            aggCall.name)
+        } else {
+          aggCall
+        }
+    }
+  }
+
+  /**
+   * Check if there are any types of [[AggregateCall]] that need to be changed. Return the
+   * [[AggregateCall]] indexes and the corresponding inferred types.
+   */
+  private def getIndexAndInferredTypesIfChanged(
+      agg: LogicalAggregate)
+    : Map[Int, RelDataType] = {
+
+    agg.getAggCallList.zipWithIndex.flatMap {
+      case (aggCall, index) =>
+        val origType = aggCall.`type`
+        val aggCallBinding = new Aggregate.AggCallBinding(
+          agg.getCluster.getTypeFactory,
+          aggCall.getAggregation,
+          SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList),
+          0,
+          aggCall.hasFilter)
+        val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding)
+
+        if (origType != inferredType && agg.getGroupCount == 1) {
+          Some(index, inferredType)
+        } else {
+          None
+        }
+    }.toMap
   }
 
   private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = {
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala
index 07c4067..b5091ee 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala
@@ -346,4 +346,45 @@ class GroupWindowTest extends TableTestBase {
 
     util.verifySql(sql, expected)
   }
+
+  @Test
+  def testReturnTypeInferenceForWindowAgg(): Unit = {
+    val util = batchTestUtil()
+    val table = util.addTable[(Int, Long, String, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime)
+
+    val innerQuery =
+      """
+        |SELECT
+        | CASE a WHEN 1 THEN 1 ELSE 99 END AS correct,
+        | rowtime
+        |FROM MyTable
+      """.stripMargin
+
+    val sqlQuery =
+      "SELECT " +
+        "  sum(correct) as s, " +
+        "  avg(correct) as a, " +
+        "  TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as wStart " +
+        s"FROM ($innerQuery) " +
+        "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)"
+
+    val expected =
+      unaryNode(
+        "DataSetCalc",
+        unaryNode(
+          "DataSetWindowAggregate",
+          unaryNode(
+            "DataSetCalc",
+            batchTableNode(table),
+            term("select", "CASE(=(a, 1), 1, 99) AS correct, rowtime")
+          ),
+          term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"),
+          term("select", "SUM(correct) AS s, AVG(correct) AS a, start('w$) AS w$start," +
+            " end('w$) AS w$end, rowtime('w$) AS w$rowtime")
+        ),
+        term("select", "CAST(s) AS s", "CAST(a) AS a", "CAST(w$start) AS wStart")
+      )
+
+    util.verifySql(sqlQuery, expected)
+  }
 }
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala
index a8c456f..5acef08 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala
@@ -301,4 +301,47 @@ class GroupWindowTest extends TableTestBase {
       )
     streamUtil.verifySql(sql, expected)
   }
+
+  @Test
+  def testReturnTypeInferenceForWindowAgg() = {
+
+    val innerQuery =
+      """
+        |SELECT
+        | CASE a WHEN 1 THEN 1 ELSE 99 END AS correct,
+        | rowtime
+        |FROM MyTable
+      """.stripMargin
+
+    val sql =
+      "SELECT " +
+        "  sum(correct) as s, " +
+        "  avg(correct) as a, " +
+        "  TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as wStart " +
+        s"FROM ($innerQuery) " +
+        "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)"
+
+    val expected =
+      unaryNode(
+        "DataStreamCalc",
+        unaryNode(
+          "DataStreamGroupWindowAggregate",
+          unaryNode(
+            "DataStreamCalc",
+            streamTableNode(table),
+            term("select", "CASE(=(a, 1), 1, 99) AS correct", "rowtime")
+          ),
+          term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"),
+          term("select",
+            "SUM(correct) AS s",
+            "AVG(correct) AS a",
+            "start('w$) AS w$start",
+            "end('w$) AS w$end",
+            "rowtime('w$) AS w$rowtime",
+            "proctime('w$) AS w$proctime")
+        ),
+        term("select", "CAST(s) AS s", "CAST(a) AS a", "w$start AS wStart")
+      )
+    streamUtil.verifySql(sql, expected)
+  }
 }