You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by go...@apache.org on 2022/09/15 10:30:04 UTC
[flink] 01/03: [FLINK-28569][table-planner] Add projectRowType to RowTypeUtils and deprecate AggCodeGenHelper#projectRowType
This is an automated email from the ASF dual-hosted git repository.
godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit dbcd2d7b86fcb7fa7a26e181f1719ea4c6dad828
Author: lincoln.lil <li...@gmail.com>
AuthorDate: Thu Sep 8 18:08:58 2022 +0800
[FLINK-28569][table-planner] Add projectRowType to RowTypeUtils and deprecate AggCodeGenHelper#projectRowType
This closes #20791
---
.../table/planner/typeutils/RowTypeUtils.java | 35 ++++++++++++++++
.../codegen/agg/batch/AggCodeGenHelper.scala | 4 --
.../codegen/agg/batch/HashAggCodeGenerator.scala | 3 +-
.../codegen/agg/batch/SortAggCodeGenerator.scala | 3 +-
.../codegen/agg/batch/WindowCodeGenerator.scala | 3 +-
.../table/planner/typeutils/RowTypeUtilsTest.java | 46 ++++++++++++++++++++++
6 files changed, 87 insertions(+), 7 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java
index ffb9a68a131..4d9879d7b99 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java
@@ -18,7 +18,13 @@
package org.apache.flink.table.planner.typeutils;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+
+import javax.annotation.Nonnull;
+
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@@ -46,4 +52,33 @@ public class RowTypeUtils {
}
return result;
}
+
+ /**
+ * Returns projected {@link RowType} by given projection indexes over original {@link RowType}.
+ * Will raise an error when projection index beyond the field count of original rowType.
+ *
+ * @param rowType source row type
+ * @param projection indexes array
+ * @return projected {@link RowType}
+ */
+ public static RowType projectRowType(@Nonnull RowType rowType, @Nonnull int[] projection)
+ throws IllegalArgumentException {
+ final int fieldCnt = rowType.getFieldCount();
+ return RowType.of(
+ Arrays.stream(projection)
+ .mapToObj(
+ index -> {
+ if (index >= fieldCnt) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Invalid projection index: %d of source rowType size: %d",
+ index, fieldCnt));
+ }
+ return rowType.getTypeAt(index);
+ })
+ .toArray(LogicalType[]::new),
+ Arrays.stream(projection)
+ .mapToObj(index -> rowType.getFieldNames().get(index))
+ .toArray(String[]::new));
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
index bb1135fc236..c401e50ea24 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
@@ -93,10 +93,6 @@ object AggCodeGenHelper {
.asInstanceOf[Map[AggregateFunction[_, _], String]]
}
- def projectRowType(rowType: RowType, mapping: Array[Int]): RowType = {
- RowType.of(mapping.map(rowType.getTypeAt), mapping.map(rowType.getFieldNames.get(_)))
- }
-
/** Add agg handler to class member and open it. */
private[flink] def addAggsHandler(
aggsHandler: GeneratedAggsHandleFunction,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index 550a93df3fb..c768ffd5f70 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator}
import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.runtime.generated.GeneratedOperator
import org.apache.flink.table.runtime.operators.TableStreamOperator
import org.apache.flink.table.runtime.operators.aggregate.BytesHashMapSpillMemorySegmentPool
@@ -60,7 +61,7 @@ class HashAggCodeGenerator(
private lazy val aggBufferTypes: Array[Array[LogicalType]] =
AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
- private lazy val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping)
+ private lazy val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
index 3a183dc4183..02d44d733ed 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
@@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction
import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator}
import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.generateCollect
import org.apache.flink.table.planner.plan.utils.AggregateInfoList
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.runtime.generated.GeneratedOperator
import org.apache.flink.table.runtime.operators.TableStreamOperator
import org.apache.flink.table.types.logical.RowType
@@ -63,7 +64,7 @@ object SortAggCodeGenerator {
val currentKeyTerm = "currentKey"
val currentKeyWriterTerm = "currentKeyWriter"
- val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping)
+ val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
val keyProjectionCode = ProjectionCodeGenerator
.generateProjectionExpression(
ctx,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
index ab4a4221cbb..2df6ae64b14 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
@@ -38,6 +38,7 @@ import org.apache.flink.table.planner.expressions.ExpressionBuilder._
import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList, AggregateUtil}
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory
import org.apache.flink.table.runtime.groupwindow.NamedWindowProperty
import org.apache.flink.table.runtime.operators.window.TimeWindow
@@ -82,7 +83,7 @@ abstract class WindowCodeGenerator(
AggCodeGenHelper.getAggBufferTypes(inputRowType, auxGrouping, aggInfos)
protected lazy val groupKeyRowType: RowType =
- AggCodeGenHelper.projectRowType(inputRowType, grouping)
+ RowTypeUtils.projectRowType(inputRowType, grouping)
protected lazy val timestampInternalType: LogicalType =
if (inputTimeIsDate) new IntType() else new BigIntType()
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java
index b0e754a037b..7449d83a903 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java
@@ -18,7 +18,16 @@
package org.apache.flink.table.planner.typeutils;
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
import java.util.Arrays;
@@ -27,6 +36,13 @@ import static org.assertj.core.api.Assertions.assertThat;
/** Tests for {@link RowTypeUtils}. */
public class RowTypeUtilsTest {
+ @Rule public ExpectedException expectedException = ExpectedException.none();
+
+ private final RowType srcType =
+ RowType.of(
+ new LogicalType[] {new IntType(), new VarCharType(), new BigIntType()},
+ new String[] {"f0", "f1", "f2"});
+
@Test
public void testGetUniqueName() {
assertThat(
@@ -39,4 +55,34 @@ public class RowTypeUtilsTest {
Arrays.asList("Alice", "Bob")))
.isEqualTo(Arrays.asList("Bob_0", "Bob_1", "Dave", "Alice_0"));
}
+
+ @Test
+ public void testProjectRowType() {
+ assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0}))
+ .isEqualTo(RowType.of(new LogicalType[] {new IntType()}, new String[] {"f0"}));
+
+ assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0, 2}))
+ .isEqualTo(
+ RowType.of(
+ new LogicalType[] {new IntType(), new BigIntType()},
+ new String[] {"f0", "f2"}));
+
+ assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 2})).isEqualTo(srcType);
+ }
+
+ @Test
+ public void testInvalidProjectRowType() {
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Invalid projection index: 3");
+ RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 2, 3});
+
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Invalid projection index: 3");
+ RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 3});
+
+ expectedException.expect(ValidationException.class);
+ expectedException.expectMessage("Field names must be unique. Found duplicates");
+ RowTypeUtils.projectRowType(srcType, new int[] {0, 0, 0, 0});
+ }
}