You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/09/17 01:30:44 UTC

[flink] branch master updated: [FLINK-19173][python] Introduce BatchArrowPythonGroupAggregateFunctionOperator

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 7925647  [FLINK-19173][python] Introduce BatchArrowPythonGroupAggregateFunctionOperator
7925647 is described below

commit 7925647e3ae67a0a8914b4eb1644c318023eeb88
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Thu Sep 10 14:41:11 2020 +0800

    [FLINK-19173][python] Introduce BatchArrowPythonGroupAggregateFunctionOperator
    
    This closes #13369
---
 .../pyflink/fn_execution/beam/beam_coders.py       |   2 +
 flink-python/pyflink/fn_execution/coders.py        |   1 +
 .../DataStreamPythonReduceFunctionOperator.java    |   1 +
 .../DataStreamPythonStatelessFunctionOperator.java |   1 +
 ...eamTwoInputPythonStatelessFunctionOperator.java |   1 +
 .../python/AbstractPythonFunctionOperatorBase.java |  10 +-
 .../python/AbstractStatelessFunctionOperator.java  |   3 +-
 ...stractArrowPythonAggregateFunctionOperator.java | 168 ++++++++++++++
 ...tBatchArrowPythonAggregateFunctionOperator.java | 128 +++++++++++
 ...hArrowPythonGroupAggregateFunctionOperator.java | 104 +++++++++
 ...rowPythonAggregateFunctionOperatorTestBase.java |  92 ++++++++
 ...owPythonGroupAggregateFunctionOperatorTest.java | 255 +++++++++++++++++++++
 .../PassThroughPythonAggregateFunctionRunner.java  | 107 +++++++++
 13 files changed, 870 insertions(+), 3 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/beam/beam_coders.py b/flink-python/pyflink/fn_execution/beam/beam_coders.py
index 2c2f4e5..cc14db8d 100644
--- a/flink-python/pyflink/fn_execution/beam/beam_coders.py
+++ b/flink-python/pyflink/fn_execution/beam/beam_coders.py
@@ -151,6 +151,8 @@ class ArrowCoder(FastCoder):
         import pandas as pd
         return pd.Series
 
+    @Coder.register_urn(coders.FLINK_SCHEMA_ARROW_CODER_URN,
+                        flink_fn_execution_pb2.Schema)
     @Coder.register_urn(coders.FLINK_SCALAR_FUNCTION_SCHEMA_ARROW_CODER_URN,
                         flink_fn_execution_pb2.Schema)
     def _pickle_from_runner_api_parameter(schema_proto, unused_components, unused_context):
diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 57788c6..dd29302 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -37,6 +37,7 @@ __all__ = ['RowCoder', 'BigIntCoder', 'TinyIntCoder', 'BooleanCoder',
 FLINK_SCALAR_FUNCTION_SCHEMA_CODER_URN = "flink:coder:schema:scalar_function:v1"
 FLINK_TABLE_FUNCTION_SCHEMA_CODER_URN = "flink:coder:schema:table_function:v1"
 FLINK_SCALAR_FUNCTION_SCHEMA_ARROW_CODER_URN = "flink:coder:schema:scalar_function:arrow:v1"
+FLINK_SCHEMA_ARROW_CODER_URN = "flink:coder:schema:arrow:v1"
 FLINK_MAP_FUNCTION_DATA_STREAM_CODER_URN = "flink:coder:datastream:map_function:v1"
 FLINK_FLAT_MAP_FUNCTION_DATA_STREAM_CODER_URN = "flink:coder:datastream:flatmap_function:v1"
 
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java
index e3596dc..b73da63 100644
--- a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java
@@ -90,6 +90,7 @@ public class DataStreamPythonReduceFunctionOperator<OUT>
 			runnerInputTypeSerializer.serialize(reuseRow, baosWrapper);
 			pythonFunctionRunner.process(baos.toByteArray());
 			baos.reset();
+			elementCount++;
 			checkInvokeFinishBundleByCount();
 			emitResults();
 		}
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java
index a2ec04c..84d43f2 100644
--- a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java
@@ -147,6 +147,7 @@ public class DataStreamPythonStatelessFunctionOperator<IN, OUT>
 		inputTypeSerializer.serialize(element.getValue(), baosWrapper);
 		pythonFunctionRunner.process(baos.toByteArray());
 		baos.reset();
