You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2019/07/30 13:20:05 UTC
[flink] branch release-1.9 updated: [FLINK-12249][table] Fix type
equivalence check problems for Window Aggregates
This is an automated email from the ASF dual-hosted git repository.
dwysakowicz pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push:
new 1de0053 [FLINK-12249][table] Fix type equivalence check problems for Window Aggregates
1de0053 is described below
commit 1de005392022404adbce4cd8b9de90f157052bd0
Author: hequn8128 <ch...@gmail.com>
AuthorDate: Fri Jul 19 16:53:17 2019 +0800
[FLINK-12249][table] Fix type equivalence check problems for Window Aggregates
This closes #9141
---
.../flink/sql/tests/StreamSQLTestProgram.java | 4 +-
.../logical/LogicalWindowAggregateRuleBase.scala | 100 ++++++++++++++++---
.../plan/batch/sql/agg/WindowAggregateTest.xml | 107 +++++++++++++++++++++
.../plan/stream/sql/agg/WindowAggregateTest.xml | 35 +++++++
.../plan/batch/sql/agg/WindowAggregateTest.scala | 22 +++++
.../plan/stream/sql/agg/WindowAggregateTest.scala | 21 ++++
.../rules/common/LogicalWindowAggregateRule.scala | 96 ++++++++++++++++--
.../table/api/batch/sql/GroupWindowTest.scala | 41 ++++++++
.../table/api/stream/sql/GroupWindowTest.scala | 43 +++++++++
9 files changed, 448 insertions(+), 21 deletions(-)
diff --git a/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java b/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java
index cde040d..47bca8e 100644
--- a/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java
+++ b/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java
@@ -106,9 +106,7 @@ public class StreamSQLTestProgram {
String tumbleQuery = String.format(
"SELECT " +
" key, " +
- //TODO: The "WHEN -1 THEN NULL" part is a temporary workaround, to make the test pass, for
- // https://issues.apache.org/jira/browse/FLINK-12249. We should remove it once the issue is fixed.
- " CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 WHEN -1 THEN NULL ELSE 99 END AS correct, " +
+ " CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 ELSE 99 END AS correct, " +
" TUMBLE_START(rowtime, INTERVAL '%d' SECOND) AS wStart, " +
" TUMBLE_ROWTIME(rowtime, INTERVAL '%d' SECOND) AS rowtime " +
"FROM (%s) " +
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala
index 6c24296..ee24adb 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala
@@ -33,8 +33,10 @@ import org.apache.calcite.plan._
import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Aggregate.Group
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject}
import org.apache.calcite.rex._
+import org.apache.calcite.sql.`type`.SqlTypeUtil
import org.apache.calcite.util.ImmutableBitSet
import _root_.java.math.BigDecimal
@@ -84,15 +86,31 @@ abstract class LogicalWindowAggregateRuleBase(description: String)
.project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression))
.build()
+ // Currently, this rule removes the window from GROUP BY operation which may lead to changes
+ // of AggCall's type which brings fails on type checks.
+ // To solve the problem, we change the types to the inferred types in the Aggregate and then
+ // cast back in the project after Aggregate.
+ val indexAndTypes = getIndexAndInferredTypesIfChanged(agg)
+ val finalCalls = adjustTypes(agg, indexAndTypes)
+
// we don't use the builder here because it uses RelMetadataQuery which affects the plan
val newAgg = LogicalAggregate.create(
newProject,
agg.indicator,
newGroupSet,
ImmutableList.of(newGroupSet),
- agg.getAggCallList)
+ finalCalls)
+
+ val transformed = call.builder()
+ val windowAgg = LogicalWindowAggregate.create(
+ window,
+ Seq[PlannerNamedWindowProperty](),
+ newAgg)
+ transformed.push(windowAgg)
- // create an additional project to conform with types
+ // The transformation adds an additional LogicalProject at the top to ensure
+ // that the types are equivalent.
+ // 1. ensure group key types, create an additional project to conform with types
val outAggGroupExpression0 = getOutAggregateGroupExpression(rexBuilder, windowExpr)
// fix up the nullability if it is changed.
val outAggGroupExpression = if (windowExpr.getType.isNullable !=
@@ -103,20 +121,80 @@ abstract class LogicalWindowAggregateRuleBase(description: String)
} else {
outAggGroupExpression0
}
- val transformed = call.builder()
- val windowAgg = LogicalWindowAggregate.create(
- window,
- Seq[PlannerNamedWindowProperty](),
- newAgg)
- // The transformation adds an additional LogicalProject at the top to ensure
- // that the types are equivalent.
- transformed.push(windowAgg)
- .project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0))
+ val projectsEnsureGroupKeyTypes =
+ transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0)
+ // 2. ensure aggCall types
+ val projectsEnsureAggCallTypes =
+ projectsEnsureGroupKeyTypes.zipWithIndex.map {
+ case (aggCall, index) =>
+ val aggCallIndex = index - agg.getGroupCount
+ if (indexAndTypes.containsKey(aggCallIndex)) {
+ rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true)
+ } else {
+ aggCall
+ }
+ }
+ transformed.project(projectsEnsureAggCallTypes)
val result = transformed.build()
call.transformTo(result)
}
+ /**
+ * Change the types of [[AggregateCall]] to the corresponding inferred types.
+ */
+ private def adjustTypes(
+ agg: LogicalAggregate,
+ indexAndTypes: Map[Int, RelDataType]) = {
+
+ agg.getAggCallList.zipWithIndex.map {
+ case (aggCall, index) =>
+ if (indexAndTypes.containsKey(index)) {
+ AggregateCall.create(
+ aggCall.getAggregation,
+ aggCall.isDistinct,
+ aggCall.isApproximate,
+ aggCall.ignoreNulls(),
+ aggCall.getArgList,
+ aggCall.filterArg,
+ aggCall.collation,
+ agg.getGroupCount,
+ agg.getInput,
+ indexAndTypes(index),
+ aggCall.name)
+ } else {
+ aggCall
+ }
+ }
+ }
+
+ /**
+ * Check if there are any types of [[AggregateCall]] that need to be changed. Return the
+ * [[AggregateCall]] indexes and the corresponding inferred types.
+ */
+ private def getIndexAndInferredTypesIfChanged(
+ agg: LogicalAggregate)
+ : Map[Int, RelDataType] = {
+
+ agg.getAggCallList.zipWithIndex.flatMap {
+ case (aggCall, index) =>
+ val origType = aggCall.`type`
+ val aggCallBinding = new Aggregate.AggCallBinding(
+ agg.getCluster.getTypeFactory,
+ aggCall.getAggregation,
+ SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList),
+ 0,
+ aggCall.hasFilter)
+ val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding)
+
+ if (origType != inferredType && agg.getGroupCount == 1) {
+ Some(index, inferredType)
+ } else {
+ None
+ }
+ }.toMap
+ }
+
private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = {
val project = agg.getInput.asInstanceOf[HepRelVertex].getCurrentRel.asInstanceOf[LogicalProject]
val groupKeys = agg.getGroupSet
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml
index 88ac82a..6286349 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml
@@ -1409,4 +1409,111 @@ Calc(select=[w$end AS EXPR$0])
]]>
</Resource>
</TestCase>
+ <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=AUTO]">
+ <Resource name="sql">
+ <![CDATA[
+SELECT
+ SUM(correct) AS s,
+ AVG(correct) AS a,
+ TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+FROM (
+ SELECT CASE a
+ WHEN 1 THEN 1
+ ELSE 99
+ END AS correct, b
+ FROM MyTable
+)
+GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+ +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1])
+ +- Exchange(distribution=[single])
+ +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1])
+ +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+ +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=ONE_PHASE]">
+ <Resource name="sql">
+ <![CDATA[
+SELECT
+ SUM(correct) AS s,
+ AVG(correct) AS a,
+ TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+FROM (
+ SELECT CASE a
+ WHEN 1 THEN 1
+ ELSE 99
+ END AS correct, b
+ FROM MyTable
+)
+GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+ +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1])
+ +- Exchange(distribution=[single])
+ +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+ +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=TWO_PHASE]">
+ <Resource name="sql">
+ <![CDATA[
+SELECT
+ SUM(correct) AS s,
+ AVG(correct) AS a,
+ TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+FROM (
+ SELECT CASE a
+ WHEN 1 THEN 1
+ ELSE 99
+ END AS correct, b
+ FROM MyTable
+)
+GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+ +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1])
+ +- Exchange(distribution=[single])
+ +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1])
+ +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1])
+ +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d])
+]]>
+ </Resource>
+ </TestCase>
</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml
index 5847055..9f10574e 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml
@@ -467,4 +467,39 @@ Calc(select=[EXPR$0, wAvg, w$start AS EXPR$2, w$end AS EXPR$3])
]]>
</Resource>
</TestCase>
+ <TestCase name="testReturnTypeInferenceForWindowAgg">
+ <Resource name="sql">
+ <![CDATA[
+SELECT
+ SUM(correct) AS s,
+ AVG(correct) AS a,
+ TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart
+FROM (
+ SELECT CASE a
+ WHEN 1 THEN 1
+ ELSE 99
+ END AS correct, rowtime
+ FROM MyTable
+)
+GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)])
++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)])
+ +- LogicalProject($f0=[TUMBLE($4, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart])
++- GroupWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime, w$proctime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime])
+ +- Exchange(distribution=[single])
+ +- Calc(select=[rowtime, CASE(=(a, 1), 1, 99) AS $f1])
+ +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime])
+]]>
+ </Resource>
+ </TestCase>
</Root>
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala
index be1ad8b..0b021b7 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala
@@ -300,6 +300,28 @@ class WindowAggregateTest(aggStrategy: AggregatePhaseStrategy) extends TableTest
""".stripMargin
util.verifyPlan(sql)
}
+
+ @Test
+ def testReturnTypeInferenceForWindowAgg() = {
+
+ val sql =
+ """
+ |SELECT
+ | SUM(correct) AS s,
+ | AVG(correct) AS a,
+ | TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart
+ |FROM (
+ | SELECT CASE a
+ | WHEN 1 THEN 1
+ | ELSE 99
+ | END AS correct, b
+ | FROM MyTable
+ |)
+ |GROUP BY TUMBLE(b, INTERVAL '15' MINUTE)
+ """.stripMargin
+
+ util.verifyPlan(sql)
+ }
}
object WindowAggregateTest {
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala
index 414450e..3c773ca 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala
@@ -295,4 +295,25 @@ class WindowAggregateTest extends TableTestBase {
util.verifyPlan(sql)
}
+ @Test
+ def testReturnTypeInferenceForWindowAgg() = {
+
+ val sql =
+ """
+ |SELECT
+ | SUM(correct) AS s,
+ | AVG(correct) AS a,
+ | TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart
+ |FROM (
+ | SELECT CASE a
+ | WHEN 1 THEN 1
+ | ELSE 99
+ | END AS correct, rowtime
+ | FROM MyTable
+ |)
+ |GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)
+ """.stripMargin
+
+ util.verifyPlan(sql)
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala
index 8c0f0c0..431fe9e 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala
@@ -21,8 +21,10 @@ import com.google.common.collect.ImmutableList
import org.apache.calcite.plan._
import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject}
import org.apache.calcite.rex._
+import org.apache.calcite.sql.`type`.SqlTypeUtil
import org.apache.calcite.util.ImmutableBitSet
import org.apache.flink.table.api._
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
@@ -78,24 +80,104 @@ abstract class LogicalWindowAggregateRule(ruleName: String)
.project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression))
.build()
+ // Currently, this rule removes the window from GROUP BY operation which may lead to changes
+ // of AggCall's type which brings fails on type checks.
+ // To solve the problem, we change the types to the inferred types in the Aggregate and then
+ // cast back in the project after Aggregate.
+ val indexAndTypes = getIndexAndInferredTypesIfChanged(agg)
+ val finalCalls = adjustTypes(agg, indexAndTypes)
+
// we don't use the builder here because it uses RelMetadataQuery which affects the plan
val newAgg = LogicalAggregate.create(
newProject,
agg.indicator,
newGroupSet,
ImmutableList.of(newGroupSet),
- agg.getAggCallList)
+ finalCalls)
- // create an additional project to conform with types
- val outAggGroupExpression = getOutAggregateGroupExpression(rexBuilder, windowExpr)
val transformed = call.builder()
- transformed.push(LogicalWindowAggregate.create(
+ val windowAgg = LogicalWindowAggregate.create(
window,
Seq[NamedWindowProperty](),
- newAgg))
- .project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0))
+ newAgg)
+ transformed.push(windowAgg)
- call.transformTo(transformed.build())
+ // The transformation adds an additional LogicalProject at the top to ensure
+ // that the types are equivalent.
+ // 1. ensure group key types, create an additional project to conform with types
+ val outAggGroupExpression = getOutAggregateGroupExpression(rexBuilder, windowExpr)
+ val projectsEnsureGroupKeyTypes =
+ transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0)
+ // 2. ensure aggCall types
+ val projectsEnsureAggCallTypes =
+ projectsEnsureGroupKeyTypes.zipWithIndex.map {
+ case (aggCall, index) =>
+ val aggCallIndex = index - agg.getGroupCount
+ if (indexAndTypes.containsKey(aggCallIndex)) {
+ rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true)
+ } else {
+ aggCall
+ }
+ }
+ transformed.project(projectsEnsureAggCallTypes)
+
+ val result = transformed.build()
+ call.transformTo(result)
+ }
+
+ /**
+ * Change the types of [[AggregateCall]] to the corresponding inferred types.
+ */
+ private def adjustTypes(
+ agg: LogicalAggregate,
+ indexAndTypes: Map[Int, RelDataType]) = {
+
+ agg.getAggCallList.zipWithIndex.map {
+ case (aggCall, index) =>
+ if (indexAndTypes.containsKey(index)) {
+ AggregateCall.create(
+ aggCall.getAggregation,
+ aggCall.isDistinct,
+ aggCall.isApproximate,
+ aggCall.ignoreNulls(),
+ aggCall.getArgList,
+ aggCall.filterArg,
+ aggCall.collation,
+ agg.getGroupCount,
+ agg.getInput,
+ indexAndTypes(index),
+ aggCall.name)
+ } else {
+ aggCall
+ }
+ }
+ }
+
+ /**
+ * Check if there are any types of [[AggregateCall]] that need to be changed. Return the
+ * [[AggregateCall]] indexes and the corresponding inferred types.
+ */
+ private def getIndexAndInferredTypesIfChanged(
+ agg: LogicalAggregate)
+ : Map[Int, RelDataType] = {
+
+ agg.getAggCallList.zipWithIndex.flatMap {
+ case (aggCall, index) =>
+ val origType = aggCall.`type`
+ val aggCallBinding = new Aggregate.AggCallBinding(
+ agg.getCluster.getTypeFactory,
+ aggCall.getAggregation,
+ SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList),
+ 0,
+ aggCall.hasFilter)
+ val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding)
+
+ if (origType != inferredType && agg.getGroupCount == 1) {
+ Some(index, inferredType)
+ } else {
+ None
+ }
+ }.toMap
}
private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = {
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala
index 07c4067..b5091ee 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala
@@ -346,4 +346,45 @@ class GroupWindowTest extends TableTestBase {
util.verifySql(sql, expected)
}
+
+ @Test
+ def testReturnTypeInferenceForWindowAgg(): Unit = {
+ val util = batchTestUtil()
+ val table = util.addTable[(Int, Long, String, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime)
+
+ val innerQuery =
+ """
+ |SELECT
+ | CASE a WHEN 1 THEN 1 ELSE 99 END AS correct,
+ | rowtime
+ |FROM MyTable
+ """.stripMargin
+
+ val sqlQuery =
+ "SELECT " +
+ " sum(correct) as s, " +
+ " avg(correct) as a, " +
+ " TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as wStart " +
+ s"FROM ($innerQuery) " +
+ "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)"
+
+ val expected =
+ unaryNode(
+ "DataSetCalc",
+ unaryNode(
+ "DataSetWindowAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(table),
+ term("select", "CASE(=(a, 1), 1, 99) AS correct, rowtime")
+ ),
+ term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"),
+ term("select", "SUM(correct) AS s, AVG(correct) AS a, start('w$) AS w$start," +
+ " end('w$) AS w$end, rowtime('w$) AS w$rowtime")
+ ),
+ term("select", "CAST(s) AS s", "CAST(a) AS a", "CAST(w$start) AS wStart")
+ )
+
+ util.verifySql(sqlQuery, expected)
+ }
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala
index a8c456f..5acef08 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala
@@ -301,4 +301,47 @@ class GroupWindowTest extends TableTestBase {
)
streamUtil.verifySql(sql, expected)
}
+
+ @Test
+ def testReturnTypeInferenceForWindowAgg() = {
+
+ val innerQuery =
+ """
+ |SELECT
+ | CASE a WHEN 1 THEN 1 ELSE 99 END AS correct,
+ | rowtime
+ |FROM MyTable
+ """.stripMargin
+
+ val sql =
+ "SELECT " +
+ " sum(correct) as s, " +
+ " avg(correct) as a, " +
+ " TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as wStart " +
+ s"FROM ($innerQuery) " +
+ "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)"
+
+ val expected =
+ unaryNode(
+ "DataStreamCalc",
+ unaryNode(
+ "DataStreamGroupWindowAggregate",
+ unaryNode(
+ "DataStreamCalc",
+ streamTableNode(table),
+ term("select", "CASE(=(a, 1), 1, 99) AS correct", "rowtime")
+ ),
+ term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"),
+ term("select",
+ "SUM(correct) AS s",
+ "AVG(correct) AS a",
+ "start('w$) AS w$start",
+ "end('w$) AS w$end",
+ "rowtime('w$) AS w$rowtime",
+ "proctime('w$) AS w$proctime")
+ ),
+ term("select", "CAST(s) AS s", "CAST(a) AS a", "w$start AS wStart")
+ )
+ streamUtil.verifySql(sql, expected)
+ }
}