You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/09/17 01:30:44 UTC
[flink] branch master updated: [FLINK-19173][python] Introduce
BatchArrowPythonGroupAggregateFunctionOperator
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 7925647 [FLINK-19173][python] Introduce BatchArrowPythonGroupAggregateFunctionOperator
7925647 is described below
commit 7925647e3ae67a0a8914b4eb1644c318023eeb88
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Thu Sep 10 14:41:11 2020 +0800
[FLINK-19173][python] Introduce BatchArrowPythonGroupAggregateFunctionOperator
This closes #13369
---
.../pyflink/fn_execution/beam/beam_coders.py | 2 +
flink-python/pyflink/fn_execution/coders.py | 1 +
.../DataStreamPythonReduceFunctionOperator.java | 1 +
.../DataStreamPythonStatelessFunctionOperator.java | 1 +
...eamTwoInputPythonStatelessFunctionOperator.java | 1 +
.../python/AbstractPythonFunctionOperatorBase.java | 10 +-
.../python/AbstractStatelessFunctionOperator.java | 3 +-
...stractArrowPythonAggregateFunctionOperator.java | 168 ++++++++++++++
...tBatchArrowPythonAggregateFunctionOperator.java | 128 +++++++++++
...hArrowPythonGroupAggregateFunctionOperator.java | 104 +++++++++
...rowPythonAggregateFunctionOperatorTestBase.java | 92 ++++++++
...owPythonGroupAggregateFunctionOperatorTest.java | 255 +++++++++++++++++++++
.../PassThroughPythonAggregateFunctionRunner.java | 107 +++++++++
13 files changed, 870 insertions(+), 3 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/beam/beam_coders.py b/flink-python/pyflink/fn_execution/beam/beam_coders.py
index 2c2f4e5..cc14db8d 100644
--- a/flink-python/pyflink/fn_execution/beam/beam_coders.py
+++ b/flink-python/pyflink/fn_execution/beam/beam_coders.py
@@ -151,6 +151,8 @@ class ArrowCoder(FastCoder):
import pandas as pd
return pd.Series
+ @Coder.register_urn(coders.FLINK_SCHEMA_ARROW_CODER_URN,
+ flink_fn_execution_pb2.Schema)
@Coder.register_urn(coders.FLINK_SCALAR_FUNCTION_SCHEMA_ARROW_CODER_URN,
flink_fn_execution_pb2.Schema)
def _pickle_from_runner_api_parameter(schema_proto, unused_components, unused_context):
diff --git a/flink-python/pyflink/fn_execution/coders.py b/flink-python/pyflink/fn_execution/coders.py
index 57788c6..dd29302 100644
--- a/flink-python/pyflink/fn_execution/coders.py
+++ b/flink-python/pyflink/fn_execution/coders.py
@@ -37,6 +37,7 @@ __all__ = ['RowCoder', 'BigIntCoder', 'TinyIntCoder', 'BooleanCoder',
FLINK_SCALAR_FUNCTION_SCHEMA_CODER_URN = "flink:coder:schema:scalar_function:v1"
FLINK_TABLE_FUNCTION_SCHEMA_CODER_URN = "flink:coder:schema:table_function:v1"
FLINK_SCALAR_FUNCTION_SCHEMA_ARROW_CODER_URN = "flink:coder:schema:scalar_function:arrow:v1"
+FLINK_SCHEMA_ARROW_CODER_URN = "flink:coder:schema:arrow:v1"
FLINK_MAP_FUNCTION_DATA_STREAM_CODER_URN = "flink:coder:datastream:map_function:v1"
FLINK_FLAT_MAP_FUNCTION_DATA_STREAM_CODER_URN = "flink:coder:datastream:flatmap_function:v1"
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java
index e3596dc..b73da63 100644
--- a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonReduceFunctionOperator.java
@@ -90,6 +90,7 @@ public class DataStreamPythonReduceFunctionOperator<OUT>
runnerInputTypeSerializer.serialize(reuseRow, baosWrapper);
pythonFunctionRunner.process(baos.toByteArray());
baos.reset();
+ elementCount++;
checkInvokeFinishBundleByCount();
emitResults();
}
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java
index a2ec04c..84d43f2 100644
--- a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamPythonStatelessFunctionOperator.java
@@ -147,6 +147,7 @@ public class DataStreamPythonStatelessFunctionOperator<IN, OUT>
inputTypeSerializer.serialize(element.getValue(), baosWrapper);
pythonFunctionRunner.process(baos.toByteArray());
baos.reset();
+ elementCount++;
checkInvokeFinishBundleByCount();
emitResults();
}
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
index 96ef311..e35f070 100644
--- a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
@@ -206,6 +206,7 @@ public class DataStreamTwoInputPythonStatelessFunctionOperator<IN1, IN2, OUT>
runnerInputTypeSerializer.serialize(reuseRow, baosWrapper);
pythonFunctionRunner.process(baos.toByteArray());
baos.reset();
+ elementCount++;
checkInvokeFinishBundleByCount();
emitResults();
}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
index 5a989ba..ca30232 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
@@ -214,7 +214,7 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
if (mark.getTimestamp() == Long.MAX_VALUE) {
invokeFinishBundle();
super.processWatermark(mark);
- } else if (elementCount == 0) {
+ } else if (isBundleFinished()) {
// forward the watermark immediately if the bundle is already finished.
super.processWatermark(mark);
} else {
@@ -234,6 +234,13 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
}
/**
+ * Returns whether the bundle is finished.
+ */
+ public boolean isBundleFinished() {
+ return elementCount == 0;
+ }
+
+ /**
* Reset the {@link PythonConfig} if needed.
* */
public void setPythonConfig(PythonConfig pythonConfig) {
@@ -299,7 +306,6 @@ public abstract class AbstractPythonFunctionOperatorBase<OUT>
* Checks whether to invoke finishBundle by elements count. Called in processElement.
*/
protected void checkInvokeFinishBundleByCount() throws Exception {
- elementCount++;
if (elementCount >= maxBundleSize) {
invokeFinishBundle();
}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java
index a3ae165..eb46961 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/AbstractStatelessFunctionOperator.java
@@ -145,6 +145,7 @@ public abstract class AbstractStatelessFunctionOperator<IN, OUT, UDFIN>
IN value = element.getValue();
bufferInput(value);
processElementInternal(value);
+ elementCount++;
checkInvokeFinishBundleByCount();
emitResults();
}
@@ -185,7 +186,7 @@ public abstract class AbstractStatelessFunctionOperator<IN, OUT, UDFIN>
* Buffers the specified input, it will be used to construct
* the operator result together with the user-defined function execution result.
*/
- public abstract void bufferInput(IN input);
+ public abstract void bufferInput(IN input) throws Exception;
public abstract UDFIN getFunctionInput(IN element);
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
new file mode 100644
index 0000000..398a17d
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.data.JoinedRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonEnv;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.runtime.arrow.serializers.ArrowSerializer;
+import org.apache.flink.table.runtime.arrow.serializers.RowDataArrowSerializer;
+import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.operators.python.AbstractStatelessFunctionOperator;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * The Abstract class of Arrow Aggregate Operator for Pandas {@link AggregateFunction}.
+ */
+@Internal
+public abstract class AbstractArrowPythonAggregateFunctionOperator
+ extends AbstractStatelessFunctionOperator<RowData, RowData, RowData> {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final String SCHEMA_ARROW_CODER_URN = "flink:coder:schema:arrow:v1";
+
+ private static final String PANDAS_AGGREGATE_FUNCTION_URN = "flink:transform:aggregate_function:arrow:v1";
+
+ /**
+ * The Pandas {@link AggregateFunction}s to be executed.
+ */
+ private final PythonFunctionInfo[] pandasAggFunctions;
+
+ protected final int[] groupingSet;
+
+ protected transient ArrowSerializer<RowData> arrowSerializer;
+
+ /**
+ * The collector used to collect records.
+ */
+ protected transient StreamRecordRowDataWrappingCollector rowDataWrapper;
+
+ /**
+ * The JoinedRowData reused holding the execution result.
+ */
+ protected transient JoinedRowData reuseJoinedRow;
+
+ /**
+ * The current number of elements to be included in an arrow batch.
+ */
+ protected transient int currentBatchCount;
+
+ /**
+ * The Projection which projects the udaf input fields from the input row.
+ */
+ private transient Projection<RowData, BinaryRowData> udafInputProjection;
+
+ public AbstractArrowPythonAggregateFunctionOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggFunctions,
+ RowType inputType,
+ RowType outputType,
+ int[] groupingSet,
+ int[] udafInputOffsets) {
+ super(config, inputType, outputType, udafInputOffsets);
+ this.pandasAggFunctions = Preconditions.checkNotNull(pandasAggFunctions);
+ this.groupingSet = Preconditions.checkNotNull(groupingSet);
+ }
+
+ @Override
+ public void open() throws Exception {
+ super.open();
+ rowDataWrapper = new StreamRecordRowDataWrappingCollector(output);
+ reuseJoinedRow = new JoinedRowData();
+
+ udafInputProjection = createUdafInputProjection();
+ arrowSerializer = new RowDataArrowSerializer(userDefinedFunctionInputType, userDefinedFunctionOutputType);
+ arrowSerializer.open(bais, baos);
+ currentBatchCount = 0;
+ }
+
+ @Override
+ public void dispose() throws Exception {
+ super.dispose();
+ arrowSerializer.close();
+ }
+
+ @Override
+ public void processElement(StreamRecord<RowData> element) throws Exception {
+ RowData value = element.getValue();
+ bufferInput(value);
+ processElementInternal(value);
+ emitResults();
+ }
+
+ @Override
+ public boolean isBundleFinished() {
+ return elementCount == 0 && currentBatchCount == 0;
+ }
+
+ @Override
+ public PythonEnv getPythonEnv() {
+ return pandasAggFunctions[0].getPythonFunction().getPythonEnv();
+ }
+
+ @Override
+ public String getFunctionUrn() {
+ return PANDAS_AGGREGATE_FUNCTION_URN;
+ }
+
+ @Override
+ public String getInputOutputCoderUrn() {
+ return SCHEMA_ARROW_CODER_URN;
+ }
+
+ @Override
+ public RowData getFunctionInput(RowData element) {
+ return udafInputProjection.apply(element);
+ }
+
+ @Override
+ public FlinkFnApi.UserDefinedFunctions getUserDefinedFunctionsProto() {
+ FlinkFnApi.UserDefinedFunctions.Builder builder = FlinkFnApi.UserDefinedFunctions.newBuilder();
+ // add udaf proto
+ for (PythonFunctionInfo pythonFunctionInfo : pandasAggFunctions) {
+ builder.addUdfs(getUserDefinedFunctionProto(pythonFunctionInfo));
+ }
+ builder.setMetricEnabled(getPythonConfig().isMetricEnabled());
+ return builder.build();
+ }
+
+ private Projection<RowData, BinaryRowData> createUdafInputProjection() {
+ final GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection(
+ CodeGeneratorContext.apply(new TableConfig()),
+ "UadfInputProjection",
+ inputType,
+ userDefinedFunctionInputType,
+ userDefinedFunctionInputOffsets);
+ // noinspection unchecked
+ return generatedProjection.newInstance(Thread.currentThread().getContextClassLoader());
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/AbstractBatchArrowPythonAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/AbstractBatchArrowPythonAggregateFunctionOperator.java
new file mode 100644
index 0000000..6bd0d19
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/AbstractBatchArrowPythonAggregateFunctionOperator.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.data.binary.BinaryRowDataUtil;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
+import org.apache.flink.table.runtime.generated.GeneratedProjection;
+import org.apache.flink.table.runtime.generated.Projection;
+import org.apache.flink.table.runtime.operators.python.aggregate.arrow.AbstractArrowPythonAggregateFunctionOperator;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+import java.util.stream.Collectors;
+
+/**
+ * The Abstract class of Batch Arrow Aggregate Operator for Pandas {@link AggregateFunction}.
+ */
+@Internal
+abstract class AbstractBatchArrowPythonAggregateFunctionOperator
+ extends AbstractArrowPythonAggregateFunctionOperator {
+
+ private static final long serialVersionUID = 1L;
+
+ private final int[] groupKey;
+
+ /**
+ * Last group key value.
+ */
+ transient BinaryRowData lastGroupKey;
+
+ /**
+ * Last group set value.
+ */
+ transient BinaryRowData lastGroupSet;
+
+ /**
+ * The Projection which projects the group key fields from the input row.
+ */
+ transient Projection<RowData, BinaryRowData> groupKeyProjection;
+
+ /**
+ * The Projection which projects the group set fields (group key and aux group key) from the input row.
+ */
+ transient Projection<RowData, BinaryRowData> groupSetProjection;
+
+ AbstractBatchArrowPythonAggregateFunctionOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggFunctions,
+ RowType inputType,
+ RowType outputType,
+ int[] groupKey,
+ int[] groupingSet,
+ int[] udafInputOffsets) {
+ super(config, pandasAggFunctions, inputType, outputType, groupingSet, udafInputOffsets);
+ this.groupKey = Preconditions.checkNotNull(groupKey);
+ }
+
+ @Override
+ public void open() throws Exception {
+ super.open();
+ groupKeyProjection = createProjection("GroupKey", groupKey);
+ groupSetProjection = createProjection("GroupSet", groupingSet);
+ lastGroupKey = null;
+ lastGroupSet = null;
+ }
+
+ @Override
+ public void endInput() throws Exception {
+ invokeCurrentBatch();
+ super.endInput();
+ }
+
+ @Override
+ public void close() throws Exception {
+ invokeCurrentBatch();
+ super.close();
+ }
+
+ protected abstract void invokeCurrentBatch() throws Exception;
+
+ boolean isNewKey(BinaryRowData currentKey) {
+ return lastGroupKey.getSizeInBytes() != currentKey.getSizeInBytes() ||
+ !(BinaryRowDataUtil.byteArrayEquals(
+ currentKey.getSegments()[0].getHeapMemory(),
+ lastGroupKey.getSegments()[0].getHeapMemory(),
+ currentKey.getSizeInBytes()));
+ }
+
+ private Projection<RowData, BinaryRowData> createProjection(String name, int[] fields) {
+ final RowType forwardedFieldType = new RowType(
+ Arrays.stream(fields)
+ .mapToObj(i -> inputType.getFields().get(i))
+ .collect(Collectors.toList()));
+ final GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection(
+ CodeGeneratorContext.apply(new TableConfig()),
+ name,
+ inputType,
+ forwardedFieldType,
+ fields);
+ // noinspection unchecked
+ return generatedProjection.newInstance(Thread.currentThread().getContextClassLoader());
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperator.java
new file mode 100644
index 0000000..e7f8371
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperator.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.types.logical.RowType;
+
+/**
+ * The Batch Arrow Python {@link AggregateFunction} Operator for Group Aggregation.
+ */
+@Internal
+public class BatchArrowPythonGroupAggregateFunctionOperator
+ extends AbstractBatchArrowPythonAggregateFunctionOperator {
+
+ private static final long serialVersionUID = 1L;
+
+ public BatchArrowPythonGroupAggregateFunctionOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggFunctions,
+ RowType inputType,
+ RowType outputType,
+ int[] groupKey,
+ int[] groupingSet,
+ int[] udafInputOffsets) {
+ super(config, pandasAggFunctions, inputType, outputType, groupKey, groupingSet, udafInputOffsets);
+ }
+
+ @Override
+ public void open() throws Exception {
+ userDefinedFunctionOutputType = new RowType(
+ outputType.getFields().subList(groupingSet.length, outputType.getFieldCount()));
+ super.open();
+ }
+
+ @Override
+ protected void invokeCurrentBatch() throws Exception {
+ if (currentBatchCount > 0) {
+ arrowSerializer.finishCurrentBatch();
+ pythonFunctionRunner.process(baos.toByteArray());
+ baos.reset();
+ elementCount += currentBatchCount;
+ checkInvokeFinishBundleByCount();
+ currentBatchCount = 0;
+ }
+ }
+
+ @Override
+ public void bufferInput(RowData input) throws Exception {
+ BinaryRowData currentKey = groupKeyProjection.apply(input).copy();
+ if (lastGroupKey == null) {
+ lastGroupKey = currentKey;
+ lastGroupSet = groupSetProjection.apply(input).copy();
+ forwardedInputQueue.add(lastGroupSet);
+ } else if (isNewKey(currentKey)) {
+ invokeCurrentBatch();
+ lastGroupKey = currentKey;
+ lastGroupSet = groupSetProjection.apply(input).copy();
+ forwardedInputQueue.add(lastGroupSet);
+ }
+ }
+
+ @Override
+ public void processElementInternal(RowData value) {
+ arrowSerializer.write(getFunctionInput(value));
+ currentBatchCount++;
+ }
+
+ @Override
+ @SuppressWarnings("ConstantConditions")
+ public void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception {
+ byte[] udafResult = resultTuple.f0;
+ int length = resultTuple.f1;
+ bais.setBuffer(udafResult, 0, length);
+ int rowCount = arrowSerializer.load();
+ for (int i = 0; i < rowCount; i++) {
+ RowData key = forwardedInputQueue.poll();
+ reuseJoinedRow.setRowKind(key.getRowKind());
+ RowData result = arrowSerializer.read(i);
+ rowDataWrapper.collect(reuseJoinedRow.replace(key, result));
+ }
+ }
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/ArrowPythonAggregateFunctionOperatorTestBase.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/ArrowPythonAggregateFunctionOperatorTestBase.java
new file mode 100644
index 0000000..78a32f0
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/ArrowPythonAggregateFunctionOperatorTestBase.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.runtime.operators.python.scalar.PythonScalarFunctionOperatorTestBase;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.runtime.util.RowDataHarnessAssertor;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.types.RowKind;
+
+import java.util.Collection;
+
+import static org.apache.flink.table.runtime.util.StreamRecordUtils.row;
+
+/**
+ * Base class for Arrow Python aggregate function operator tests.
+ */
+public abstract class ArrowPythonAggregateFunctionOperatorTestBase {
+
+ private RowDataHarnessAssertor assertor = new RowDataHarnessAssertor(getOutputLogicalType());
+
+ protected OneInputStreamOperatorTestHarness<RowData, RowData> getTestHarness(
+ Configuration config) throws Exception {
+ RowType inputType = getInputType();
+ RowType outputType = getOutputType();
+ AbstractArrowPythonAggregateFunctionOperator operator = getTestOperator(
+ config,
+ new PythonFunctionInfo[]{
+ new PythonFunctionInfo(
+ PythonScalarFunctionOperatorTestBase.DummyPythonFunction.INSTANCE,
+ new Integer[]{0})},
+ inputType,
+ outputType,
+ new int[]{0},
+ new int[]{2});
+
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness =
+ new OneInputStreamOperatorTestHarness<>(operator);
+ testHarness.getStreamConfig().setManagedMemoryFraction(0.5);
+ testHarness.setup(new RowDataSerializer(outputType));
+ return testHarness;
+ }
+
+ protected RowData newRow(boolean accumulateMsg, Object... fields) {
+ if (accumulateMsg) {
+ return row(fields);
+ } else {
+ RowData row = row(fields);
+ row.setRowKind(RowKind.DELETE);
+ return row;
+ }
+ }
+
+ protected void assertOutputEquals(String message, Collection<Object> expected, Collection<Object> actual) {
+ assertor.assertOutputEquals(message, expected, actual);
+ }
+
+ public abstract LogicalType[] getOutputLogicalType();
+
+ public abstract RowType getInputType();
+
+ public abstract RowType getOutputType();
+
+ public abstract AbstractArrowPythonAggregateFunctionOperator getTestOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggregateFunctions,
+ RowType inputType,
+ RowType outputType,
+ int[] groupingSet,
+ int[] udafInputOffsets);
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
new file mode 100644
index 0000000..7b9fd19
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.PythonFunctionRunner;
+import org.apache.flink.python.PythonOptions;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.runtime.operators.python.aggregate.arrow.AbstractArrowPythonAggregateFunctionOperator;
+import org.apache.flink.table.runtime.operators.python.aggregate.arrow.ArrowPythonAggregateFunctionOperatorTestBase;
+import org.apache.flink.table.runtime.utils.PassThroughPythonAggregateFunctionRunner;
+import org.apache.flink.table.runtime.utils.PythonTestUtils;
+import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/**
+ * Test for {@link BatchArrowPythonGroupAggregateFunctionOperator}. These test that:
+ *
+ * <ul>
+ * <li>FinishBundle is called when checkpoint is encountered</li>
+ * <li>Watermarks are buffered and only sent to downstream when finishedBundle is triggered</li>
+ * </ul>
+ */
+public class BatchArrowPythonGroupAggregateFunctionOperatorTest
+ extends ArrowPythonAggregateFunctionOperatorTestBase {
+
+ @Test
+ public void testGroupAggregateFunction() throws Exception {
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(
+ new Configuration());
+ long initialTime = 0L;
+ ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c4", 1L), initialTime + 2));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 3));
+ testHarness.close();
+
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+ }
+
+ @Test
+ public void testFinishBundleTriggeredOnCheckpoint() throws Exception {
+ Configuration conf = new Configuration();
+ conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+
+ long initialTime = 0L;
+ ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c4", 1L), initialTime + 2));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 3));
+ // checkpoint trigger finishBundle
+ testHarness.prepareSnapshotPreBarrier(0L);
+
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+ testHarness.close();
+
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+ }
+
+ @Test
+ public void testFinishBundleTriggeredByCount() throws Exception {
+ Configuration conf = new Configuration();
+ conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 2);
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+
+ long initialTime = 0L;
+ ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 1L), initialTime + 2));
+ assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
+
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 2));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+ testHarness.close();
+
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+ }
+
+ @Test
+ public void testFinishBundleTriggeredByTime() throws Exception {
+ Configuration conf = new Configuration();
+ conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
+ conf.setLong(PythonOptions.MAX_BUNDLE_TIME_MILLS, 1000L);
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+
+ long initialTime = 0L;
+ ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 1L), initialTime + 2));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 2));
+ assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
+
+ testHarness.setProcessingTime(1000L);
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+ testHarness.close();
+
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", 2L)));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+ }
+
+ @Test
+ public void testWatermarkProcessedOnFinishBundle() throws Exception {
+ Configuration conf = new Configuration();
+ conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
+ long initialTime = 0L;
+ ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c1", "c2", 0L), initialTime + 1));
+ testHarness.processElement(new StreamRecord<>(newRow(true, "c2", "c6", 2L), initialTime + 2));
+ testHarness.processWatermark(initialTime + 2);
+ assertOutputEquals("Watermark has been processed", expectedOutput, testHarness.getOutput());
+
+ // checkpoint trigger finishBundle
+ testHarness.prepareSnapshotPreBarrier(0L);
+
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", 0L)));
+ expectedOutput.add(new Watermark(initialTime + 2));
+
+ assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
+
+ testHarness.close();
+ }
+
+ @Override
+ public LogicalType[] getOutputLogicalType() {
+ return new LogicalType[]{
+ DataTypes.STRING().getLogicalType(),
+ DataTypes.BIGINT().getLogicalType()
+ };
+ }
+
+ @Override
+ public RowType getInputType() {
+ return new RowType(Arrays.asList(
+ new RowType.RowField("f1", new VarCharType()),
+ new RowType.RowField("f2", new VarCharType()),
+ new RowType.RowField("f3", new BigIntType())));
+ }
+
+ @Override
+ public RowType getOutputType() {
+ return new RowType(Arrays.asList(
+ new RowType.RowField("f1", new VarCharType()),
+ new RowType.RowField("f2", new BigIntType())));
+ }
+
+ @Override
+ public AbstractArrowPythonAggregateFunctionOperator getTestOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggregateFunctions,
+ RowType inputType,
+ RowType outputType,
+ int[] groupingSet,
+ int[] udafInputOffsets) {
+ return new PassThroughBatchArrowPythonGroupAggregateFunctionOperator(
+ config,
+ pandasAggregateFunctions,
+ inputType,
+ outputType,
+ groupingSet,
+ groupingSet,
+ udafInputOffsets);
+ }
+
+ private static class PassThroughBatchArrowPythonGroupAggregateFunctionOperator
+ extends BatchArrowPythonGroupAggregateFunctionOperator {
+
+ PassThroughBatchArrowPythonGroupAggregateFunctionOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggregateFunctions,
+ RowType inputType,
+ RowType outputType,
+ int[] groupKey,
+ int[] groupingSet,
+ int[] udafInputOffsets) {
+ super(config, pandasAggregateFunctions, inputType, outputType, groupKey, groupingSet, udafInputOffsets);
+ }
+
+ @Override
+ public PythonFunctionRunner createPythonFunctionRunner() {
+ return new PassThroughPythonAggregateFunctionRunner(
+ getRuntimeContext().getTaskName(),
+ PythonTestUtils.createTestEnvironmentManager(),
+ userDefinedFunctionInputType,
+ userDefinedFunctionOutputType,
+ getFunctionUrn(),
+ getUserDefinedFunctionsProto(),
+ getInputOutputCoderUrn(),
+ new HashMap<>(),
+ PythonTestUtils.createMockFlinkMetricContainer()
+ );
+ }
+ }
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
new file mode 100644
index 0000000..2f474b9
--- /dev/null
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.utils;
+
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.python.PythonConfig;
+import org.apache.flink.python.env.PythonEnvironmentManager;
+import org.apache.flink.python.metric.FlinkMetricContainer;
+import org.apache.flink.table.runtime.arrow.serializers.RowDataArrowSerializer;
+import org.apache.flink.table.runtime.runners.python.beam.BeamTablePythonStatelessFunctionRunner;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Struct;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A {@link PassThroughPythonAggregateFunctionRunner} runner that just return the first input element
+ * with the same key as the execution results.
+ */
+public class PassThroughPythonAggregateFunctionRunner extends BeamTablePythonStatelessFunctionRunner {
+
+ private final List<byte[]> buffer;
+
+ private final RowDataArrowSerializer arrowSerializer;
+
+ /**
+ * Reusable InputStream used to holding the execution results to be deserialized.
+ */
+ private transient ByteArrayInputStreamWithPos bais;
+
+ /**
+ * Reusable OutputStream used to holding the serialized input elements.
+ */
+ private transient ByteArrayOutputStreamWithPos baos;
+
+ public PassThroughPythonAggregateFunctionRunner(
+ String taskName,
+ PythonEnvironmentManager environmentManager,
+ RowType inputType,
+ RowType outputType,
+ String functionUrn,
+ FlinkFnApi.UserDefinedFunctions userDefinedFunctions,
+ String coderUrn,
+ Map<String, String> jobOptions,
+ FlinkMetricContainer flinkMetricContainer) {
+ super(taskName, environmentManager, inputType, outputType, functionUrn, userDefinedFunctions,
+ coderUrn, jobOptions, flinkMetricContainer);
+ this.buffer = new LinkedList<>();
+ arrowSerializer = new RowDataArrowSerializer(inputType, outputType);
+ }
+
+ @Override
+ public void open(PythonConfig config) throws Exception {
+ super.open(config);
+ bais = new ByteArrayInputStreamWithPos();
+ baos = new ByteArrayOutputStreamWithPos();
+ arrowSerializer.open(bais, baos);
+ }
+
+ @Override
+ protected void startBundle() {
+ super.startBundle();
+ this.mainInputReceiver = input -> {
+ byte[] data = input.getValue();
+ bais.setBuffer(data, 0, data.length);
+ arrowSerializer.load();
+ arrowSerializer.write(arrowSerializer.read(0));
+ arrowSerializer.finishCurrentBatch();
+ buffer.add(baos.toByteArray());
+ baos.reset();
+ };
+ }
+
+ @Override
+ public void flush() throws Exception {
+ super.flush();
+ resultBuffer.addAll(buffer);
+ buffer.clear();
+ }
+
+ @Override
+ public JobBundleFactory createJobBundleFactory(Struct pipelineOptions) {
+ return PythonTestUtils.createMockJobBundleFactory();
+ }
+}