You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2022/08/06 12:47:17 UTC

[flink] 02/02: [FLINK-27619][sql] fix use wrong constant for over window.

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 8ac7e9724a60d9f90ede81b8e1bb13b116623ff3
Author: luoyuxia <lu...@alumni.sjtu.edu.cn>
AuthorDate: Sun Jul 24 12:34:02 2022 +0800

    [FLINK-27619][sql] fix use wrong constant for over window.
---
 .../batch/BatchPhysicalOverAggregateRule.scala     | 80 +++++++++++++++++++---
 .../plan/batch/sql/agg/OverAggregateTest.xml       | 29 ++++++++
 .../plan/batch/sql/agg/OverAggregateTest.scala     | 12 ++++
 3 files changed, 111 insertions(+), 10 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
index 0449fe849bc..6d20ce229ec 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
@@ -29,14 +29,18 @@ import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
 import org.apache.flink.table.planner.typeutils.RowTypeUtils
 import org.apache.flink.table.planner.utils.ShortcutUtils
 
-import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall}
+import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelOptUtil}
 import org.apache.calcite.plan.RelOptRule._
 import org.apache.calcite.rel._
 import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.{AggregateCall, Window}
 import org.apache.calcite.rel.core.Window.Group
+import org.apache.calcite.rex.{RexInputRef, RexNode, RexShuttle}
+import org.apache.calcite.sql.SqlAggFunction
 import org.apache.calcite.tools.ValidationException
 
+import java.util
+
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
@@ -55,15 +59,10 @@ class BatchPhysicalOverAggregateRule
     val logicWindow: FlinkLogicalOverAggregate = call.rel(0)
     var input: RelNode = call.rel(1)
     var inputRowType = logicWindow.getInput.getRowType
+    val originInputSize = inputRowType.getFieldCount
     val typeFactory = logicWindow.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
 
     val constants = logicWindow.constants.asScala
-    val constantTypes = constants.map(c => FlinkTypeFactory.toLogicalType(c.getType))
-    val inputNamesWithConstants = inputRowType.getFieldNames ++ constants.indices.map(i => s"TMP$i")
-    val inputTypesWithConstants = inputRowType.getFieldList
-      .map(i => FlinkTypeFactory.toLogicalType(i.getType)) ++ constantTypes
-    val inputTypeWithConstants =
-      typeFactory.buildRelNodeRowType(inputNamesWithConstants, inputTypesWithConstants)
 
     var overWindowAgg: BatchPhysicalOverAggregateBase = null
 
@@ -92,12 +91,16 @@ class BatchPhysicalOverAggregateRule
 
       val newInput = RelOptRule.convert(input, requiredTrait)
 
-      val groupToAggCallToAggFunction = groupBuffer.map {
-        group =>
+      val groupToAggCallToAggFunction = groupBuffer.zipWithIndex.map {
+        case (_, idx) =>
+          // we may need to adjust the arg index of AggregateCall in the group
+          // for the input's size may change
+          adjustGroup(groupBuffer, idx, originInputSize, newInput.getRowType.getFieldCount)
+          val group = groupBuffer.get(idx)
           val aggregateCalls = group.getAggregateCalls(logicWindow)
           val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions(
             ShortcutUtils.unwrapTypeFactory(input),
-            FlinkTypeFactory.toLogicalRowType(inputTypeWithConstants),
+            FlinkTypeFactory.toLogicalRowType(generateInputTypeWithConstants()),
             aggregateCalls,
             sortSpec.getFieldIndices)
           val aggCallToAggFunction = aggregateCalls.zip(aggregates)
@@ -152,6 +155,15 @@ class BatchPhysicalOverAggregateRule
       inputRowType = outputRowType
     }
 
+    def generateInputTypeWithConstants(): RelDataType = {
+      val constantTypes = constants.map(c => FlinkTypeFactory.toLogicalType(c.getType))
+      val inputNamesWithConstants = inputRowType.getFieldNames ++
+        constants.indices.map(i => s"TMP$i")
+      val inputTypesWithConstants = inputRowType.getFieldList
+        .map(i => FlinkTypeFactory.toLogicalType(i.getType)) ++ constantTypes
+      typeFactory.buildRelNodeRowType(inputNamesWithConstants, inputTypesWithConstants)
+    }
+
     logicWindow.groups.foreach {
       group =>
         validate(group)
@@ -202,6 +214,54 @@ class BatchPhysicalOverAggregateRule
     typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++ aggNames)
   }
 
