You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by jc...@apache.org on 2024/01/12 01:47:24 UTC

(flink) branch master updated: [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract

This is an automated email from the ASF dual-hosted git repository.

jchan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 01569644aed [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract
01569644aed is described below

commit 01569644aedb56f792c7f7e04f84612d405b0bdf
Author: Jane Chan <qi...@gmail.com>
AuthorDate: Fri Jan 12 09:47:16 2024 +0800

    [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract
    
    This closes #24051
---
 .../exec/stream/StreamExecGroupTableAggregate.java |   1 +
 .../codegen/agg/AggsHandlerCodeGenerator.scala     |  52 +++++++++-
 .../planner/codegen/agg/ImperativeAggCodeGen.scala |  21 +++-
 .../utils/JavaUserDefinedTableAggFunctions.java    | 114 +++++++++++++++++++++
 .../stream/table/TableAggregateITCase.scala        |  80 ++++++++++++++-
 .../operators/aggregate/GroupTableAggFunction.java |   9 +-
 6 files changed, 268 insertions(+), 9 deletions(-)

diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
index 0f4f80f5c94..1d9e454bd11 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
@@ -156,6 +156,7 @@ public class StreamExecGroupTableAggregate extends ExecNodeBase<RowData>
                         accTypes,
                         inputCountIndex,
                         generateUpdateBefore,
+                        generator.isIncrementalUpdate(),
                         config.getStateRetentionTime());
         final OneInputStreamOperator<RowData, RowData> operator =
                 new KeyedProcessOperator<>(aggFunction);
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
index 583e49bd035..84dd8d83858 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
@@ -21,13 +21,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializer
 import org.apache.flink.table.api.{DataTypes, TableException}
 import org.apache.flink.table.data.GenericRowData
 import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction}
+import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction, TableAggregateFunction, UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
 import org.apache.flink.table.planner.JLong
 import org.apache.flink.table.planner.codegen._
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.Indenter.toISC
 import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._
 import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
+import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils
 import org.apache.flink.table.planner.plan.utils.AggregateInfoList
 import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
 import org.apache.flink.table.runtime.dataview.{DataViewSpec, ListViewSpec, MapViewSpec, StateListView, StateMapView}
@@ -86,6 +88,7 @@ class AggsHandlerCodeGenerator(
   private var isRetractNeeded = false
   private var isMergeNeeded = false
   private var isWindowSizeNeeded = false
+  private var isIncrementalUpdateNeeded = false
 
   var valueType: RowType = _
 
@@ -166,6 +169,14 @@ class AggsHandlerCodeGenerator(
     this
   }
 
+  /**
+   * Whether to update acc result incrementally. The value is true only for TableAggregateFunction
+   * with emitUpdateWithRetract method implemented.
+   */
+  def isIncrementalUpdate: Boolean = {
+    isIncrementalUpdateNeeded
+  }
+
   /**
    * Tells the generator to generate `merge(..)` method with the merged accumulator information for
    * the [[AggsHandleFunction]] and [[NamespaceAggsHandleFunction]]. Default not generate
@@ -234,6 +245,20 @@ class AggsHandlerCodeGenerator(
               constants,
               relBuilder)
           case _: ImperativeAggregateFunction[_, _] =>
+            aggInfo.function match {
+              case tableAggFunc: TableAggregateFunction[_, _] =>
+                // If the user implements both the emitValue and emitUpdateWithRetract methods,
+                // the emitUpdateWithRetract method will be called with priority.
+                if (
+                  UserDefinedFunctionUtils.ifMethodExistInFunction(
+                    UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT,
+                    tableAggFunc)
+                ) {
+                  this.isIncrementalUpdateNeeded = true
+                }
+
+              case _ =>
+            }
             new ImperativeAggCodeGen(
               ctx,
               aggInfo,
@@ -247,7 +272,8 @@ class AggsHandlerCodeGenerator(
               hasNamespace,
               mergedAccOnHeap,
               mergedAccExternalTypes(aggBufferOffset),
-              copyInputField)
+              copyInputField,
+              isIncrementalUpdateNeeded)
         }
         aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length
         codegen
@@ -447,6 +473,23 @@ class AggsHandlerCodeGenerator(
     val recordInputName = newName("recordInput")
     val recordToRowDataCode = genRecordToRowData(aggExternalType, recordInputName)
 
+    // for emitUpdateWithRetract, the collector needs to implement RetractableCollector
+    // and override retract method
+    val (collectorClassName, collectorRetractCode) =
+      if (isIncrementalUpdateNeeded)
+        (
+          RETRACTABLE_COLLECTOR,
+          s"""
+             |@Override
+             |public void retract(Object $recordInputName) throws Exception {
+             |  $ROW_DATA tempRowData = convertToRowData($recordInputName);
+             |  result.replace(key, tempRowData);
+             |  result.setRowKind($ROW_KIND.DELETE);
+             |  $COLLECTOR_TERM.collect(result);
+             |}
+             |""".stripMargin)
+      else (COLLECTOR, "")
+
     val functionName = newName(name)
     val functionCode =
       j"""
@@ -527,7 +570,7 @@ class AggsHandlerCodeGenerator(
             ${ctx.reuseCloseCode()}
           }
 
-          private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR {
+          private class $CONVERT_COLLECTOR_TYPE_TERM implements $collectorClassName {
             private $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM;
             private $ROW_DATA key;
             private $JOINED_ROW result;
@@ -562,6 +605,8 @@ class AggsHandlerCodeGenerator(
               $COLLECTOR_TERM.collect(result);
             }
 
+            $collectorRetractCode
+
             @Override
             public void close() {
               $COLLECTOR_TERM.close();
@@ -1255,6 +1300,7 @@ object AggsHandlerCodeGenerator {
   val STORE_TERM = "store"
 
   val COLLECTOR: String = className[Collector[_]]
+  val RETRACTABLE_COLLECTOR: String = className[RetractableCollector[_]]
   val COLLECTOR_TERM = "out"
   val MEMBER_COLLECTOR_TERM = "convertCollector"
   val CONVERT_COLLECTOR_TYPE_TERM = "ConvertCollector"
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
index 533c956d3a3..6add23ac9a8 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen.agg
 import org.apache.flink.table.data.{GenericRowData, RowData, UpdatableRowData}
 import org.apache.flink.table.expressions.Expression
 import org.apache.flink.table.functions.{FunctionContext, ImperativeAggregateFunction, UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
 import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression}
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import org.apache.flink.table.planner.codegen.GenerateUtils.generateFieldAccess
@@ -69,6 +70,9 @@ import scala.collection.mutable.ArrayBuffer
  *   whether the accumulators state has namespace
  * @param inputFieldCopy
  *   copy input field element if true (only mutable type will be copied)
+ * @param isIncrementalUpdateNeeded
+ *   whether the agg supports emitting incremental update, true for TableAggregateFunction if
+ *   user-defined function implements emitUpdateWithRetract, otherwise false.
  */
 class ImperativeAggCodeGen(
     ctx: CodeGeneratorContext,
@@ -83,7 +87,8 @@ class ImperativeAggCodeGen(
     hasNamespace: Boolean,
     mergedAccOnHeap: Boolean,
     mergedAccExternalType: DataType,
-    inputFieldCopy: Boolean)
+    inputFieldCopy: Boolean,
+    isIncrementalUpdateNeeded: Boolean)
   extends AggCodeGen {
 
   private val SINGLE_ITERABLE = className[SingleElementIterator[_]]
@@ -488,10 +493,14 @@ class ImperativeAggCodeGen(
     }
 
     if (needEmitValue) {
+      val (emitMethod, collectorClass) =
+        if (isIncrementalUpdateNeeded)
+          (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT, classOf[RetractableCollector[_]])
+        else (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT, classOf[Collector[_]])
       UserDefinedFunctionHelper.validateClassForRuntime(
         function.getClass,
-        UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT,
-        accumulatorClass ++ Array(classOf[Collector[_]]),
+        emitMethod,
+        accumulatorClass ++ Array(collectorClass),
         classOf[Unit],
         functionName
       )
@@ -500,7 +509,11 @@ class ImperativeAggCodeGen(
 
   def emitValue: String = {
     val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm
-    s"$functionTerm.emitValue($accTerm, $MEMBER_COLLECTOR_TERM);"
+    val finalEmitMethodName =
+      if (isIncrementalUpdateNeeded) UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT
+      else UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT
+
+    s"$functionTerm.$finalEmitMethodName($accTerm, $MEMBER_COLLECTOR_TERM);"
   }
 
   override def setWindowSize(generator: ExprCodeGenerator): String = {
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
new file mode 100644
index 00000000000..3f02991b0ce
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.utils;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.util.Collector;
+
+/** Test table aggregate table functions. */
+public class JavaUserDefinedTableAggFunctions {
+
+    /** Mutable accumulator of structured type for the table aggregate function. */
+    public static class Top2Accumulator {
+
+        public Integer first;
+        public Integer second;
+        public Integer previousFirst;
+        public Integer previousSecond;
+    }
+
+    /**
+     * Function that takes (value INT), stores intermediate results in a structured type of {@link
+     * Top2Accumulator}, and returns the result as a structured type of {@link Tuple2} for value and
+     * rank.
+     */
+    public static class Top2
+            extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accumulator> {
+
+        @Override
+        public Top2Accumulator createAccumulator() {
+            Top2Accumulator acc = new Top2Accumulator();
+            acc.first = Integer.MIN_VALUE;
+            acc.second = Integer.MIN_VALUE;
+            return acc;
+        }
+
+        public void accumulate(Top2Accumulator acc, Integer value) {
+            if (value > acc.first) {
+                acc.second = acc.first;
+                acc.first = value;
+            } else if (value > acc.second) {
+                acc.second = value;
+            }
+        }
+
+        public void merge(Top2Accumulator acc, Iterable<Top2Accumulator> it) {
+            for (Top2Accumulator otherAcc : it) {
+                accumulate(acc, otherAcc.first);
+                accumulate(acc, otherAcc.second);
+            }
+        }
+
+        public void emitValue(Top2Accumulator acc, Collector<Tuple2<Integer, Integer>> out) {
+            // emit the value and rank
+            if (acc.first != Integer.MIN_VALUE) {
+                out.collect(Tuple2.of(acc.first, 1));
+            }
+            if (acc.second != Integer.MIN_VALUE) {
+                out.collect(Tuple2.of(acc.second, 2));
+            }
+        }
+    }
+
+    /** Subclass of {@link Top2} to support emit incremental changes. */
+    public static class IncrementalTop2 extends Top2 {
+        @Override
+        public Top2Accumulator createAccumulator() {
+            Top2Accumulator acc = super.createAccumulator();
+            acc.previousFirst = Integer.MIN_VALUE;
+            acc.previousSecond = Integer.MIN_VALUE;
+            return acc;
+        }
+
+        @Override
+        public void accumulate(Top2Accumulator acc, Integer value) {
+            acc.previousFirst = acc.first;
+            acc.previousSecond = acc.second;
+            super.accumulate(acc, value);
+        }
+
+        public void emitUpdateWithRetract(
+                Top2Accumulator acc, RetractableCollector<Tuple2<Integer, Integer>> out) {
+            // emit the value and rank only if they're changed
+            if (!acc.first.equals(acc.previousFirst)) {
+                if (!acc.previousFirst.equals(Integer.MIN_VALUE)) {
+                    out.retract(Tuple2.of(acc.previousFirst, 1));
+                }
+                out.collect(Tuple2.of(acc.first, 1));
+            }
+            if (!acc.second.equals(acc.previousSecond)) {
+                if (!acc.previousSecond.equals(Integer.MIN_VALUE)) {
+                    out.retract(Tuple2.of(acc.previousSecond, 2));
+                }
+                out.collect(Tuple2.of(acc.second, 2));
+            }
+        }
+    }
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
index 18d11d732f4..ca565b14eb9 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.common.time.Time
 import org.apache.flink.api.scala._
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
-import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestingRetractSink}
+import org.apache.flink.table.planner.runtime.utils.{JavaUserDefinedTableAggFunctions, StreamingWithStateTestBase, TestingRetractSink}
 import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedDoubleMaxFunction
 import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
 import org.apache.flink.table.planner.runtime.utils.TestData.tupleData3
@@ -45,6 +45,84 @@ class TableAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTes
     tEnv.getConfig.setIdleStateRetention(Duration.ofHours(1))
   }
 
+  @TestTemplate
+  def testFlagAggregateWithOrWithoutIncrementalUpdate(): Unit = {
+    // Create a Table from the array of Rows
+    val table = tEnv.fromValues(
+      DataTypes.ROW(
+        DataTypes.FIELD("id", DataTypes.INT),
+        DataTypes.FIELD("name", DataTypes.STRING),
+        DataTypes.FIELD("price", DataTypes.INT)),
+      row(1, "Latte", 6: java.lang.Integer),
+      row(2, "Milk", 3: java.lang.Integer),
+      row(3, "Breve", 5: java.lang.Integer),
+      row(4, "Mocha", 8: java.lang.Integer),
+      row(5, "Tea", 4: java.lang.Integer)
+    )
+
+    // Register the table aggregate function
+    tEnv.createTemporarySystemFunction("top2", new JavaUserDefinedTableAggFunctions.Top2)
+    tEnv.createTemporarySystemFunction(
+      "incrementalTop2",
+      new JavaUserDefinedTableAggFunctions.IncrementalTop2)
+
+    checkRank(
+      "top2",
+      List(
+        // output triggered by (1, "Latte", 6)
+        "(true,6,1)",
+        // output triggered by (2, "Milk", 3)
+        "(false,6,1)",
+        "(true,6,1)",
+        "(true,3,2)",
+        // output triggered by (3, "Breve", 5)
+        "(false,6,1)",
+        "(false,3,2)",
+        "(true,6,1)",
+        "(true,5,2)",
+        // output triggered by (4, "Mocha", 8)
+        "(false,6,1)",
+        "(false,5,2)",
+        "(true,8,1)",
+        "(true,6,2)",
+        // output triggered by (5, "Tea", 4)
+        "(false,8,1)",
+        "(false,6,2)",
+        "(true,8,1)",
+        "(true,6,2)"
+      )
+    )
+    checkRank(
+      "incrementalTop2",
+      List(
+        // output triggered by (1, "Latte", 6)
+        "(true,6,1)",
+        // output triggered by (2, "Milk", 3)
+        "(true,3,2)",
+        // output triggered by (3, "Breve", 5)
+        "(false,3,2)",
+        "(true,5,2)",
+        // output triggered by (4, "Mocha", 8)
+        "(false,6,1)",
+        "(true,8,1)",
+        "(false,5,2)",
+        "(true,6,2)"
+      )
+    )
+
+    def checkRank(func: String, expectedResult: List[String]): Unit = {
+      val resultTable =
+        table
+          .flatAggregate(call(func, $("price")).as("top_price", "rank"))
+          .select($("top_price"), $("rank"))
+
+      val sink = new TestingRetractSink()
+      resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+      env.execute()
+      assertThat(sink.getRawResults).isEqualTo(expectedResult)
+    }
+  }
+
   @TestTemplate
   def testGroupByFlatAggregate(): Unit = {
     val top3 = new Top3
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
index ad26d48b703..0c36db1cc7e 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
@@ -51,6 +51,8 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
     /** Whether this operator will generate UPDATE_BEFORE messages. */
     private final boolean generateUpdateBefore;
 
+    private final boolean incrementalUpdate;
+
     /** State idle retention time which unit is MILLISECONDS. */
     private final long stateRetentionTime;
 
@@ -69,6 +71,7 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
      *     contain COUNT(*), i.e. doesn't contain retraction messages. We make sure there is a
      *     COUNT(*) if input stream contains retraction.
      * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages.
+     * @param incrementalUpdate Whether to update acc result incrementally.
      * @param stateRetentionTime state idle retention time which unit is MILLISECONDS.
      */
     public GroupTableAggFunction(
@@ -76,11 +79,13 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
             LogicalType[] accTypes,
             int indexOfCountStar,
             boolean generateUpdateBefore,
+            boolean incrementalUpdate,
             long stateRetentionTime) {
         this.genAggsHandler = genAggsHandler;
         this.accTypes = accTypes;
         this.recordCounter = RecordCounter.of(indexOfCountStar);
         this.generateUpdateBefore = generateUpdateBefore;
+        this.incrementalUpdate = incrementalUpdate;
         this.stateRetentionTime = stateRetentionTime;
     }
 
@@ -117,7 +122,9 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
         // set accumulators to handler first
         function.setAccumulators(accumulators);
 
-        if (!firstRow && generateUpdateBefore) {
+        // when incrementalUpdate is required, there is no need to retract
+        // previous sent data which is not changed
+        if (!firstRow && !incrementalUpdate && generateUpdateBefore) {
             function.emitValue(out, currentKey, true);
         }