You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by jc...@apache.org on 2024/01/12 01:47:24 UTC
(flink) branch master updated: [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract
This is an automated email from the ASF dual-hosted git repository.
jchan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 01569644aed [FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract
01569644aed is described below
commit 01569644aedb56f792c7f7e04f84612d405b0bdf
Author: Jane Chan <qi...@gmail.com>
AuthorDate: Fri Jan 12 09:47:16 2024 +0800
[FLINK-31788][table] TableAggregateFunction supports emitUpdateWithRetract
This closes #24051
---
.../exec/stream/StreamExecGroupTableAggregate.java | 1 +
.../codegen/agg/AggsHandlerCodeGenerator.scala | 52 +++++++++-
.../planner/codegen/agg/ImperativeAggCodeGen.scala | 21 +++-
.../utils/JavaUserDefinedTableAggFunctions.java | 114 +++++++++++++++++++++
.../stream/table/TableAggregateITCase.scala | 80 ++++++++++++++-
.../operators/aggregate/GroupTableAggFunction.java | 9 +-
6 files changed, 268 insertions(+), 9 deletions(-)
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
index 0f4f80f5c94..1d9e454bd11 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupTableAggregate.java
@@ -156,6 +156,7 @@ public class StreamExecGroupTableAggregate extends ExecNodeBase<RowData>
accTypes,
inputCountIndex,
generateUpdateBefore,
+ generator.isIncrementalUpdate(),
config.getStateRetentionTime());
final OneInputStreamOperator<RowData, RowData> operator =
new KeyedProcessOperator<>(aggFunction);
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
index 583e49bd035..84dd8d83858 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala
@@ -21,13 +21,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.table.api.{DataTypes, TableException}
import org.apache.flink.table.data.GenericRowData
import org.apache.flink.table.expressions._
-import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction}
+import org.apache.flink.table.functions.{DeclarativeAggregateFunction, ImperativeAggregateFunction, TableAggregateFunction, UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.codegen._
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.Indenter.toISC
import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator._
import org.apache.flink.table.planner.expressions.DeclarativeExpressionResolver.toRexInputRef
+import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.planner.plan.utils.AggregateInfoList
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.dataview.{DataViewSpec, ListViewSpec, MapViewSpec, StateListView, StateMapView}
@@ -86,6 +88,7 @@ class AggsHandlerCodeGenerator(
private var isRetractNeeded = false
private var isMergeNeeded = false
private var isWindowSizeNeeded = false
+ private var isIncrementalUpdateNeeded = false
var valueType: RowType = _
@@ -166,6 +169,14 @@ class AggsHandlerCodeGenerator(
this
}
+ /**
+ * Whether to update acc result incrementally. The value is true only for TableAggregateFunction
+ * with emitUpdateWithRetract method implemented.
+ */
+ def isIncrementalUpdate: Boolean = {
+ isIncrementalUpdateNeeded
+ }
+
/**
* Tells the generator to generate `merge(..)` method with the merged accumulator information for
* the [[AggsHandleFunction]] and [[NamespaceAggsHandleFunction]]. Default not generate
@@ -234,6 +245,20 @@ class AggsHandlerCodeGenerator(
constants,
relBuilder)
case _: ImperativeAggregateFunction[_, _] =>
+ aggInfo.function match {
+ case tableAggFunc: TableAggregateFunction[_, _] =>
+ // If the user implements both the emitValue and emitUpdateWithRetract methods,
+ // the emitUpdateWithRetract method will be called with priority.
+ if (
+ UserDefinedFunctionUtils.ifMethodExistInFunction(
+ UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT,
+ tableAggFunc)
+ ) {
+ this.isIncrementalUpdateNeeded = true
+ }
+
+ case _ =>
+ }
new ImperativeAggCodeGen(
ctx,
aggInfo,
@@ -247,7 +272,8 @@ class AggsHandlerCodeGenerator(
hasNamespace,
mergedAccOnHeap,
mergedAccExternalTypes(aggBufferOffset),
- copyInputField)
+ copyInputField,
+ isIncrementalUpdateNeeded)
}
aggBufferOffset = aggBufferOffset + aggInfo.externalAccTypes.length
codegen
@@ -447,6 +473,23 @@ class AggsHandlerCodeGenerator(
val recordInputName = newName("recordInput")
val recordToRowDataCode = genRecordToRowData(aggExternalType, recordInputName)
+ // for emitUpdateWithRetract, the collector needs to implement RetractableCollector
+ // and override retract method
+ val (collectorClassName, collectorRetractCode) =
+ if (isIncrementalUpdateNeeded)
+ (
+ RETRACTABLE_COLLECTOR,
+ s"""
+ |@Override
+ |public void retract(Object $recordInputName) throws Exception {
+ | $ROW_DATA tempRowData = convertToRowData($recordInputName);
+ | result.replace(key, tempRowData);
+ | result.setRowKind($ROW_KIND.DELETE);
+ | $COLLECTOR_TERM.collect(result);
+ |}
+ |""".stripMargin)
+ else (COLLECTOR, "")
+
val functionName = newName(name)
val functionCode =
j"""
@@ -527,7 +570,7 @@ class AggsHandlerCodeGenerator(
${ctx.reuseCloseCode()}
}
- private class $CONVERT_COLLECTOR_TYPE_TERM implements $COLLECTOR {
+ private class $CONVERT_COLLECTOR_TYPE_TERM implements $collectorClassName {
private $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM;
private $ROW_DATA key;
private $JOINED_ROW result;
@@ -562,6 +605,8 @@ class AggsHandlerCodeGenerator(
$COLLECTOR_TERM.collect(result);
}
+ $collectorRetractCode
+
@Override
public void close() {
$COLLECTOR_TERM.close();
@@ -1255,6 +1300,7 @@ object AggsHandlerCodeGenerator {
val STORE_TERM = "store"
val COLLECTOR: String = className[Collector[_]]
+ val RETRACTABLE_COLLECTOR: String = className[RetractableCollector[_]]
val COLLECTOR_TERM = "out"
val MEMBER_COLLECTOR_TERM = "convertCollector"
val CONVERT_COLLECTOR_TYPE_TERM = "ConvertCollector"
diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
index 533c956d3a3..6add23ac9a8 100644
--- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
+++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/ImperativeAggCodeGen.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.codegen.agg
import org.apache.flink.table.data.{GenericRowData, RowData, UpdatableRowData}
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.functions.{FunctionContext, ImperativeAggregateFunction, UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression}
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GenerateUtils.generateFieldAccess
@@ -69,6 +70,9 @@ import scala.collection.mutable.ArrayBuffer
* whether the accumulators state has namespace
* @param inputFieldCopy
* copy input field element if true (only mutable type will be copied)
+ * @param isIncrementalUpdateNeeded
+ * whether the agg supports emitting incremental update, true for TableAggregateFunction if
+ * user-defined function implements emitUpdateWithRetract, otherwise false.
*/
class ImperativeAggCodeGen(
ctx: CodeGeneratorContext,
@@ -83,7 +87,8 @@ class ImperativeAggCodeGen(
hasNamespace: Boolean,
mergedAccOnHeap: Boolean,
mergedAccExternalType: DataType,
- inputFieldCopy: Boolean)
+ inputFieldCopy: Boolean,
+ isIncrementalUpdateNeeded: Boolean)
extends AggCodeGen {
private val SINGLE_ITERABLE = className[SingleElementIterator[_]]
@@ -488,10 +493,14 @@ class ImperativeAggCodeGen(
}
if (needEmitValue) {
+ val (emitMethod, collectorClass) =
+ if (isIncrementalUpdateNeeded)
+ (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT, classOf[RetractableCollector[_]])
+ else (UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT, classOf[Collector[_]])
UserDefinedFunctionHelper.validateClassForRuntime(
function.getClass,
- UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT,
- accumulatorClass ++ Array(classOf[Collector[_]]),
+ emitMethod,
+ accumulatorClass ++ Array(collectorClass),
classOf[Unit],
functionName
)
@@ -500,7 +509,11 @@ class ImperativeAggCodeGen(
def emitValue: String = {
val accTerm = if (isAccTypeInternal) accInternalTerm else accExternalTerm
- s"$functionTerm.emitValue($accTerm, $MEMBER_COLLECTOR_TERM);"
+ val finalEmitMethodName =
+ if (isIncrementalUpdateNeeded) UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT_RETRACT
+ else UserDefinedFunctionHelper.TABLE_AGGREGATE_EMIT
+
+ s"$functionTerm.$finalEmitMethodName($accTerm, $MEMBER_COLLECTOR_TERM);"
}
override def setWindowSize(generator: ExprCodeGenerator): String = {
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
new file mode 100644
index 00000000000..3f02991b0ce
--- /dev/null
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedTableAggFunctions.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.utils;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.util.Collector;
+
+/** Test table aggregate table functions. */
+public class JavaUserDefinedTableAggFunctions {
+
+ /** Mutable accumulator of structured type for the table aggregate function. */
+ public static class Top2Accumulator {
+
+ public Integer first;
+ public Integer second;
+ public Integer previousFirst;
+ public Integer previousSecond;
+ }
+
+ /**
+ * Function that takes (value INT), stores intermediate results in a structured type of {@link
+ * Top2Accumulator}, and returns the result as a structured type of {@link Tuple2} for value and
+ * rank.
+ */
+ public static class Top2
+ extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accumulator> {
+
+ @Override
+ public Top2Accumulator createAccumulator() {
+ Top2Accumulator acc = new Top2Accumulator();
+ acc.first = Integer.MIN_VALUE;
+ acc.second = Integer.MIN_VALUE;
+ return acc;
+ }
+
+ public void accumulate(Top2Accumulator acc, Integer value) {
+ if (value > acc.first) {
+ acc.second = acc.first;
+ acc.first = value;
+ } else if (value > acc.second) {
+ acc.second = value;
+ }
+ }
+
+ public void merge(Top2Accumulator acc, Iterable<Top2Accumulator> it) {
+ for (Top2Accumulator otherAcc : it) {
+ accumulate(acc, otherAcc.first);
+ accumulate(acc, otherAcc.second);
+ }
+ }
+
+ public void emitValue(Top2Accumulator acc, Collector<Tuple2<Integer, Integer>> out) {
+ // emit the value and rank
+ if (acc.first != Integer.MIN_VALUE) {
+ out.collect(Tuple2.of(acc.first, 1));
+ }
+ if (acc.second != Integer.MIN_VALUE) {
+ out.collect(Tuple2.of(acc.second, 2));
+ }
+ }
+ }
+
+ /** Subclass of {@link Top2} to support emit incremental changes. */
+ public static class IncrementalTop2 extends Top2 {
+ @Override
+ public Top2Accumulator createAccumulator() {
+ Top2Accumulator acc = super.createAccumulator();
+ acc.previousFirst = Integer.MIN_VALUE;
+ acc.previousSecond = Integer.MIN_VALUE;
+ return acc;
+ }
+
+ @Override
+ public void accumulate(Top2Accumulator acc, Integer value) {
+ acc.previousFirst = acc.first;
+ acc.previousSecond = acc.second;
+ super.accumulate(acc, value);
+ }
+
+ public void emitUpdateWithRetract(
+ Top2Accumulator acc, RetractableCollector<Tuple2<Integer, Integer>> out) {
+ // emit the value and rank only if they're changed
+ if (!acc.first.equals(acc.previousFirst)) {
+ if (!acc.previousFirst.equals(Integer.MIN_VALUE)) {
+ out.retract(Tuple2.of(acc.previousFirst, 1));
+ }
+ out.collect(Tuple2.of(acc.first, 1));
+ }
+ if (!acc.second.equals(acc.previousSecond)) {
+ if (!acc.previousSecond.equals(Integer.MIN_VALUE)) {
+ out.retract(Tuple2.of(acc.previousSecond, 2));
+ }
+ out.collect(Tuple2.of(acc.second, 2));
+ }
+ }
+ }
+}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
index 18d11d732f4..ca565b14eb9 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/TableAggregateITCase.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.common.time.Time
import org.apache.flink.api.scala._
import org.apache.flink.table.api._
import org.apache.flink.table.api.bridge.scala._
-import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestingRetractSink}
+import org.apache.flink.table.planner.runtime.utils.{JavaUserDefinedTableAggFunctions, StreamingWithStateTestBase, TestingRetractSink}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedDoubleMaxFunction
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
import org.apache.flink.table.planner.runtime.utils.TestData.tupleData3
@@ -45,6 +45,84 @@ class TableAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTes
tEnv.getConfig.setIdleStateRetention(Duration.ofHours(1))
}
+ @TestTemplate
+ def testFlagAggregateWithOrWithoutIncrementalUpdate(): Unit = {
+ // Create a Table from the array of Rows
+ val table = tEnv.fromValues(
+ DataTypes.ROW(
+ DataTypes.FIELD("id", DataTypes.INT),
+ DataTypes.FIELD("name", DataTypes.STRING),
+ DataTypes.FIELD("price", DataTypes.INT)),
+ row(1, "Latte", 6: java.lang.Integer),
+ row(2, "Milk", 3: java.lang.Integer),
+ row(3, "Breve", 5: java.lang.Integer),
+ row(4, "Mocha", 8: java.lang.Integer),
+ row(5, "Tea", 4: java.lang.Integer)
+ )
+
+ // Register the table aggregate function
+ tEnv.createTemporarySystemFunction("top2", new JavaUserDefinedTableAggFunctions.Top2)
+ tEnv.createTemporarySystemFunction(
+ "incrementalTop2",
+ new JavaUserDefinedTableAggFunctions.IncrementalTop2)
+
+ checkRank(
+ "top2",
+ List(
+ // output triggered by (1, "Latte", 6)
+ "(true,6,1)",
+ // output triggered by (2, "Milk", 3)
+ "(false,6,1)",
+ "(true,6,1)",
+ "(true,3,2)",
+ // output triggered by (3, "Breve", 5)
+ "(false,6,1)",
+ "(false,3,2)",
+ "(true,6,1)",
+ "(true,5,2)",
+ // output triggered by (4, "Mocha", 8)
+ "(false,6,1)",
+ "(false,5,2)",
+ "(true,8,1)",
+ "(true,6,2)",
+ // output triggered by (5, "Tea", 4)
+ "(false,8,1)",
+ "(false,6,2)",
+ "(true,8,1)",
+ "(true,6,2)"
+ )
+ )
+ checkRank(
+ "incrementalTop2",
+ List(
+ // output triggered by (1, "Latte", 6)
+ "(true,6,1)",
+ // output triggered by (2, "Milk", 3)
+ "(true,3,2)",
+ // output triggered by (3, "Breve", 5)
+ "(false,3,2)",
+ "(true,5,2)",
+ // output triggered by (4, "Mocha", 8)
+ "(false,6,1)",
+ "(true,8,1)",
+ "(false,5,2)",
+ "(true,6,2)"
+ )
+ )
+
+ def checkRank(func: String, expectedResult: List[String]): Unit = {
+ val resultTable =
+ table
+ .flatAggregate(call(func, $("price")).as("top_price", "rank"))
+ .select($("top_price"), $("rank"))
+
+ val sink = new TestingRetractSink()
+ resultTable.toRetractStream[Row].addSink(sink).setParallelism(1)
+ env.execute()
+ assertThat(sink.getRawResults).isEqualTo(expectedResult)
+ }
+ }
+
@TestTemplate
def testGroupByFlatAggregate(): Unit = {
val top3 = new Top3
diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
index ad26d48b703..0c36db1cc7e 100644
--- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
+++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupTableAggFunction.java
@@ -51,6 +51,8 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
/** Whether this operator will generate UPDATE_BEFORE messages. */
private final boolean generateUpdateBefore;
+ private final boolean incrementalUpdate;
+
/** State idle retention time which unit is MILLISECONDS. */
private final long stateRetentionTime;
@@ -69,6 +71,7 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
* contain COUNT(*), i.e. doesn't contain retraction messages. We make sure there is a
* COUNT(*) if input stream contains retraction.
* @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages.
+ * @param incrementalUpdate Whether to update acc result incrementally.
* @param stateRetentionTime state idle retention time which unit is MILLISECONDS.
*/
public GroupTableAggFunction(
@@ -76,11 +79,13 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
LogicalType[] accTypes,
int indexOfCountStar,
boolean generateUpdateBefore,
+ boolean incrementalUpdate,
long stateRetentionTime) {
this.genAggsHandler = genAggsHandler;
this.accTypes = accTypes;
this.recordCounter = RecordCounter.of(indexOfCountStar);
this.generateUpdateBefore = generateUpdateBefore;
+ this.incrementalUpdate = incrementalUpdate;
this.stateRetentionTime = stateRetentionTime;
}
@@ -117,7 +122,9 @@ public class GroupTableAggFunction extends KeyedProcessFunction<RowData, RowData
// set accumulators to handler first
function.setAccumulators(accumulators);
- if (!firstRow && generateUpdateBefore) {
+ // when incrementalUpdate is required, there is no need to retract
+ // previous sent data which is not changed
+ if (!firstRow && !incrementalUpdate && generateUpdateBefore) {
function.emitValue(out, currentKey, true);
}