+  private def adjustGroup(
+      groupBuffer: ArrayBuffer[Window.Group],
+      groupIdx: Int,
+      originInputSize: Int,
+      newInputSize: Int): Unit = {
+    val inputSizeDiff = newInputSize - originInputSize
+    if (inputSizeDiff > 0) {
+      // the input's size of this group has increased, adjust the arg index of agg call
+      // in the group to make sure the arg index still refers to the origin value
+      var hasAdjust = false
+      val indexAdjustment = new RexShuttle() {
+        override def visitInputRef(inputRef: RexInputRef): RexNode = {
+          if (inputRef.getIndex >= originInputSize) {
+            hasAdjust = true
+            new RexInputRef(inputRef.getIndex + inputSizeDiff, inputRef.getType)
+          } else {
+            inputRef
+          }
+        }
+      }
+      val group = groupBuffer.get(groupIdx)
+      val newAggCalls = new util.ArrayList[Window.RexWinAggCall]()
+      group.aggCalls.forEach(
+        aggCall => {
+          val newOperands = indexAdjustment.visitList(aggCall.operands);
+          newAggCalls.add(
+            new Window.RexWinAggCall(
+              aggCall.getOperator.asInstanceOf[SqlAggFunction],
+              aggCall.getType,
+              newOperands,
+              aggCall.ordinal,
+              aggCall.distinct,
+              aggCall.ignoreNulls))
+        })
+      if (hasAdjust) {
+        groupBuffer.set(
+          groupIdx,
+          new Group(
+            group.keys,
+            group.isRows,
+            group.lowerBound,
+            group.upperBound,
+            group.orderKeys,
+            newAggCalls))
+      }
+    }
+  }
+
   // SPARK/PostgreSQL don't support distinct on over(), and Hive only support distinct without
   // window frame. Because it is complicated for Distinct on over().
   private def validate(group: Group): Unit = {
diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
index a3c3538a66d..c58b35754ae 100644
--- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
+++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
@@ -418,6 +418,35 @@ Calc(select=[CASE((w0$o0 > 0), w0$o1, null:INTEGER) AS EXPR$0, w0$o2 AS EXPR$1])
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testOverWindowWithConstants3">
+    <Resource name="sql">
+      <![CDATA[
+SELECT
+    COUNT(2) OVER (ORDER BY a),
+    COUNT(1) OVER (PARTITION BY c ORDER BY a)
+FROM MyTable
+      ]]>
+	</Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(EXPR$0=[COUNT(2) OVER (ORDER BY $0 NULLS FIRST)], EXPR$1=[COUNT(1) OVER (PARTITION BY $2 ORDER BY $0 NULLS FIRST)])
++- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]])
+]]>
+	</Resource>
+    <Resource name="optimized exec plan">
+      <![CDATA[
+Calc(select=[w0$o0 AS $0, w1$o0 AS $1])
++- OverAggregate(partitionBy=[c], orderBy=[a ASC], window#0=[COUNT(1) AS w1$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, w0$o0, w1$o0])
+   +- Sort(orderBy=[c ASC, a ASC])
+      +- Exchange(distribution=[hash[c]])
+         +- OverAggregate(orderBy=[a ASC], window#0=[COUNT(2) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, w0$o0])
+            +- Sort(orderBy=[a ASC])
+               +- Exchange(distribution=[single])
+                  +- Calc(select=[a, c])
+                     +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+]]>
+	</Resource>
+  </TestCase>
   <TestCase name="testSamePartitionKeysWithDiffOrderKeys1">
     <Resource name="sql">
       <![CDATA[
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
index 35401b41367..d0753274fc0 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
@@ -308,6 +308,18 @@ class OverAggregateTest extends TableTestBase {
     util.verifyExecPlan(sqlQuery)
   }
 
+  @Test
+  def testOverWindowWithConstants3(): Unit = {
+    val sqlQuery =
+      """
+        |SELECT
+        |    COUNT(2) OVER (ORDER BY a),
+        |    COUNT(1) OVER (PARTITION BY c ORDER BY a)
+        |FROM MyTable
+      """.stripMargin
+    util.verifyExecPlan(sqlQuery)
+  }
+
   @Test(expected = classOf[RuntimeException])
   def testDistinct(): Unit = {
     val sqlQuery =