You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/08/06 12:47:17 UTC
[flink] 02/02: [FLINK-27619][sql] fix use wrong constant for over window.
This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 8ac7e9724a60d9f90ede81b8e1bb13b116623ff3
Author: luoyuxia <lu...@alumni.sjtu.edu.cn>
AuthorDate: Sun Jul 24 12:34:02 2022 +0800
[FLINK-27619][sql] fix use wrong constant for over window.
---
.../batch/BatchPhysicalOverAggregateRule.scala | 80 +++++++++++++++++++---
.../plan/batch/sql/agg/OverAggregateTest.xml | 29 ++++++++
.../plan/batch/sql/agg/OverAggregateTest.scala | 12 ++++
3 files changed, 111 insertions(+), 10 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
index 0449fe849bc..6d20ce229ec 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
@@ -29,14 +29,18 @@ import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.planner.utils.ShortcutUtils
-import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall}
+import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelOptUtil}
import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.rel._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.core.Window.Group
+import org.apache.calcite.rex.{RexInputRef, RexNode, RexShuttle}
+import org.apache.calcite.sql.SqlAggFunction
import org.apache.calcite.tools.ValidationException
+import java.util
+
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -55,15 +59,10 @@ class BatchPhysicalOverAggregateRule
val logicWindow: FlinkLogicalOverAggregate = call.rel(0)
var input: RelNode = call.rel(1)
var inputRowType = logicWindow.getInput.getRowType
+ val originInputSize = inputRowType.getFieldCount
val typeFactory = logicWindow.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
val constants = logicWindow.constants.asScala
- val constantTypes = constants.map(c => FlinkTypeFactory.toLogicalType(c.getType))
- val inputNamesWithConstants = inputRowType.getFieldNames ++ constants.indices.map(i => s"TMP$i")
- val inputTypesWithConstants = inputRowType.getFieldList
- .map(i => FlinkTypeFactory.toLogicalType(i.getType)) ++ constantTypes
- val inputTypeWithConstants =
- typeFactory.buildRelNodeRowType(inputNamesWithConstants, inputTypesWithConstants)
var overWindowAgg: BatchPhysicalOverAggregateBase = null
@@ -92,12 +91,16 @@ class BatchPhysicalOverAggregateRule
val newInput = RelOptRule.convert(input, requiredTrait)
- val groupToAggCallToAggFunction = groupBuffer.map {
- group =>
+ val groupToAggCallToAggFunction = groupBuffer.zipWithIndex.map {
+ case (_, idx) =>
+ // we may need to adjust the arg index of AggregateCall in the group
+ // for the input's size may change
+ adjustGroup(groupBuffer, idx, originInputSize, newInput.getRowType.getFieldCount)
+ val group = groupBuffer.get(idx)
val aggregateCalls = group.getAggregateCalls(logicWindow)
val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions(
ShortcutUtils.unwrapTypeFactory(input),
- FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
+ FlinkTypeFactory.toLogicalRowType(generateInputTypeWithConstants()),
aggregateCalls,
sortSpec.getFieldIndices)
val aggCallToAggFunction = aggregateCalls.zip(aggregates)
@@ -152,6 +155,15 @@ class BatchPhysicalOverAggregateRule
inputRowType = outputRowType
}
+ def generateInputTypeWithConstants(): RelDataType = {
+ val constantTypes = constants.map(c => FlinkTypeFactory.toLogicalType(c.getType))
+ val inputNamesWithConstants = inputRowType.getFieldNames ++
+ constants.indices.map(i => s"TMP$i")
+ val inputTypesWithConstants = inputRowType.getFieldList
+ .map(i => FlinkTypeFactory.toLogicalType(i.getType)) ++ constantTypes
+ typeFactory.buildRelNodeRowType(inputNamesWithConstants, inputTypesWithConstants)
+ }
+
logicWindow.groups.foreach {
group =>
validate(group)
@@ -202,6 +214,54 @@ class BatchPhysicalOverAggregateRule
typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++ aggNames)
}
+ private def adjustGroup(
+ groupBuffer: ArrayBuffer[Window.Group],
+ groupIdx: Int,
+ originInputSize: Int,
+ newInputSize: Int): Unit = {
+ val inputSizeDiff = newInputSize - originInputSize
+ if (inputSizeDiff > 0) {
+ // the input's size of this group has increased, adjust the arg index of agg call
+ // in the group to make sure the arg index still refers to the origin value
+ var hasAdjust = false
+ val indexAdjustment = new RexShuttle() {
+ override def visitInputRef(inputRef: RexInputRef): RexNode = {
+ if (inputRef.getIndex >= originInputSize) {
+ hasAdjust = true
+ new RexInputRef(inputRef.getIndex + inputSizeDiff, inputRef.getType)
+ } else {
+ inputRef
+ }
+ }
+ }
+ val group = groupBuffer.get(groupIdx)
+ val newAggCalls = new util.ArrayList[Window.RexWinAggCall]()
+ group.aggCalls.forEach(
+ aggCall => {
+ val newOperands = indexAdjustment.visitList(aggCall.operands);
+ newAggCalls.add(
+ new Window.RexWinAggCall(
+ aggCall.getOperator.asInstanceOf[SqlAggFunction],
+ aggCall.getType,
+ newOperands,
+ aggCall.ordinal,
+ aggCall.distinct,
+ aggCall.ignoreNulls))
+ })
+ if (hasAdjust) {
+ groupBuffer.set(
+ groupIdx,
+ new Group(
+ group.keys,
+ group.isRows,
+ group.lowerBound,
+ group.upperBound,
+ group.orderKeys,
+ newAggCalls))
+ }
+ }
+ }
+
// SPARK/PostgreSQL don't support distinct on over(), and Hive only support distinct without
// window frame. Because it is complicated for Distinct on over().
private def validate(group: Group): Unit = {
diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
index a3c3538a66d..c58b35754ae 100644
--- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
+++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
@@ -418,6 +418,35 @@ Calc(select=[CASE((w0$o0 > 0), w0$o1, null:INTEGER) AS EXPR$0, w0$o2 AS EXPR$1])
]]>
</Resource>
</TestCase>
+ <TestCase name="testOverWindowWithConstants3">
+ <Resource name="sql">
+ <![CDATA[
+SELECT
+ COUNT(2) OVER (ORDER BY a),
+ COUNT(1) OVER (PARTITION BY c ORDER BY a)
+FROM MyTable
+ ]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[COUNT(2) OVER (ORDER BY $0 NULLS FIRST)], EXPR$1=[COUNT(1) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST)])
++- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]])
+]]>
+ </Resource>
+ <Resource name="optimized exec plan">
+ <![CDATA[
+Calc(select=[w0$o0 AS $0, w1$o0 AS $1])
++- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[COUNT(1) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, w0$o0, w1$o0])
+ +- Sort(orderBy=[c ASC, a ASC])
+ +- Exchange(distribution=[hash[c]])
+ +- OverAggregate(orderBy=[a ASC], window#0=[COUNT(2) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, w0$o0])
+ +- Sort(orderBy=[a ASC])
+ +- Exchange(distribution=[single])
+ +- Calc(select=[a, c])
+ +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testSamePartitionKeysWithDiffOrderKeys1">
<Resource name="sql">
<![CDATA[
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
index 35401b41367..d0753274fc0 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
@@ -308,6 +308,18 @@ class OverAggregateTest extends TableTestBase {
util.verifyExecPlan(sqlQuery)
}
+ @Test
+ def testOverWindowWithConstants3(): Unit = {
+ val sqlQuery =
+ """
+ |SELECT
+ | COUNT(2) OVER (ORDER BY a),
+ | COUNT(1) OVER (PARTITION BY c ORDER BY a)
+ |FROM MyTable
+ """.stripMargin
+ util.verifyExecPlan(sqlQuery)
+ }
+
@Test(expected = classOf[RuntimeException])
def testDistinct(): Unit = {
val sqlQuery =