You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2020/09/23 07:44:12 UTC
[flink] branch master updated: [FLINK-19333][python] Introduce
BatchArrowPythonOverWindowAggregateFunctionOperator
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 1e6c187 [FLINK-19333][python] Introduce BatchArrowPythonOverWindowAggregateFunctionOperator
1e6c187 is described below
commit 1e6c187112bb802ad67c080ae8b4864eda829a97
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Tue Sep 22 15:16:29 2020 +0800
[FLINK-19333][python] Introduce BatchArrowPythonOverWindowAggregateFunctionOperator
This closes #13451
---
.../pyflink/fn_execution/flink_fn_execution_pb2.py | 217 ++++++++++----
.../pyflink/proto/flink-fn-execution.proto | 21 ++
...stractArrowPythonAggregateFunctionOperator.java | 2 +-
...wPythonOverWindowAggregateFunctionOperator.java | 332 +++++++++++++++++++++
...owPythonGroupAggregateFunctionOperatorTest.java | 3 +-
...onGroupWindowAggregateFunctionOperatorTest.java | 3 +-
...onOverWindowAggregateFunctionOperatorTest.java} | 156 +++++-----
.../PassThroughPythonAggregateFunctionRunner.java | 45 ++-
8 files changed, 647 insertions(+), 132 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index bd27ff7..6b1af39 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -36,11 +36,57 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='flink-fn-execution.proto',
package='org.apache.flink.fn_execution.v1',
syntax='proto3',
- serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...]
+ serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\x92\x02\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x12\x14\n\x0cwindow_index\x18\x03 \x01(\x05\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\ri [...]
)
+_OVERWINDOW_WINDOWTYPE = _descriptor.EnumDescriptor(
+ name='WindowType',
+ full_name='org.apache.flink.fn_execution.v1.OverWindow.WindowType',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='RANGE_UNBOUNDED', index=0, number=0,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='RANGE_UNBOUNDED_PRECEDING', index=1, number=1,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='RANGE_UNBOUNDED_FOLLOWING', index=2, number=2,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='RANGE_SLIDING', index=3, number=3,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='ROW_UNBOUNDED', index=4, number=4,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='ROW_UNBOUNDED_PRECEDING', index=5, number=5,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='ROW_UNBOUNDED_FOLLOWING', index=6, number=6,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='ROW_SLIDING', index=7, number=7,
+ options=None,
+ type=None),
+ ],
+ containing_type=None,
+ options=None,
+ serialized_start=662,
+ serialized_end=870,
+)
+_sym_db.RegisterEnumDescriptor(_OVERWINDOW_WINDOWTYPE)
+
_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = _descriptor.EnumDescriptor(
name='FunctionType',
full_name='org.apache.flink.fn_execution.v1.UserDefinedDataStreamFunction.FunctionType',
@@ -70,8 +116,8 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=585,
- serialized_end=663,
+ serialized_start=1023,
+ serialized_end=1101,
)
_sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE)
@@ -168,8 +214,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=2877,
- serialized_end=3160,
+ serialized_start=3315,
+ serialized_end=3598,
)
_sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME)
@@ -254,8 +300,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=3681,
- serialized_end=3912,
+ serialized_start=4119,
+ serialized_end=4350,
)
_sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME)
@@ -303,8 +349,8 @@ _USERDEFINEDFUNCTION_INPUT = _descriptor.Descriptor(
name='input', full_name='org.apache.flink.fn_execution.v1.UserDefinedFunction.Input.input',
index=0, containing_type=None, fields=[]),
],
- serialized_start=181,
- serialized_end=315,
+ serialized_start=203,
+ serialized_end=337,
)
_USERDEFINEDFUNCTION = _descriptor.Descriptor(
@@ -328,6 +374,13 @@ _USERDEFINEDFUNCTION = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='window_index', full_name='org.apache.flink.fn_execution.v1.UserDefinedFunction.window_index', index=2,
+ number=3, type=5, cpp_type=1, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -341,7 +394,7 @@ _USERDEFINEDFUNCTION = _descriptor.Descriptor(
oneofs=[
],
serialized_start=63,
- serialized_end=315,
+ serialized_end=337,
)
@@ -366,11 +419,64 @@ _USERDEFINEDFUNCTIONS = _descriptor.Descriptor(
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='windows', full_name='org.apache.flink.fn_execution.v1.UserDefinedFunctions.windows', index=2,
+ number=3, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ options=None,
+ is_extendable=False,
+ syntax='proto3',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=340,
+ serialized_end=518,
+)
+
+
+_OVERWINDOW = _descriptor.Descriptor(
+ name='OverWindow',
+ full_name='org.apache.flink.fn_execution.v1.OverWindow',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='window_type', full_name='org.apache.flink.fn_execution.v1.OverWindow.window_type', index=0,
+ number=1, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='lower_boundary', full_name='org.apache.flink.fn_execution.v1.OverWindow.lower_boundary', index=1,
+ number=2, type=3, cpp_type=2, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='upper_boundary', full_name='org.apache.flink.fn_execution.v1.OverWindow.upper_boundary', index=2,
+ number=3, type=3, cpp_type=2, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
+ _OVERWINDOW_WINDOWTYPE,
],
options=None,
is_extendable=False,
@@ -378,8 +484,8 @@ _USERDEFINEDFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=317,
- serialized_end=432,
+ serialized_start=521,
+ serialized_end=870,
)
@@ -417,8 +523,8 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=435,
- serialized_end=663,
+ serialized_start=873,
+ serialized_end=1101,
)
@@ -455,8 +561,8 @@ _USERDEFINEDDATASTREAMFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=666,
- serialized_end=801,
+ serialized_start=1104,
+ serialized_end=1239,
)
@@ -535,8 +641,8 @@ _USERDEFINEDAGGREGATEFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=804,
- serialized_end=1135,
+ serialized_start=1242,
+ serialized_end=1573,
)
@@ -573,8 +679,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1213,
- serialized_end=1364,
+ serialized_start=1651,
+ serialized_end=1802,
)
_SCHEMA_TIMEINFO = _descriptor.Descriptor(
@@ -603,8 +709,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1366,
- serialized_end=1395,
+ serialized_start=1804,
+ serialized_end=1833,
)
_SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
@@ -633,8 +739,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1397,
- serialized_end=1431,
+ serialized_start=1835,
+ serialized_end=1869,
)
_SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -663,8 +769,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1433,
- serialized_end=1477,
+ serialized_start=1871,
+ serialized_end=1915,
)
_SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -693,8 +799,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1479,
- serialized_end=1518,
+ serialized_start=1917,
+ serialized_end=1956,
)
_SCHEMA_DECIMALINFO = _descriptor.Descriptor(
@@ -730,8 +836,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1520,
- serialized_end=1567,
+ serialized_start=1958,
+ serialized_end=2005,
)
_SCHEMA_BINARYINFO = _descriptor.Descriptor(
@@ -760,8 +866,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1569,
- serialized_end=1597,
+ serialized_start=2007,
+ serialized_end=2035,
)
_SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
@@ -790,8 +896,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1599,
- serialized_end=1630,
+ serialized_start=2037,
+ serialized_end=2068,
)
_SCHEMA_CHARINFO = _descriptor.Descriptor(
@@ -820,8 +926,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1632,
- serialized_end=1658,
+ serialized_start=2070,
+ serialized_end=2096,
)
_SCHEMA_VARCHARINFO = _descriptor.Descriptor(
@@ -850,8 +956,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1660,
- serialized_end=1689,
+ serialized_start=2098,
+ serialized_end=2127,
)
_SCHEMA_FIELDTYPE = _descriptor.Descriptor(
@@ -974,8 +1080,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
name='type_info', full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=1692,
- serialized_end=2764,
+ serialized_start=2130,
+ serialized_end=3202,
)
_SCHEMA_FIELD = _descriptor.Descriptor(
@@ -1018,8 +1124,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2766,
- serialized_end=2874,
+ serialized_start=3204,
+ serialized_end=3312,
)
_SCHEMA = _descriptor.Descriptor(
@@ -1049,8 +1155,8 @@ _SCHEMA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1138,
- serialized_end=3160,
+ serialized_start=1576,
+ serialized_end=3598,
)
@@ -1104,8 +1210,8 @@ _TYPEINFO_FIELDTYPE = _descriptor.Descriptor(
name='type_info', full_name='org.apache.flink.fn_execution.v1.TypeInfo.FieldType.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=3241,
- serialized_end=3566,
+ serialized_start=3679,
+ serialized_end=4004,
)
_TYPEINFO_FIELD = _descriptor.Descriptor(
@@ -1148,8 +1254,8 @@ _TYPEINFO_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3568,
- serialized_end=3678,
+ serialized_start=4006,
+ serialized_end=4116,
)
_TYPEINFO = _descriptor.Descriptor(
@@ -1179,8 +1285,8 @@ _TYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3163,
- serialized_end=3912,
+ serialized_start=3601,
+ serialized_end=4350,
)
_USERDEFINEDFUNCTION_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION
@@ -1196,6 +1302,9 @@ _USERDEFINEDFUNCTION_INPUT.oneofs_by_name['input'].fields.append(
_USERDEFINEDFUNCTION_INPUT.fields_by_name['inputConstant'].containing_oneof = _USERDEFINEDFUNCTION_INPUT.oneofs_by_name['input']
_USERDEFINEDFUNCTION.fields_by_name['inputs'].message_type = _USERDEFINEDFUNCTION_INPUT
_USERDEFINEDFUNCTIONS.fields_by_name['udfs'].message_type = _USERDEFINEDFUNCTION
+_USERDEFINEDFUNCTIONS.fields_by_name['windows'].message_type = _OVERWINDOW
+_OVERWINDOW.fields_by_name['window_type'].enum_type = _OVERWINDOW_WINDOWTYPE
+_OVERWINDOW_WINDOWTYPE.containing_type = _OVERWINDOW
_USERDEFINEDDATASTREAMFUNCTION.fields_by_name['functionType'].enum_type = _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE
_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE.containing_type = _USERDEFINEDDATASTREAMFUNCTION
_USERDEFINEDDATASTREAMFUNCTIONS.fields_by_name['udfs'].message_type = _USERDEFINEDDATASTREAMFUNCTION
@@ -1287,6 +1396,7 @@ _TYPEINFO.fields_by_name['field'].message_type = _TYPEINFO_FIELD
_TYPEINFO_TYPENAME.containing_type = _TYPEINFO
DESCRIPTOR.message_types_by_name['UserDefinedFunction'] = _USERDEFINEDFUNCTION
DESCRIPTOR.message_types_by_name['UserDefinedFunctions'] = _USERDEFINEDFUNCTIONS
+DESCRIPTOR.message_types_by_name['OverWindow'] = _OVERWINDOW
DESCRIPTOR.message_types_by_name['UserDefinedDataStreamFunction'] = _USERDEFINEDDATASTREAMFUNCTION
DESCRIPTOR.message_types_by_name['UserDefinedDataStreamFunctions'] = _USERDEFINEDDATASTREAMFUNCTIONS
DESCRIPTOR.message_types_by_name['UserDefinedAggregateFunctions'] = _USERDEFINEDAGGREGATEFUNCTIONS
@@ -1316,6 +1426,13 @@ UserDefinedFunctions = _reflection.GeneratedProtocolMessageType('UserDefinedFunc
))
_sym_db.RegisterMessage(UserDefinedFunctions)
+OverWindow = _reflection.GeneratedProtocolMessageType('OverWindow', (_message.Message,), dict(
+ DESCRIPTOR = _OVERWINDOW,
+ __module__ = 'flink_fn_execution_pb2'
+ # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.OverWindow)
+ ))
+_sym_db.RegisterMessage(OverWindow)
+
UserDefinedDataStreamFunction = _reflection.GeneratedProtocolMessageType('UserDefinedDataStreamFunction', (_message.Message,), dict(
DESCRIPTOR = _USERDEFINEDDATASTREAMFUNCTION,
__module__ = 'flink_fn_execution_pb2'
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto
index c102c0e..ef5d091 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -44,12 +44,33 @@ message UserDefinedFunction {
// 2. The result of another user-defined function
// 3. The constant value of the column
repeated Input inputs = 2;
+
+ // The index of the over window used in pandas batch over window aggregation
+ int32 window_index = 3;
}
// A list of user-defined functions to be executed in a batch.
message UserDefinedFunctions {
repeated UserDefinedFunction udfs = 1;
bool metric_enabled = 2;
+ repeated OverWindow windows = 3;
+}
+
+// Used to describe the info of over window in pandas batch over window aggregation
+message OverWindow {
+ enum WindowType {
+ RANGE_UNBOUNDED = 0;
+ RANGE_UNBOUNDED_PRECEDING = 1;
+ RANGE_UNBOUNDED_FOLLOWING = 2;
+ RANGE_SLIDING = 3;
+ ROW_UNBOUNDED = 4;
+ ROW_UNBOUNDED_PRECEDING = 5;
+ ROW_UNBOUNDED_FOLLOWING = 6;
+ ROW_SLIDING = 7;
+ }
+ WindowType window_type = 1;
+ int64 lower_boundary = 2;
+ int64 upper_boundary = 3;
}
// User defined DataStream function definition.
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
index 21fbb94..c55d3c5 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/AbstractArrowPythonAggregateFunctionOperator.java
@@ -58,7 +58,7 @@ public abstract class AbstractArrowPythonAggregateFunctionOperator
/**
* The Pandas {@link AggregateFunction}s to be executed.
*/
- private final PythonFunctionInfo[] pandasAggFunctions;
+ protected final PythonFunctionInfo[] pandasAggFunctions;
protected final int[] groupingSet;
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonOverWindowAggregateFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonOverWindowAggregateFunctionOperator.java
new file mode 100644
index 0000000..95c5d79
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonOverWindowAggregateFunctionOperator.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryRowData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.functions.python.PythonFunctionInfo;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.RowType;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ListIterator;
+
+import static org.apache.flink.table.runtime.operators.python.utils.PythonOperatorUtils.getUserDefinedFunctionProto;
+
+/**
+ * The Batch Arrow Python {@link AggregateFunction} Operator for Over Window Aggregation.
+ */
+@Internal
+public class BatchArrowPythonOverWindowAggregateFunctionOperator
+ extends AbstractBatchArrowPythonAggregateFunctionOperator {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final String SCHEMA_OVER_WINDOW_ARROW_CODER_URN = "flink:coder:schema:batch_over_window:arrow:v1";
+
+ private static final String PANDAS_BATCH_OVER_WINDOW_AGG_FUNCTION_URN = "flink:transform:batch_over_window_aggregate_function:arrow:v1";
+
+ /**
+ * Used to serialize the boundary of range window.
+ */
+ private static final IntSerializer windowBoundarySerializer = IntSerializer.INSTANCE;
+
+ /**
+ * Window lower boundary. e.g. Long.MIN_VALUE means unbounded preceding.
+ */
+ private final long[] lowerBoundary;
+
+ /**
+ * Window upper boundary. e.g. Long.MAX_VALUE means unbounded following.
+ */
+ private final long[] upperBoundary;
+
+ /**
+ * Whether the specified position window is a range window.
+ */
+ private final boolean[] isRangeWindows;
+
+ /**
+ * The window index of the specified aggregate function belonging to.
+ */
+ private final int[] aggWindowIndex;
+
+ /**
+ * The row time index of the input data.
+ */
+ private final int inputTimeFieldIndex;
+
+ /**
+ * The order of row time. True for ascending.
+ */
+ private final boolean asc;
+
+ /**
+ * The type serializer for the forwarded fields.
+ */
+ private transient RowDataSerializer forwardedInputSerializer;
+
+ /**
+ * Stores the start position of the last key data in forwardedInputQueue.
+ */
+ private transient int lastKeyDataStartPos;
+
+ /**
+ * Reusable OutputStream used to holding the window boundary with input elements.
+ */
+ private transient ByteArrayOutputStreamWithPos windowBoundaryWithDataBaos;
+
+ /**
+ * OutputStream Wrapper.
+ */
+ private transient DataOutputViewStreamWrapper windowBoundaryWithDataWrapper;
+
+ /**
+ * Stores bounded range window boundaries.
+ */
+ private transient List<List<Integer>> boundedRangeWindowBoundaries;
+
+ /**
+ * Stores index of the bounded range window.
+ */
+ private transient ArrayList<Integer> boundedRangeWindowIndex;
+
+ public BatchArrowPythonOverWindowAggregateFunctionOperator(
+ Configuration config,
+ PythonFunctionInfo[] pandasAggFunctions,
+ RowType inputType,
+ RowType outputType,
+ long[] lowerBoundary,
+ long[] upperBoundary,
+ boolean[] isRangeWindows,
+ int[] aggWindowIndex,
+ int[] groupKey,
+ int[] groupingSet,
+ int[] udafInputOffsets,
+ int inputTimeFieldIndex,
+ boolean asc) {
+ super(config, pandasAggFunctions, inputType, outputType, groupKey, groupingSet, udafInputOffsets);
+ this.lowerBoundary = lowerBoundary;
+ this.upperBoundary = upperBoundary;
+ this.isRangeWindows = isRangeWindows;
+ this.aggWindowIndex = aggWindowIndex;
+ this.inputTimeFieldIndex = inputTimeFieldIndex;
+ this.asc = asc;
+ }
+
+ @Override
+ public void open() throws Exception {
+ userDefinedFunctionOutputType = new RowType(
+ outputType.getFields().subList(inputType.getFieldCount(), outputType.getFieldCount()));
+ forwardedInputSerializer = new RowDataSerializer(inputType);
+ this.lastKeyDataStartPos = 0;
+ windowBoundaryWithDataBaos = new ByteArrayOutputStreamWithPos();
+ windowBoundaryWithDataWrapper = new DataOutputViewStreamWrapper(windowBoundaryWithDataBaos);
+ boundedRangeWindowBoundaries = new ArrayList<>(lowerBoundary.length);
+ boundedRangeWindowIndex = new ArrayList<>();
+ for (int i = 0; i < lowerBoundary.length; i++) {
+ // range window with bounded preceding or bounded following
+ if (isRangeWindows[i] && (lowerBoundary[i] != Long.MIN_VALUE || upperBoundary[i] != Long.MAX_VALUE)) {
+ boundedRangeWindowIndex.add(i);
+ boundedRangeWindowBoundaries.add(new ArrayList<>());
+ }
+ }
+ super.open();
+ }
+
+ @Override
+ public void bufferInput(RowData input) throws Exception {
+ BinaryRowData currentKey = groupKeyProjection.apply(input).copy();
+ if (isNewKey(currentKey)) {
+ if (lastGroupKey != null) {
+ invokeCurrentBatch();
+ }
+ lastGroupKey = currentKey;
+ lastGroupSet = groupSetProjection.apply(input).copy();
+ }
+ RowData forwardedFields = forwardedInputSerializer.copy(input);
+ forwardedInputQueue.add(forwardedFields);
+ }
+
+ @Override
+ protected void invokeCurrentBatch() throws Exception {
+ if (currentBatchCount > 0) {
+ arrowSerializer.finishCurrentBatch();
+ ListIterator<RowData> iter = forwardedInputQueue.listIterator(lastKeyDataStartPos);
+ int[] lowerBoundaryPos = new int[boundedRangeWindowIndex.size()];
+ int[] upperBoundaryPos = new int[boundedRangeWindowIndex.size()];
+ while (iter.hasNext()) {
+ RowData curData = iter.next();
+ // loop every bounded range window
+ for (int j = 0; j < boundedRangeWindowIndex.size(); j++) {
+ int windowPos = boundedRangeWindowIndex.get(j);
+ long curMills = curData.getTimestamp(inputTimeFieldIndex, 3).getMillisecond();
+ List<Integer> curWindowBoundary = boundedRangeWindowBoundaries.get(j);
+ // bounded preceding
+ if (lowerBoundary[windowPos] != Long.MIN_VALUE) {
+ int curLowerBoundaryPos = lowerBoundaryPos[j];
+ long lowerBoundaryTime = curMills + lowerBoundary[windowPos];
+ while (isInCurrentOverWindow(forwardedInputQueue.get(curLowerBoundaryPos), lowerBoundaryTime, false)) {
+ curLowerBoundaryPos += 1;
+ }
+ lowerBoundaryPos[j] = curLowerBoundaryPos;
+ curWindowBoundary.add(curLowerBoundaryPos);
+ }
+ // bounded following
+ if (upperBoundary[windowPos] != Long.MAX_VALUE) {
+ int curUpperBoundaryPos = upperBoundaryPos[j];
+ long upperBoundaryTime = curMills + upperBoundary[windowPos];
+ while (curUpperBoundaryPos < forwardedInputQueue.size() &&
+ isInCurrentOverWindow(forwardedInputQueue.get(curUpperBoundaryPos), upperBoundaryTime, true)) {
+ curUpperBoundaryPos += 1;
+ }
+ upperBoundaryPos[j] = curUpperBoundaryPos;
+ curWindowBoundary.add(curUpperBoundaryPos);
+ }
+ }
+ }
+ // serialize the num of bounded range window.
+ windowBoundarySerializer.serialize(boundedRangeWindowBoundaries.size(), windowBoundaryWithDataWrapper);
+ // serialize every bounded range window boundaries.
+ for (List<Integer> boundedRangeWindowBoundary : boundedRangeWindowBoundaries) {
+ windowBoundarySerializer.serialize(boundedRangeWindowBoundary.size(), windowBoundaryWithDataWrapper);
+ for (int ele : boundedRangeWindowBoundary) {
+ windowBoundarySerializer.serialize(ele, windowBoundaryWithDataWrapper);
+ }
+ boundedRangeWindowBoundary.clear();
+ }
+ // write arrow format data.
+ windowBoundaryWithDataBaos.write(baos.toByteArray());
+ baos.reset();
+ pythonFunctionRunner.process(windowBoundaryWithDataBaos.toByteArray());
+ windowBoundaryWithDataBaos.reset();
+ elementCount += currentBatchCount;
+ checkInvokeFinishBundleByCount();
+ currentBatchCount = 0;
+ }
+ lastKeyDataStartPos = forwardedInputQueue.size();
+ }
+
+ @Override
+ public void processElementInternal(RowData value) {
+ arrowSerializer.write(getFunctionInput(value));
+ currentBatchCount++;
+ }
+
+ @Override
+ @SuppressWarnings("ConstantConditions")
+ public void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception {
+ byte[] udafResult = resultTuple.f0;
+ int length = resultTuple.f1;
+ bais.setBuffer(udafResult, 0, length);
+ int rowCount = arrowSerializer.load();
+ for (int i = 0; i < rowCount; i++) {
+ RowData input = forwardedInputQueue.poll();
+ lastKeyDataStartPos--;
+ reuseJoinedRow.setRowKind(input.getRowKind());
+ rowDataWrapper.collect(reuseJoinedRow.replace(input, arrowSerializer.read(i)));
+ }
+ }
+
+ @Override
+ public FlinkFnApi.UserDefinedFunctions getUserDefinedFunctionsProto() {
+ FlinkFnApi.UserDefinedFunctions.Builder builder = FlinkFnApi.UserDefinedFunctions.newBuilder();
+ // add udaf proto
+ for (int i = 0; i < pandasAggFunctions.length; i++) {
+ FlinkFnApi.UserDefinedFunction.Builder functionBuilder =
+ getUserDefinedFunctionProto(pandasAggFunctions[i]).toBuilder();
+ functionBuilder.setWindowIndex(aggWindowIndex[i]);
+ builder.addUdfs(functionBuilder);
+ }
+ builder.setMetricEnabled(getPythonConfig().isMetricEnabled());
+ // add windows
+ for (int i = 0; i < lowerBoundary.length; i++) {
+ FlinkFnApi.OverWindow.Builder windowBuilder = FlinkFnApi.OverWindow.newBuilder();
+ if (isRangeWindows[i]) {
+ // range window
+ if (lowerBoundary[i] != Long.MIN_VALUE) {
+ if (upperBoundary[i] != Long.MAX_VALUE) {
+ // range sliding window
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.RANGE_SLIDING);
+ } else {
+ // range unbounded following window
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.RANGE_UNBOUNDED_FOLLOWING);
+ }
+ } else {
+ if (upperBoundary[i] != Long.MAX_VALUE) {
+ // range unbounded preceding window
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.RANGE_UNBOUNDED_PRECEDING);
+ } else {
+ // range unbounded window
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.RANGE_UNBOUNDED);
+ }
+ }
+ } else {
+ // row window
+ if (lowerBoundary[i] != Long.MIN_VALUE) {
+ windowBuilder.setLowerBoundary(lowerBoundary[i]);
+ if (upperBoundary[i] != Long.MAX_VALUE) {
+ // row sliding window
+ windowBuilder.setUpperBoundary(upperBoundary[i]);
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.ROW_SLIDING);
+ } else {
+ // row unbounded following window
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.ROW_UNBOUNDED_FOLLOWING);
+ }
+ } else {
+ if (upperBoundary[i] != Long.MAX_VALUE) {
+ // row unbounded preceding window
+ windowBuilder.setUpperBoundary(upperBoundary[i]);
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.ROW_UNBOUNDED_PRECEDING);
+ } else {
+ // row unbounded window
+ windowBuilder.setWindowType(FlinkFnApi.OverWindow.WindowType.ROW_UNBOUNDED);
+ }
+ }
+ }
+ builder.addWindows(windowBuilder);
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String getFunctionUrn() {
+ return PANDAS_BATCH_OVER_WINDOW_AGG_FUNCTION_URN;
+ }
+
+ @Override
+ public String getInputOutputCoderUrn() {
+ return SCHEMA_OVER_WINDOW_ARROW_CODER_URN;
+ }
+
+ private boolean isInCurrentOverWindow(RowData data, long time, boolean includeEqual) {
+ long dataTime = data.getTimestamp(inputTimeFieldIndex, 3).getMillisecond();
+ long diff = time - dataTime;
+ return (diff > 0 && asc) || (diff == 0 && includeEqual);
+ }
+}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
index 7f5c4a2..8d2feda 100644
--- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupAggregateFunctionOperatorTest.java
@@ -193,7 +193,8 @@ public class BatchArrowPythonGroupAggregateFunctionOperatorTest
getUserDefinedFunctionsProto(),
getInputOutputCoderUrn(),
new HashMap<>(),
- PythonTestUtils.createMockFlinkMetricContainer()
+ PythonTestUtils.createMockFlinkMetricContainer(),
+ false
);
}
}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java
index e42533b..aaf6791 100644
--- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java
@@ -245,7 +245,8 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
getUserDefinedFunctionsProto(),
getInputOutputCoderUrn(),
new HashMap<>(),
- PythonTestUtils.createMockFlinkMetricContainer()
+ PythonTestUtils.createMockFlinkMetricContainer(),
+ false
);
}
}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonOverWindowAggregateFunctionOperatorTest.java
similarity index 60%
copy from flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java
copy to flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonOverWindowAggregateFunctionOperatorTest.java
index e42533b..a57a0e8 100644
--- a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonGroupWindowAggregateFunctionOperatorTest.java
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/arrow/batch/BatchArrowPythonOverWindowAggregateFunctionOperatorTest.java
@@ -19,13 +19,13 @@
package org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.python.PythonFunctionRunner;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.TimestampData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.operators.python.aggregate.arrow.AbstractArrowPythonAggregateFunctionOperator;
import org.apache.flink.table.runtime.operators.python.aggregate.arrow.ArrowPythonAggregateFunctionOperatorTestBase;
@@ -34,53 +34,48 @@ import org.apache.flink.table.runtime.utils.PythonTestUtils;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
-import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.VarCharType;
import org.junit.Test;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
+import static org.junit.Assert.assertEquals;
+
/**
- * Test for {@link BatchArrowPythonGroupWindowAggregateFunctionOperator}. These test that:
+ * Test for {@link BatchArrowPythonOverWindowAggregateFunctionOperator}. These test that:
*
* <ul>
* <li>FinishBundle is called when bundled element count reach to max bundle size</li>
* <li>FinishBundle is called when bundled time reach to max bundle time</li>
* </ul>
*/
-public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends ArrowPythonAggregateFunctionOperatorTestBase {
+public class BatchArrowPythonOverWindowAggregateFunctionOperatorTest extends ArrowPythonAggregateFunctionOperatorTestBase {
+
@Test
- public void testGroupAggregateFunction() throws Exception {
+ public void testOverWindowAggregateFunction() throws Exception {
OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(
new Configuration());
+
long initialTime = 0L;
ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
testHarness.open();
testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c2", 0L, 0L), initialTime + 1));
- testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c4", 1L, 6000L), initialTime + 2));
- testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c6", 2L, 10000L), initialTime + 3));
+ testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c4", 1L, 0L), initialTime + 2));
+ testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c6", 2L, 10L), initialTime + 3));
testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c2", "c8", 3L, 0L), initialTime + 3));
testHarness.close();
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 0L, TimestampData.fromEpochMillis(-5000L), TimestampData.fromEpochMillis(5000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 0L, TimestampData.fromEpochMillis(0L), TimestampData.fromEpochMillis(10000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 1L, TimestampData.fromEpochMillis(5000L), TimestampData.fromEpochMillis(15000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 2L, TimestampData.fromEpochMillis(10000L), TimestampData.fromEpochMillis(20000L))));
-
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c2", 3L, TimestampData.fromEpochMillis(-5000L), TimestampData.fromEpochMillis(5000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c2", 3L, TimestampData.fromEpochMillis(0L), TimestampData.fromEpochMillis(10000L))));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L, 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c4", 1L, 0L, 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c6", 2L, 10L, 2L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", "c8", 3L, 0L, 3L)));
assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
}
@@ -88,7 +83,7 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
@Test
public void testFinishBundleTriggeredByCount() throws Exception {
Configuration conf = new Configuration();
- conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 6);
+ conf.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 3);
OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(conf);
long initialTime = 0L;
@@ -97,30 +92,20 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
testHarness.open();
testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c2", 0L, 0L), initialTime + 1));
- testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c4", 1L, 6000L), initialTime + 2));
- testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c6", 2L, 10000L), initialTime + 3));
-
+ testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c4", 1L, 0L), initialTime + 2));
+ testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c6", 2L, 10L), initialTime + 3));
assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c2", "c8", 3L, 0L), initialTime + 3));
-
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 0L, TimestampData.fromEpochMillis(-5000L), TimestampData.fromEpochMillis(5000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 0L, TimestampData.fromEpochMillis(0L), TimestampData.fromEpochMillis(10000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 1L, TimestampData.fromEpochMillis(5000L), TimestampData.fromEpochMillis(15000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 2L, TimestampData.fromEpochMillis(10000L), TimestampData.fromEpochMillis(20000L))));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L, 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c4", 1L, 0L, 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c6", 2L, 10L, 2L)));
assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
testHarness.close();
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c2", 3L, TimestampData.fromEpochMillis(-5000L), TimestampData.fromEpochMillis(5000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c2", 3L, TimestampData.fromEpochMillis(0L), TimestampData.fromEpochMillis(10000L))));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", "c8", 3L, 0L, 3L)));
assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
}
@@ -138,36 +123,53 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
testHarness.open();
testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c2", 0L, 0L), initialTime + 1));
- testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c4", 1L, 6000L), initialTime + 2));
- testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c6", 2L, 10000L), initialTime + 3));
+ testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c4", 1L, 0L), initialTime + 2));
+ testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c1", "c6", 2L, 10L), initialTime + 3));
testHarness.processElement(new StreamRecord<>(newBinaryRow(true, "c2", "c8", 3L, 0L), initialTime + 3));
assertOutputEquals("FinishBundle should not be triggered.", expectedOutput, testHarness.getOutput());
testHarness.setProcessingTime(1000L);
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 0L, TimestampData.fromEpochMillis(-5000L), TimestampData.fromEpochMillis(5000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 0L, TimestampData.fromEpochMillis(0L), TimestampData.fromEpochMillis(10000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 1L, TimestampData.fromEpochMillis(5000L), TimestampData.fromEpochMillis(15000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c1", 2L, TimestampData.fromEpochMillis(10000L), TimestampData.fromEpochMillis(20000L))));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c2", 0L, 0L, 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c4", 1L, 0L, 0L)));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c1", "c6", 2L, 10L, 2L)));
+
assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
testHarness.close();
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c2", 3L, TimestampData.fromEpochMillis(-5000L), TimestampData.fromEpochMillis(5000L))));
- expectedOutput.add(new StreamRecord<>(
- newRow(true, "c2", 3L, TimestampData.fromEpochMillis(0L), TimestampData.fromEpochMillis(10000L))));
+ expectedOutput.add(new StreamRecord<>(newRow(true, "c2", "c8", 3L, 0L, 3L)));
assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
}
+ @Test
+ public void testUserDefinedFunctionsProto() throws Exception {
+ OneInputStreamOperatorTestHarness<RowData, RowData> testHarness = getTestHarness(
+ new Configuration());
+ testHarness.open();
+ BatchArrowPythonOverWindowAggregateFunctionOperator operator =
+ (BatchArrowPythonOverWindowAggregateFunctionOperator) testHarness.getOneInputOperator();
+ FlinkFnApi.UserDefinedFunctions functionsProto = operator.getUserDefinedFunctionsProto();
+ List<FlinkFnApi.OverWindow> windows = functionsProto.getWindowsList();
+ assertEquals(2, windows.size());
+
+ // first window is a range sliding window.
+ FlinkFnApi.OverWindow firstWindow = windows.get(0);
+ assertEquals(firstWindow.getWindowType(), FlinkFnApi.OverWindow.WindowType.RANGE_SLIDING);
+
+ // second window is a row unbounded preceding window.
+ FlinkFnApi.OverWindow secondWindow = windows.get(1);
+ assertEquals(secondWindow.getWindowType(), FlinkFnApi.OverWindow.WindowType.ROW_UNBOUNDED_PRECEDING);
+ assertEquals(secondWindow.getUpperBoundary(), 2L);
+ }
+
@Override
public LogicalType[] getOutputLogicalType() {
return new LogicalType[]{
DataTypes.STRING().getLogicalType(),
+ DataTypes.STRING().getLogicalType(),
+ DataTypes.BIGINT().getLogicalType(),
+ DataTypes.BIGINT().getLogicalType(),
DataTypes.BIGINT().getLogicalType()
};
}
@@ -185,9 +187,10 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
public RowType getOutputType() {
return new RowType(Arrays.asList(
new RowType.RowField("f1", new VarCharType()),
- new RowType.RowField("f2", new BigIntType()),
- new RowType.RowField("windowStart", new TimestampType(3)),
- new RowType.RowField("windowEnd", new TimestampType(3))));
+ new RowType.RowField("f2", new VarCharType()),
+ new RowType.RowField("f3", new BigIntType()),
+ new RowType.RowField("rowTime", new BigIntType()),
+ new RowType.RowField("agg", new BigIntType())));
}
@Override
@@ -198,40 +201,41 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
RowType outputType,
int[] groupingSet,
int[] udafInputOffsets) {
- // SlidingWindow(10000L, 5000L)
- return new PassThroughBatchArrowPythonGroupWindowAggregateFunctionOperator(
+ return new PassThroughBatchArrowPythonOverWindowAggregateFunctionOperator(
config,
pandasAggregateFunctions,
inputType,
outputType,
- 3,
- 100000,
- 10000L,
- 5000L,
- new int[]{0, 1},
+ new long[]{0L, Long.MIN_VALUE},
+ new long[]{0L, 2L},
+ new boolean[]{true, false},
+ new int[]{0},
groupingSet,
groupingSet,
- udafInputOffsets
- );
+ udafInputOffsets,
+ 3,
+ true);
}
- private static class PassThroughBatchArrowPythonGroupWindowAggregateFunctionOperator
- extends BatchArrowPythonGroupWindowAggregateFunctionOperator {
- PassThroughBatchArrowPythonGroupWindowAggregateFunctionOperator(
+ private static class PassThroughBatchArrowPythonOverWindowAggregateFunctionOperator
+ extends BatchArrowPythonOverWindowAggregateFunctionOperator {
+
+ PassThroughBatchArrowPythonOverWindowAggregateFunctionOperator(
Configuration config,
PythonFunctionInfo[] pandasAggFunctions,
RowType inputType,
RowType outputType,
- int inputTimeFieldIndex,
- int maxLimitSize,
- long windowSize,
- long slideSize,
- int[] namedProperties,
+ long[] lowerBoundary,
+ long[] upperBoundary,
+ boolean[] isRangeWindow,
+ int[] aggWindowIndex,
int[] groupKey,
int[] groupingSet,
- int[] udafInputOffsets) {
- super(config, pandasAggFunctions, inputType, outputType, inputTimeFieldIndex,
- maxLimitSize, windowSize, slideSize, namedProperties, groupKey, groupingSet, udafInputOffsets);
+ int[] udafInputOffsets,
+ int inputTimeFieldIndex,
+ boolean asc) {
+ super(config, pandasAggFunctions, inputType, outputType, lowerBoundary, upperBoundary,
+ isRangeWindow, aggWindowIndex, groupKey, groupingSet, udafInputOffsets, inputTimeFieldIndex, asc);
}
@Override
@@ -245,8 +249,8 @@ public class BatchArrowPythonGroupWindowAggregateFunctionOperatorTest extends Ar
getUserDefinedFunctionsProto(),
getInputOutputCoderUrn(),
new HashMap<>(),
- PythonTestUtils.createMockFlinkMetricContainer()
- );
+ PythonTestUtils.createMockFlinkMetricContainer(),
+ true);
}
}
}
diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java b/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
index e4cffa5..6c81294 100644
--- a/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
+++ b/flink-python/src/test/java/org/apache/flink/table/runtime/utils/PassThroughPythonAggregateFunctionRunner.java
@@ -18,12 +18,15 @@
package org.apache.flink.table.runtime.utils;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.python.PythonConfig;
import org.apache.flink.python.env.PythonEnvironmentManager;
import org.apache.flink.python.metric.FlinkMetricContainer;
+import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.arrow.serializers.RowDataArrowSerializer;
import org.apache.flink.table.runtime.runners.python.beam.BeamTableStatelessPythonFunctionRunner;
import org.apache.flink.table.types.logical.RowType;
@@ -31,6 +34,7 @@ import org.apache.flink.table.types.logical.RowType;
import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Struct;
+import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -41,16 +45,28 @@ import java.util.Map;
*/
public class PassThroughPythonAggregateFunctionRunner extends BeamTableStatelessPythonFunctionRunner {
+ private static final IntSerializer windowBoundarySerializer = IntSerializer.INSTANCE;
+
private final List<byte[]> buffer;
private final RowDataArrowSerializer arrowSerializer;
/**
+ * Whether it is batch over window.
+ */
+ private final boolean isBatchOverWindow;
+
+ /**
* Reusable InputStream used to holding the execution results to be deserialized.
*/
private transient ByteArrayInputStreamWithPos bais;
/**
+ * InputStream Wrapper.
+ */
+ private transient DataInputViewStreamWrapper baisWrapper;
+
+ /**
* Reusable OutputStream used to holding the serialized input elements.
*/
private transient ByteArrayOutputStreamWithPos baos;
@@ -64,10 +80,12 @@ public class PassThroughPythonAggregateFunctionRunner extends BeamTableStateless
FlinkFnApi.UserDefinedFunctions userDefinedFunctions,
String coderUrn,
Map<String, String> jobOptions,
- FlinkMetricContainer flinkMetricContainer) {
+ FlinkMetricContainer flinkMetricContainer,
+ boolean isBatchOverWindow) {
super(taskName, environmentManager, inputType, outputType, functionUrn, userDefinedFunctions,
coderUrn, jobOptions, flinkMetricContainer);
this.buffer = new LinkedList<>();
+ this.isBatchOverWindow = isBatchOverWindow;
arrowSerializer = new RowDataArrowSerializer(inputType, outputType);
}
@@ -75,6 +93,7 @@ public class PassThroughPythonAggregateFunctionRunner extends BeamTableStateless
public void open(PythonConfig config) throws Exception {
super.open(config);
bais = new ByteArrayInputStreamWithPos();
+ baisWrapper = new DataInputViewStreamWrapper(bais);
baos = new ByteArrayOutputStreamWithPos();
arrowSerializer.open(bais, baos);
}
@@ -85,8 +104,28 @@ public class PassThroughPythonAggregateFunctionRunner extends BeamTableStateless
this.mainInputReceiver = input -> {
byte[] data = input.getValue();
bais.setBuffer(data, 0, data.length);
- arrowSerializer.load();
- arrowSerializer.write(arrowSerializer.read(0));
+ if (isBatchOverWindow) {
+ int windowSize = windowBoundarySerializer.deserialize(baisWrapper);
+ List<Integer> lowerBoundarys = new ArrayList<>();
+ for (int i = 0; i < windowSize; i++) {
+ int windowLength = windowBoundarySerializer.deserialize(baisWrapper);
+ for (int j = 0; j < windowLength; j++) {
+ if (j % 2 == 0) {
+ lowerBoundarys.add(windowBoundarySerializer.deserialize(baisWrapper));
+ } else {
+ windowBoundarySerializer.deserialize(baisWrapper);
+ }
+ }
+ }
+ arrowSerializer.load();
+ for (Integer lowerBoundary : lowerBoundarys) {
+ RowData firstData = arrowSerializer.read(lowerBoundary);
+ arrowSerializer.write(firstData);
+ }
+ } else {
+ arrowSerializer.load();
+ arrowSerializer.write(arrowSerializer.read(0));
+ }
arrowSerializer.finishCurrentBatch();
buffer.add(baos.toByteArray());
baos.reset();