+		elementCount++;
 		checkInvokeFinishBundleByCount();
 		emitResults();
 	}
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
index 96ef311..e35f070 100644
--- a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
@@ -206,6 +206,7 @@ public class DataStreamTwoInputPythonStatelessFunctionOperator<IN1, IN2, OUT>
 		runnerInputTypeSerializer.serialize(reuseRow, baosWrapper);
 		pythonFunctionRunner.process(baos.toByteArray());
 		baos.reset();
+		elementCount++;
 		checkInvokeFinishBundleByCount();
 		emitResults();
 	}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
index 5a989ba..ca30232 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
@@ -214,7 +214,7 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
 		if (mark.getTimestamp() == Long.MAX_VALUE) {
 			invokeFinishBundle();
 			super.processWatermark(mark);
-		} else if (elementCount == 0) {
+		} else if (isBundleFinished()) {
 			// forward the watermark immediately if the bundle is already finished.
 			super.processWatermark(mark);
 		} else {
@@ -234,6 +234,13 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
 	}
 
 	/**
+	 * Returns whether the bundle is finished.
+	 */
+	public boolean isBundleFinished() {
+		return elementCount == 0;
+	}
+
+	/**
 	 * Reset the {@link PythonConfig} if needed.
 	 * */
 	public void setPythonConfig(PythonConfig pythonConfig) {
@@ -299,7 +306,6 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
 	 * Checks whether to invoke finishBundle by elements count. Called in processElement.
 	 */
 	protected void checkInvokeFinishBundleByCount() throws Exception {
-		elementCount++;
 		if (elementCount >= maxBundleSize) {
 			invokeFinishBundle();
 		}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java
index a3ae165..eb46961 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java
@@ -145,6 +145,7 @@ public abstract class AbstractStatelessFunctionOperator<IN, OUT, UDFIN>
 		IN value = element.getValue();
 		bufferInput(value);
 		processElementInternal(value);
+		elementCount++;
 		checkInvokeFinishBundleByCount();
 		emitResults();
 	}
@@ -185,7 +186,7 @@ public abstract class AbstractStatelessFunctionOperator<IN, OUT, UDFIN>
 	 * Buffers the specified input, it will be used to construct
 	 * the operator result together with the user-defined function execution result.
 	 */
-	public abstract void bufferInput(IN input);
+	public abstract void bufferInput(IN input) throws Exception;
 
 	public abstract UDFIN getFunctionInput(IN element);
 
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
new file mode 100644
index 0000000..398a17d
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.data.JoinedRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonEnv;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.runtime.arrow.serializers.ArrowSerializer;
+import org.apache.flink.table.runtime.arrow.serializers.RowDataArrowSerializer;
+import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.operators.python.AbstractStatelessFunctionOperator;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * The Abstract class of Arrow Aggregate Operator for Pandas {@link AggregateFunction}.
+ */
+@Internal
+public abstract class AbstractArrowPythonAggregateFunctionOperator
+	extends AbstractStatelessFunctionOperator<RowData, RowData, RowData> {
+
+	private static final long serialVersionUID = 1L;
+
+	private static final String SCHEMA_ARROW_CODER_URN = "flink:coder:schema:arrow:v1";
+
+	private static final String PANDAS_AGGREGATE_FUNCTION_URN = "flink:transform:aggregate_function:arrow:v1";
+
+	/**
+	 * The Pandas {@link AggregateFunction}s to be executed.
+	 */
+	private final PythonFunctionInfo[] pandasAggFunctions;
+
+	protected final int[] groupingSet;
+
+	protected transient ArrowSerializer<RowData> arrowSerializer;
+
+	/**
+	 * The collector used to collect records.
+	 */
+	protected transient StreamRecordRowDataWrappingCollector rowDataWrapper;
+
+	/**
+	 * The JoinedRowData reused holding the execution result.
+	 */
+	protected transient JoinedRowData reuseJoinedRow;
+
+	/**
+	 * The current number of elements to be included in an arrow batch.
+	 */
+	protected transient int currentBatchCount;
+
+	/**
+	 * The Projection which projects the udaf input fields from the input row.
+	 */
+	private transient Projection<RowData, BinaryRowData> udafInputProjection;
+
+	public AbstractArrowPythonAggregateFunctionOperator(
+		Configuration config,
+		PythonFunctionInfo[] pandasAggFunctions,
+		RowType inputType,
+		RowType outputType,
+		int[] groupingSet,
+		int[] udafInputOffsets) {
+		super(config, inputType, outputType, udafInputOffsets);
+		this.pandasAggFunctions = Preconditions.checkNotNull(pandasAggFunctions);
+		this.groupingSet = Preconditions.checkNotNull(groupingSet);
+	}
+
+	@Override
+	public void open() throws Exception {
+		super.open();
+		rowDataWrapper = new StreamRecordRowDataWrappingCollector(output);
+		reuseJoinedRow = new JoinedRowData();
+
+		udafInputProjection = createUdafInputProjection();
+		arrowSerializer = new RowDataArrowSerializer(userDefinedFunctionInputType, userDefinedFunctionOutputType);
+		arrowSerializer.open(bais, baos);
+		currentBatchCount = 0;
+	}
+
+	@Override
+	public void dispose() throws Exception {
+		super.dispose();
+		arrowSerializer.close();
+	}
+
+	@Override
+	public void processElement(StreamRecord<RowData> element) throws Exception {
+		RowData value = element.getValue();
+		bufferInput(value);
+		processElementInternal(value);
+		emitResults();
+	}
+
+	@Override
+	public boolean isBundleFinished() {
+		return elementCount == 0 && currentBatchCount == 0;
+	}
+
+	@Override
+	public PythonEnv getPythonEnv() {
+		return pandasAggFunctions[0].getPythonFunction().getPythonEnv();
+	}
+
+	@Override
+	public String getFunctionUrn() {
+		return PANDAS_AGGREGATE_FUNCTION_URN;
+	}
+
+	@Override
+	public String getInputOutputCoderUrn() {
+		return SCHEMA_ARROW_CODER_URN;
+	}
+
+	@Override
+	public RowData getFunctionInput(RowData element) {
+		return udafInputProjection.apply(element);
+	}
+
+	@Override
+	public FlinkFnApi.UserDefinedFunctions getUserDefinedFunctionsProto() {
+		FlinkFnApi.UserDefinedFunctions.Builder builder = FlinkFnApi.UserDefinedFunctions.newBuilder();
+		// add udaf proto
+		for (PythonFunctionInfo pythonFunctionInfo : pandasAggFunctions) {
+			builder.addUdfs(getUserDefinedFunctionProto(pythonFunctionInfo));
+		}
+		builder.setMetricEnabled(getPythonConfig().isMetricEnabled());
+		return builder.build();
+	}
+
+	private Projection<RowData, BinaryRowData> createUdafInputProjection() {
+		final GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection(
+			CodeGeneratorContext.apply(new TableConfig()),
+			"UadfInputProjection",
+			inputType,
+			userDefinedFunctionInputType,
+			userDefinedFunctionInputOffsets);
+		// noinspection unchecked
+		return generatedProjection.newInstance(Thread.currentThread().getContextClassLoader());
+	}
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/AbstractBatchArrowPythonAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/AbstractBatchArrowPythonAggregateFunctionOperator.java
new file mode 100644
index 0000000..6bd0d19
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/AbstractBatchArrowPythonAggregateFunctionOperator.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.data.binary.BinaryRowDataUtil;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.operators.python.aggregate.arrow.AbstractArrowPythonAggregateFunctionOperator;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+import java.util.stream.Collectors;
+
+/**
+ * The Abstract class of Batch Arrow Aggregate Operator for Pandas {@link AggregateFunction}.
+ */
+@Internal
+abstract class AbstractBatchArrowPythonAggregateFunctionOperator
+	extends AbstractArrowPythonAggregateFunctionOperator {
+
+	private static final long serialVersionUID = 1L;
+
+	private final int[] groupKey;
+
+	/**
+	 * Last group key value.
+	 */
+	transient BinaryRowData lastGroupKey;
+
+	/**
+	 * Last group set value.
+	 */
+	transient BinaryRowData lastGroupSet;
+
+	/**
+	 * The Projection which projects the group key fields from the input row.
+	 */
+	transient Projection<RowData, BinaryRowData> groupKeyProjection;
+
+	/**
+	 * The Projection which projects the group set fields (group key and aux group key) from the input row.
+	 */
+	transient Projection<RowData, BinaryRowData> groupSetProjection;
+
+	AbstractBatchArrowPythonAggregateFunctionOperator(
+		Configuration config,
+		PythonFunctionInfo[] pandasAggFunctions,
+		RowType inputType,
+		RowType outputType,
+		int[] groupKey,
+		int[] groupingSet,
+		int[] udafInputOffsets) {
+		super(config, pandasAggFunctions, inputType, outputType, groupingSet, udafInputOffsets);
+		this.groupKey = Preconditions.checkNotNull(groupKey);
+	}
+
+	@Override
+	public void open() throws Exception {
+		super.open();
+		groupKeyProjection = createProjection("GroupKey", groupKey);
+		groupSetProjection = createProjection("GroupSet", groupingSet);
+		lastGroupKey = null;
+		lastGroupSet = null;
+	}
+
+	@Override
+	public void endInput() throws Exception {
+		invokeCurrentBatch();
+		super.endInput();
+	}
+
+	@Override
+	public void close() throws Exception {
+		invokeCurrentBatch();
+		super.close();
+	}
+
+	protected abstract void invokeCurrentBatch() throws Exception;
+
+	boolean isNewKey(BinaryRowData currentKey) {
+		return lastGroupKey.getSizeInBytes() != currentKey.getSizeInBytes() ||
+			!(BinaryRowDataUtil.byteArrayEquals(
+				currentKey.getSegments()[0].getHeapMemory(),
+				lastGroupKey.getSegments()[0].getHeapMemory(),
+				currentKey.getSizeInBytes()));
+	}
+
+	private Projection<RowData, BinaryRowData> createProjection(String name, int[] fields) {
+		final RowType forwardedFieldType = new RowType(
+			Arrays.stream(fields)
+				.mapToObj(i -> inputType.getFields().get(i))
+				.collect(Collectors.toList()));
+		final GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection(
+			CodeGeneratorContext.apply(new TableConfig()),
+			name,
+			inputType,
+			forwardedFieldType,
+			fields);
+		// noinspection unchecked
+		return generatedProjection.newInstance(Thread.currentThread().getContextClassLoader());
+	}
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperator.java
new file mode 100644
index 0000000..e7f8371
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperator.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.types.logical.RowType;
+
+/**
+ * The Batch Arrow Python {@link AggregateFunction} Operator for Group Aggregation.
+ */
+@Internal
+public class BatchArrowPythonGroupAggregateFunctionOperator
+	extends AbstractBatchArrowPythonAggregateFunctionOperator {
+
+	private static final long serialVersionUID = 1L;
+
+	public BatchArrowPythonGroupAggregateFunctionOperator(
+		Configuration config,
+		PythonFunctionInfo[] pandasAggFunctions,
+		RowType inputType,
+		RowType outputType,
+		int[] groupKey,
+		int[] groupingSet,
+		int[] udafInputOffsets) {
+		super(config, pandasAggFunctions, inputType, outputType, groupKey, groupingSet, udafInputOffsets);
+	}
+
+	@Override
+	public void open() throws Exception {
+		userDefinedFunctionOutputType = new RowType(
+			outputType.getFields().subList(groupingSet.length, outputType.getFieldCount()));
+		super.open();
+	}
+
+	@Override
+	protected void invokeCurrentBatch() throws Exception {
+		if (currentBatchCount > 0) {
+			arrowSerializer.finishCurrentBatch();
+			pythonFunctionRunner.process(baos.toByteArray());
+			baos.reset();
+			elementCount += currentBatchCount;
+			checkInvokeFinishBundleByCount();
+			currentBatchCount = 0;
+		}
+	}
+
+	@Override
+	public void bufferInput(RowData input) throws Exception {
+		BinaryRowData currentKey = groupKeyProjection.apply(input).copy();
+		if (lastGroupKey == null) {
+			lastGroupKey = currentKey;
+			lastGroupSet = groupSetProjection.apply(input).copy();
+			forwardedInputQueue.add(lastGroupSet);
+		} else if (isNewKey(currentKey)) {
+			invokeCurrentBatch();
+			lastGroupKey = currentKey;
+			lastGroupSet = groupSetProjection.apply(input).copy();
+			forwardedInputQueue.add(lastGroupSet);
+		}
+	}
+
+	@Override
+	public void processElementInternal(RowData value) {
+		arrowSerializer.write(getFunctionInput(value));
+		currentBatchCount++;
+	}
+
+	@Override
+	@SuppressWarnings("ConstantConditions")
+	public void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception {
+		byte[] udafResult = resultTuple.f0;
+		int length = resultTuple.f1;
+		bais.setBuffer(udafResult, 0, length);
+		int rowCount = arrowSerializer.load();
+		for (int i = 0; i < rowCount; i++) {
+			RowData key = forwardedInputQueue.poll();
+			reuseJoinedRow.setRowKind(key.getRowKind());
+			RowData result = arrowSerializer.read(i);
+			rowDataWrapper.collect(reuseJoinedRow.replace(key, result));
+		}
+	}
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/ArrowPythonAggregateFunctionOperatorTestBase.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/ArrowPythonAggregateFunctionOperatorTestBase.java
new file mode 100644
index 0000000..78a32f0
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/ArrowPythonAggregateFunctionOperatorTestBase.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.runtime.operators.python.scalar.PythonScalarFunctionOperatorTestBase;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.runtime.util.RowDataHarnessAssertor;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.types.RowKind;
+
+import java.util.Collection;
+
+import static org.apache.flink.table.runtime.util.StreamRecordUtils.row;
+
+/**
+ * Base class for Arrow Python aggregate function operator tests.
+ */
+public abstract class ArrowPythonAggregateFunctionOperatorTestBase {
+
+	private RowDataHarnessAssertor assertor = new RowDataHarnessAssertor(getOutputLogicalType());
+
+	protected OneInputStreamOperatorTestHarness<RowData, RowData> getTestHarness(
+		Configuration config) throws Exception {
+		RowType inputType = getInputType();
+		RowType outputType = getOutputType();
+		AbstractArrowPythonAggregateFunctionOperator operator = getTestOperator(
+			config,
+			new PythonFunctionInfo[]{
+				new PythonFunctionInfo(
+					PythonScalarFunctionOperatorTestBase.DummyPythonFunction.INSTANCE,
+					new Integer[]{0})},
+			inputType,
+			outputType,
+			new int[]{0},
+			new int[]{2});
+
+		OneInputStreamOperatorTestHarness<RowData, RowData> testHarness =
+			new OneInputStreamOperatorTestHarness<>(operator);
+		testHarness.getStreamConfig().setManagedMemoryFraction(0.5);
+		testHarness.setup(new RowDataSerializer(outputType));
+		return testHarness;
+	}
+
+	protected RowData newRow(boolean accumulateMsg, Object... fields) {
+		if (accumulateMsg) {
+			return row(fields);
+		} else {
+			RowData row = row(fields);
+			row.setRowKind(RowKind.DELETE);
+			return row;
+		}
+	}
+
+	protected void assertOutputEquals(String message, Collection<Object> expected, Collection<Object> actual) {
+		assertor.assertOutputEquals(message, expected, actual);
+	}
+
+	public abstract LogicalType[] getOutputLogicalType();
+
+	public abstract RowType getInputType();
+
+	public abstract RowType getOutputType();
+
+	public abstract AbstractArrowPythonAggregateFunctionOperator getTestOperator(
+		Configuration config,
+		PythonFunctionInfo[] pandasAggregateFunctions,
+		RowType inputType,
+		RowType outputType,
+		int[] groupingSet,
+		int[] udafInputOffsets);
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
new file mode 100644
index 0000000..7b9fd19
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.PythonFunctionRunner;
+import org.apache.flink.python.PythonOptions;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.runtime.operators.python.aggregate.arrow.AbstractArrowPythonAggregateFunctionOperator;
+import org.apache.flink.table.runtime.operators.python.aggregate.arrow.ArrowPythonAggregateFunctionOperatorTestBase;
+import org.apache.flink.table.runtime.utils.PassThroughPythonAggregateFunctionRunner;
+import org.apache.flink.table.runtime.utils.PythonTestUtils;
+import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/**
+ * Test for {@link BatchArrowPythonGroupAggregateFunctionOperator}. These test that:
+ *
+ * <ul>
+ * <li>FinishBundle is called when checkpoint is encountered</li>
+ * <li>Watermarks are buffered and only sent to downstream when finishedBundle is triggered</li>
+ * </ul>
+ */
+public class BatchArrowPythonGroupAggregateFunctionOperatorTest
+	extends ArrowPythonAggregateFunctionOperatorTestBase {
+
+	@Test
+	public void testGroupAggregateFunction() throws Exception {
+		OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(
+			new Configuration());
+		long initialTime = 0L;
+		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c4", 1L), initialTime + 2));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 3));
+		testHarness.close();
+
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+	}
+
+	@Test
+	public void testFinishBundleTriggeredOnCheckpoint() throws Exception {
+		Configuration conf = new Configuration();
+		conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
+		OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+
+		long initialTime = 0L;
+		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c4", 1L), initialTime + 2));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 3));
+		// checkpoint trigger finishBundle
+		testHarness.prepareSnapshotPreBarrier(0L);
+
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+		testHarness.close();
+
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+	}
+
+	@Test
+	public void testFinishBundleTriggeredByCount() throws Exception {
+		Configuration conf = new Configuration();
+		conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 2);
+		OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+
+		long initialTime = 0L;
+		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 1L), initialTime + 2));
+		assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
+
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 2));
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+		testHarness.close();
+
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+	}
+
+	@Test
+	public void testFinishBundleTriggeredByTime() throws Exception {
+		Configuration conf = new Configuration();
+		conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
+		conf.setLong(PythonOptions.MAX_BUNDLE_TIME_MILLS, 1000L);
+		OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+
+		long initialTime = 0L;
+		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 1L), initialTime + 2));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 2));
+		assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
+
+		testHarness.setProcessingTime(1000L);
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+		testHarness.close();
+
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+	}
+
+	@Test
+	public void testWatermarkProcessedOnFinishBundle() throws Exception {
+		Configuration conf = new Configuration();
+		conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
+		OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+		long initialTime = 0L;
+		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+		testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 2));
+		testHarness.processWatermark(initialTime + 2);
+		assertOutputEquals("Watermark has been processed", expectedOutput, testHarness.getOutput());
+
+		// checkpoint trigger finishBundle
+		testHarness.prepareSnapshotPreBarrier(0L);
+
+		expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+		expectedOutput.add(new Watermark(initialTime + 2));
+
+		assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+		testHarness.close();
+	}
+
+	@Override
+	public LogicalType[] getOutputLogicalType() {
+		return new LogicalType[]{
+			DataTypes.STRING().getLogicalType(),
+			DataTypes.BIGINT().getLogicalType()
+		};
+	}
+
+	@Override
+	public RowType getInputType() {
+		return new RowType(Arrays.asList(
+			new RowType.RowField("f1", new VarCharType()),
+			new RowType.RowField("f2", new VarCharType()),
+			new RowType.RowField("f3", new BigIntType())));
+	}
+
+	@Override
+	public RowType getOutputType() {
+		return new RowType(Arrays.asList(
+			new RowType.RowField("f1", new VarCharType()),
+			new RowType.RowField("f2", new BigIntType())));
+	}
+
+	@Override
+	public AbstractArrowPythonAggregateFunctionOperator getTestOperator(
+		Configuration config,
+		PythonFunctionInfo[] pandasAggregateFunctions,
+		RowType inputType,
+		RowType outputType,
+		int[] groupingSet,
+		int[] udafInputOffsets) {
+		return new PassThroughBatchArrowPythonGroupAggregateFunctionOperator(
+			config,
+			pandasAggregateFunctions,
+			inputType,
+			outputType,
+			groupingSet,
+			groupingSet,
+			udafInputOffsets);
+	}
+
+	private static class PassThroughBatchArrowPythonGroupAggregateFunctionOperator
+		extends BatchArrowPythonGroupAggregateFunctionOperator {
+
+		PassThroughBatchArrowPythonGroupAggregateFunctionOperator(
+			Configuration config,
+			PythonFunctionInfo[] pandasAggregateFunctions,
+			RowType inputType,
+			RowType outputType,
+			int[] groupKey,
+			int[] groupingSet,
+			int[] udafInputOffsets) {
+			super(config, pandasAggregateFunctions, inputType, outputType, groupKey, groupingSet, udafInputOffsets);
+		}
+
+		@Override
+		public PythonFunctionRunner createPythonFunctionRunner() {
+			return new PassThroughPythonAggregateFunctionRunner(
+				getRuntimeContext().getTaskName(),
+				PythonTestUtils.createTestEnvironmentManager(),
+				userDefinedFunctionInputType,
+				userDefinedFunctionOutputType,
+				getFunctionUrn(),
+				getUserDefinedFunctionsProto(),
+				getInputOutputCoderUrn(),
+				new HashMap<>(),
+				PythonTestUtils.createMockFlinkMetricContainer()
+			);
+		}
+	}
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
new file mode 100644
index 0000000..2f474b9
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.utils;
+
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.python.PythonConfig;
+import org.apache.flink.python.env.PythonEnvironmentManager;
+import org.apache.flink.python.metric.FlinkMetricContainer;
+import org.apache.flink.table.runtime.arrow.serializers.RowDataArrowSerializer;
+import org.apache.flink.table.runtime.runners.python.beam.BeamTablePythonStatelessFunctionRunner;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Struct;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A {@link PassThroughPythonAggregateFunctionRunner} runner that just return the first input element
+ * with the same key as the execution results.
+ */
+public class PassThroughPythonAggregateFunctionRunner extends BeamTablePythonStatelessFunctionRunner {
+
+	private final List<byte[]> buffer;
+
+	private final RowDataArrowSerializer arrowSerializer;
+
+	/**
+	 * Reusable InputStream used to holding the execution results to be deserialized.
+	 */
+	private transient ByteArrayInputStreamWithPos bais;
+
+	/**
+	 * Reusable OutputStream used to holding the serialized input elements.
+	 */
+	private transient ByteArrayOutputStreamWithPos baos;
+
+	public PassThroughPythonAggregateFunctionRunner(
+		String taskName,
+		PythonEnvironmentManager environmentManager,
+		RowType inputType,
+		RowType outputType,
+		String functionUrn,
+		FlinkFnApi.UserDefinedFunctions userDefinedFunctions,
+		String coderUrn,
+		Map<String, String> jobOptions,
+		FlinkMetricContainer flinkMetricContainer) {
+		super(taskName, environmentManager, inputType, outputType, functionUrn, userDefinedFunctions,
+			coderUrn, jobOptions, flinkMetricContainer);
+		this.buffer = new LinkedList<>();
+		arrowSerializer = new RowDataArrowSerializer(inputType, outputType);
+	}
+
+	@Override
+	public void open(PythonConfig config) throws Exception {
+		super.open(config);
+		bais = new ByteArrayInputStreamWithPos();
+		baos = new ByteArrayOutputStreamWithPos();
+		arrowSerializer.open(bais, baos);
+	}
+
+	@Override
+	protected void startBundle() {
+		super.startBundle();
+		this.mainInputReceiver = input -> {
+			byte[] data = input.getValue();
+			bais.setBuffer(data, 0, data.length);
+			arrowSerializer.load();
+			arrowSerializer.write(arrowSerializer.read(0));
+			arrowSerializer.finishCurrentBatch();
+			buffer.add(baos.toByteArray());
+			baos.reset();
+		};
+	}
+
+	@Override
+	public void flush() throws Exception {
+		super.flush();
+		resultBuffer.addAll(buffer);
+		buffer.clear();
+	}
+
+	@Override
+	public JobBundleFactory createJobBundleFactory(Struct pipelineOptions) {
+		return PythonTestUtils.createMockJobBundleFactory();
+	}
+}