You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by go...@apache.org on 2022/09/15 10:30:04 UTC

[flink] 01/03: [FLINK-28569][table-planner] Add projectRowType to RowTypeUtils and deprecate AggCodeGenHelper#projectRowType

This is an automated email from the ASF dual-hosted git repository.

godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit dbcd2d7b86fcb7fa7a26e181f1719ea4c6dad828
Author: lincoln.lil <li...@gmail.com>
AuthorDate: Thu Sep 8 18:08:58 2022 +0800

    [FLINK-28569][table-planner] Add projectRowType to RowTypeUtils and deprecate AggCodeGenHelper#projectRowType
    
    This closes #20791
---
 .../table/planner/typeutils/RowTypeUtils.java      | 35 ++++++++++++++++
 .../codegen/agg/batch/AggCodeGenHelper.scala       |  4 --
 .../codegen/agg/batch/HashAggCodeGenerator.scala   |  3 +-
 .../codegen/agg/batch/SortAggCodeGenerator.scala   |  3 +-
 .../codegen/agg/batch/WindowCodeGenerator.scala    |  3 +-
 .../table/planner/typeutils/RowTypeUtilsTest.java  | 46 ++++++++++++++++++++++
 6 files changed, 87 insertions(+), 7 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java
index ffb9a68a131..4d9879d7b99 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java
@@ -18,7 +18,13 @@
 
 package org.apache.flink.table.planner.typeutils;
 
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+
+import javax.annotation.Nonnull;
+
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
@@ -46,4 +52,33 @@ public class RowTypeUtils {
         }
         return result;
     }
+
+    /**
+     * Returns projected {@link RowType} by given projection indexes over original {@link RowType}.
+     * Will raise an error when projection index beyond the field count of original rowType.
+     *
+     * @param rowType source row type
+     * @param projection indexes array
+     * @return projected {@link RowType}
+     */
+    public static RowType projectRowType(@Nonnull RowType rowType, @Nonnull int[] projection)
+            throws IllegalArgumentException {
+        final int fieldCnt = rowType.getFieldCount();
+        return RowType.of(
+                Arrays.stream(projection)
+                        .mapToObj(
+                                index -> {
+                                    if (index >= fieldCnt) {
+                                        throw new IllegalArgumentException(
+                                                String.format(
+                                                        "Invalid projection index: %d of source rowType size: %d",
+                                                        index, fieldCnt));
+                                    }
+                                    return rowType.getTypeAt(index);
+                                })
+                        .toArray(LogicalType[]::new),
+                Arrays.stream(projection)
+                        .mapToObj(index -> rowType.getFieldNames().get(index))
+                        .toArray(String[]::new));
+    }
 }
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
index bb1135fc236..c401e50ea24 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala
@@ -93,10 +93,6 @@ object AggCodeGenHelper {
       .asInstanceOf[Map[AggregateFunction[_, _], String]]
   }
 
-  def projectRowType(rowType: RowType, mapping: Array[Int]): RowType = {
-    RowType.of(mapping.map(rowType.getTypeAt), mapping.map(rowType.getFieldNames.get(_)))
-  }
-
   /** Add agg handler to class member and open it. */
   private[flink] def addAggsHandler(
       aggsHandler: GeneratedAggsHandleFunction,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index 550a93df3fb..c768ffd5f70 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator}
 import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
 import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList}
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
 import org.apache.flink.table.runtime.generated.GeneratedOperator
 import org.apache.flink.table.runtime.operators.TableStreamOperator
 import org.apache.flink.table.runtime.operators.aggregate.BytesHashMapSpillMemorySegmentPool
@@ -60,7 +61,7 @@ class HashAggCodeGenerator(
   private lazy val aggBufferTypes: Array[Array[LogicalType]] =
     AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos)
 
-  private lazy val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping)
+  private lazy val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
   private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten)
 
   def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
index 3a183dc4183..02d44d733ed 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala
@@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator}
 import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.generateCollect
 import org.apache.flink.table.planner.plan.utils.AggregateInfoList
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
 import org.apache.flink.table.runtime.generated.GeneratedOperator
 import org.apache.flink.table.runtime.operators.TableStreamOperator
 import org.apache.flink.table.types.logical.RowType
@@ -63,7 +64,7 @@ object SortAggCodeGenerator {
     val currentKeyTerm = "currentKey"
     val currentKeyWriterTerm = "currentKeyWriter"
 
-    val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping)
+    val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping)
     val keyProjectionCode = ProjectionCodeGenerator
       .generateProjectionExpression(
         ctx,
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
index ab4a4221cbb..2df6ae64b14 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala
@@ -38,6 +38,7 @@ import org.apache.flink.table.planner.expressions.ExpressionBuilder._
 import org.apache.flink.table.planner.expressions.converter.ExpressionConverter
 import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow}
 import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList, AggregateUtil}
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
 import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory
 import org.apache.flink.table.runtime.groupwindow.NamedWindowProperty
 import org.apache.flink.table.runtime.operators.window.TimeWindow
@@ -82,7 +83,7 @@ abstract class WindowCodeGenerator(
     AggCodeGenHelper.getAggBufferTypes(inputRowType, auxGrouping, aggInfos)
 
   protected lazy val groupKeyRowType: RowType =
-    AggCodeGenHelper.projectRowType(inputRowType, grouping)
+    RowTypeUtils.projectRowType(inputRowType, grouping)
 
   protected lazy val timestampInternalType: LogicalType =
     if (inputTimeIsDate) new IntType() else new BigIntType()
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java
index b0e754a037b..7449d83a903 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java
@@ -18,7 +18,16 @@
 
 package org.apache.flink.table.planner.typeutils;
 
+import org.apache.flink.table.api.ValidationException;
+import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
 import java.util.Arrays;
 
@@ -27,6 +36,13 @@ import static org.assertj.core.api.Assertions.assertThat;
 /** Tests for {@link RowTypeUtils}. */
 public class RowTypeUtilsTest {
 
+    @Rule public ExpectedException expectedException = ExpectedException.none();
+
+    private final RowType srcType =
+            RowType.of(
+                    new LogicalType[] {new IntType(), new VarCharType(), new BigIntType()},
+                    new String[] {"f0", "f1", "f2"});
+
     @Test
     public void testGetUniqueName() {
         assertThat(
@@ -39,4 +55,34 @@ public class RowTypeUtilsTest {
                                 Arrays.asList("Alice", "Bob")))
                 .isEqualTo(Arrays.asList("Bob_0", "Bob_1", "Dave", "Alice_0"));
     }
+
+    @Test
+    public void testProjectRowType() {
+        assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0}))
+                .isEqualTo(RowType.of(new LogicalType[] {new IntType()}, new String[] {"f0"}));
+
+        assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0, 2}))
+                .isEqualTo(
+                        RowType.of(
+                                new LogicalType[] {new IntType(), new BigIntType()},
+                                new String[] {"f0", "f2"}));
+
+        assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 2})).isEqualTo(srcType);
+    }
+
+    @Test
+    public void testInvalidProjectRowType() {
+
+        expectedException.expect(IllegalArgumentException.class);
+        expectedException.expectMessage("Invalid projection index: 3");
+        RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 2, 3});
+
+        expectedException.expect(IllegalArgumentException.class);
+        expectedException.expectMessage("Invalid projection index: 3");
+        RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 3});
+
+        expectedException.expect(ValidationException.class);
+        expectedException.expectMessage("Field names must be unique. Found duplicates");
+        RowTypeUtils.projectRowType(srcType, new int[] {0, 0, 0, 0});
+    }
 }