You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ja...@apache.org on 2020/06/04 07:30:42 UTC

[flink] 01/02: [FLINK-16451][table-planner-blink] Fix IndexOutOfBoundsException for DISTINCT AGG with constants

This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch release-1.10
in repository https://gitbox.apache.org/repos/asf/flink.git

commit fefc92ace2e644f7bd0619bc5103095719f1a7bb
Author: Jark Wu <ja...@apache.org>
AuthorDate: Thu Jun 4 11:57:58 2020 +0800

    [FLINK-16451][table-planner-blink] Fix IndexOutOfBoundsException for DISTINCT AGG with constants
    
    This closes #12432
---
 .../types/logical/utils/LogicalTypeChecks.java     | 35 ++++++++++++++++++++++
 .../codegen/agg/AggsHandlerCodeGenerator.scala     |  1 +
 .../planner/codegen/agg/DistinctAggCodeGen.scala   | 25 +++++++++++-----
 .../runtime/stream/sql/OverWindowITCase.scala      |  8 +++--
 4 files changed, 60 insertions(+), 9 deletions(-)

diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
index c1d35a2..3437835 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java
@@ -30,7 +30,9 @@ import org.apache.flink.table.types.logical.LocalZonedTimestampType;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.LogicalTypeFamily;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.SmallIntType;
+import org.apache.flink.table.types.logical.StructuredType;
 import org.apache.flink.table.types.logical.TimeType;
 import org.apache.flink.table.types.logical.TimestampKind;
 import org.apache.flink.table.types.logical.TimestampType;
@@ -66,6 +68,8 @@ public final class LogicalTypeChecks {
 
 	private static final SingleFieldIntervalExtractor SINGLE_FIELD_INTERVAL_EXTRACTOR = new SingleFieldIntervalExtractor();
 
+	private static final FieldCountExtractor FIELD_COUNT_EXTRACTOR = new FieldCountExtractor();
+
 	public static boolean hasRoot(LogicalType logicalType, LogicalTypeRoot typeRoot) {
 		return logicalType.getTypeRoot() == typeRoot;
 	}
@@ -105,6 +109,13 @@ public final class LogicalTypeChecks {
 		return logicalType.accept(LENGTH_EXTRACTOR);
 	}
 
+	/**
+	 * Returns the field count of row and structured types.
+	 */
+	public static int getFieldCount(LogicalType logicalType) {
+		return logicalType.accept(FIELD_COUNT_EXTRACTOR);
+	}
+
 	public static boolean hasLength(LogicalType logicalType, int length) {
 		return getLength(logicalType) == length;
 	}
@@ -340,6 +351,30 @@ public final class LogicalTypeChecks {
 		}
 	}
 
+	private static class FieldCountExtractor extends Extractor<Integer> {
+
+		@Override
+		public Integer visit(RowType rowType) {
+			return rowType.getFieldCount();
+		}
+
+		@Override
+		public Integer visit(StructuredType structuredType) {
+			int fieldCount = 0;
+			StructuredType currentType = structuredType;
+			while (currentType != null) {
+				fieldCount += currentType.getAttributes().size();
+				currentType = currentType.getSuperType().orElse(null);
+			}
+			return fieldCount;
+		}
+
+		@Override
+		public Integer visit(DistinctType distinctType) {
+			return distinctType.getSourceType().accept(this);
+		}
+	}
+
 	private static class SingleFieldIntervalExtractor extends Extractor<Boolean> {
 
 		@Override
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
index ba81b14..a53b7c0 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
@@ -259,6 +259,7 @@ class AggsHandlerCodeGenerator(
           index,
           innerCodeGens,
           filterExpr.toArray,
+          constantExprs,
           mergedAccOffset,
           aggBufferOffset,
           aggBufferSize,
diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala
index 9f53438..58f069d 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala
@@ -31,6 +31,7 @@ import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
 import org.apache.flink.table.planner.plan.utils.DistinctInfo
 import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType
 import org.apache.flink.table.types.DataType
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks
 import org.apache.flink.table.types.logical.{LogicalType, RowType}
 import org.apache.flink.util.Preconditions
 import org.apache.flink.util.Preconditions.checkArgument
@@ -63,6 +64,7 @@ class DistinctAggCodeGen(
   distinctIndex: Int,
   innerAggCodeGens: Array[AggCodeGen],
   filterExpressions: Array[Option[Expression]],
+  constantExpressions: Seq[GeneratedExpression],
   mergedAccOffset: Int,
   aggBufferOffset: Int,
   aggBufferSize: Int,
@@ -371,13 +373,22 @@ class DistinctAggCodeGen(
   private def generateKeyExpression(
       ctx: CodeGeneratorContext,
       generator: ExprCodeGenerator): GeneratedExpression = {
-    val fieldExprs = distinctInfo.argIndexes.map(generateInputAccess(
-      ctx,
-      generator.input1Type,
-      generator.input1Term,
-      _,
-      nullableInput = false,
-      deepCopy = inputFieldCopy))
+    val fieldExprs = distinctInfo.argIndexes.map(argIndex => {
+      val inputFieldCount = LogicalTypeChecks.getFieldCount(generator.input1Type)
+      if (argIndex >= inputFieldCount) {
+        // arg index to constant
+        constantExpressions(argIndex - inputFieldCount)
+      } else {
+        // arg index to input field
+        generateInputAccess(
+          ctx,
+          generator.input1Type,
+          generator.input1Term,
+          argIndex,
+          nullableInput = false,
+          deepCopy = inputFieldCopy)
+      }
+    })
 
     // the key expression of MapView
     if (fieldExprs.length > 1) {
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala
index 2e8cf75..a0c13e1 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala
@@ -280,14 +280,18 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas
     tEnv.registerTable("T1", t1)
 
     val sqlQuery = "SELECT " +
-      "count(a) OVER (ORDER BY proctime ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW) " +
+      "listagg(distinct c, '|') " +
+      "  OVER (ORDER BY proctime ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW), " +
+      "count(a) " +
+      "  OVER (ORDER BY proctime ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW) " +
       "from T1"
 
     val sink = new TestingAppendSink
     tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
     env.execute()
 
-    val expected = List("1", "2", "3", "4", "5", "6", "7", "8", "9")
+    val expected = List("Hello,1", "Hello,2", "Hello,3", "Hello,4", "Hello,5", "Hello,6",
+      "Hello|Hello World,7", "Hello|Hello World,8", "Hello|Hello World,9")
     assertEquals(expected.sorted, sink.getAppendResults.sorted)
   }