You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2020/08/16 09:12:03 UTC
[flink] 01/02: [FLINK-18943][python] Support CoMapFunction for
Python DataStream API
This is an automated email from the ASF dual-hosted git repository.
hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 175038b705f82655a4a082de64a73f989f6abea4
Author: hequn.chq <he...@alibaba-inc.com>
AuthorDate: Sat Aug 15 12:39:41 2020 +0800
[FLINK-18943][python] Support CoMapFunction for Python DataStream API
---
flink-python/dev/glibc_version_fix.h | 0
flink-python/pyflink/datastream/data_stream.py | 105 ++++++-
flink-python/pyflink/datastream/functions.py | 35 +++
.../pyflink/datastream/tests/test_data_stream.py | 38 +++
.../pyflink/fn_execution/flink_fn_execution_pb2.py | 90 +++---
.../pyflink/fn_execution/operation_utils.py | 6 +
.../pyflink/proto/flink-fn-execution.proto | 2 +
...eamTwoInputPythonStatelessFunctionOperator.java | 196 +++++++++++++
.../python/AbstractPythonFunctionOperator.java | 320 +--------------------
...ava => AbstractPythonFunctionOperatorBase.java} | 18 +-
.../AbstractTwoInputPythonFunctionOperator.java | 44 +++
11 files changed, 482 insertions(+), 372 deletions(-)
diff --git a/flink-python/dev/glibc_version_fix.h b/flink-python/dev/glibc_version_fix.h
old mode 100644
new mode 100755
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 67b2903..17abb70 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -24,7 +24,7 @@ from pyflink.common.typeinfo import TypeInformation
from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \
MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, FilterFunction, \
FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, ReduceFunction, \
- ReduceFunctionWrapper
+ ReduceFunctionWrapper, CoMapFunction
from pyflink.java_gateway import get_gateway
@@ -359,6 +359,17 @@ class DataStream(object):
j_united_stream = self._j_data_stream.union(j_data_stream_arr)
return DataStream(j_data_stream=j_united_stream)
+ def connect(self, ds: 'DataStream') -> 'ConnectedStreams':
+ """
+ Creates a new 'ConnectedStreams' by connecting 'DataStream' outputs of (possible)
+ different types with each other. The DataStreams connected using this operator can
+ be used with CoFunctions to apply joint transformations.
+
+ :param ds: The DataStream with which this stream will be connected.
+ :return: The `ConnectedStreams`.
+ """
+ return ConnectedStreams(self, ds)
+
def shuffle(self) -> 'DataStream':
"""
Sets the partitioning of the DataStream so that the output elements are shuffled uniformly
@@ -687,6 +698,9 @@ class KeyedStream(DataStream):
j_python_data_stream_scalar_function_operator
))
+ def connect(self, ds: 'KeyedStream') -> 'ConnectedStreams':
+ raise Exception('Connect on KeyedStream has not been supported yet.')
+
def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
return self._values().filter(func)
@@ -763,3 +777,92 @@ class KeyedStream(DataStream):
def slot_sharing_group(self, slot_sharing_group: str) -> 'DataStream':
raise Exception("Setting slot sharing group for KeyedStream is not supported.")
+
+
+class ConnectedStreams(object):
+ """
+ ConnectedStreams represent two connected streams of (possibly) different data types.
+ Connected streams are useful for cases where operations on one stream directly
+ affect the operations on the other stream, usually via shared state between the streams.
+
+ An example for the use of connected streams would be to apply rules that change over time
+ onto another stream. One of the connected streams has the rules, the other stream the
+ elements to apply the rules to. The operation on the connected stream maintains the
+ current set of rules in the state. It may receive either a rule update and update the state
+ or a data element and apply the rules in the state to the element.
+
+ The connected stream can be conceptually viewed as a union stream of an Either type, that
+ holds either the first stream's type or the second stream's type.
+ """
+
+ def __init__(self, stream1: DataStream, stream2: DataStream):
+ self.stream1 = stream1
+ self.stream2 = stream2
+
+ def map(self, func: CoMapFunction, type_info: TypeInformation = None) \
+ -> 'DataStream':
+ """
+ Applies a CoMap transformation on a `ConnectedStreams` and maps the output to a common
+ type. The transformation calls a `CoMapFunction.map1` for each element of the first
+ input and `CoMapFunction.map2` for each element of the second input. Each CoMapFunction
+ call returns exactly one element.
+
+ :param func: The CoMapFunction used to jointly transform the two input DataStreams
+ :param type_info: `TypeInformation` for the result type of the function.
+ :return: The transformed `DataStream`
+ """
+ if not isinstance(func, CoMapFunction):
+ raise TypeError("The input function must be a CoMapFunction!")
+ func_name = str(func)
+
+ # get connected stream
+ j_connected_stream = self.stream1._j_data_stream.connect(self.stream2._j_data_stream)
+ from pyflink.fn_execution import flink_fn_execution_pb2
+ j_operator, j_output_type = self._get_connected_stream_operator(
+ func, type_info, func_name, flink_fn_execution_pb2.UserDefinedDataStreamFunction.CO_MAP)
+ return DataStream(j_connected_stream.transform("Co-Process", j_output_type, j_operator))
+
+ def _get_connected_stream_operator(self, func: Union[Function, FunctionWrapper],
+ type_info: TypeInformation, func_name: str,
+ func_type: int):
+ gateway = get_gateway()
+ import cloudpickle
+ serialized_func = cloudpickle.dumps(func)
+
+ j_input_types1 = self.stream1._j_data_stream.getTransformation().getOutputType()
+ j_input_types2 = self.stream2._j_data_stream.getTransformation().getOutputType()
+
+ if type_info is None:
+ output_type_info = PickledBytesTypeInfo.PICKLED_BYTE_ARRAY_TYPE_INFO()
+ else:
+ if isinstance(type_info, list):
+ output_type_info = RowTypeInfo(type_info)
+ else:
+ output_type_info = type_info
+
+ DataStreamPythonFunction = gateway.jvm.org.apache.flink.datastream.runtime.functions \
+ .python.DataStreamPythonFunction
+ j_python_data_stream_scalar_function = DataStreamPythonFunction(
+ func_name,
+ bytearray(serialized_func),
+ _get_python_env())
+
+ DataStreamPythonFunctionInfo = gateway.jvm. \
+ org.apache.flink.datastream.runtime.functions.python \
+ .DataStreamPythonFunctionInfo
+
+ j_python_data_stream_function_info = DataStreamPythonFunctionInfo(
+ j_python_data_stream_scalar_function,
+ func_type)
+
+ j_conf = gateway.jvm.org.apache.flink.configuration.Configuration()
+ DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
+ .operators.python.DataStreamTwoInputPythonStatelessFunctionOperator
+ j_python_data_stream_function_operator = DataStreamPythonFunctionOperator(
+ j_conf,
+ j_input_types1,
+ j_input_types2,
+ output_type_info.get_java_type_info(),
+ j_python_data_stream_function_info)
+
+ return j_python_data_stream_function_operator, output_type_info.get_java_type_info()
diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py
index c5049e8..10433ee 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -56,6 +56,41 @@ class MapFunction(Function):
pass
+class CoMapFunction(Function):
+ """
+ A CoMapFunction implements a map() transformation over two connected streams.
+
+ The same instance of the transformation function is used to transform both of
+ the connected streams. That way, the stream transformations can share state.
+
+ The basic syntax for using a MapFunction is as follows:
+ ::
+ >>> ds1 = ...
+ >>> ds2 = ...
+ >>> new_ds = ds1.connect(ds2).map(MyCoMapFunction())
+ """
+
+ @abc.abstractmethod
+ def map1(self, value):
+ """
+ This method is called for each element in the first of the connected streams.
+
+ :param value: The stream element
+ :return: The resulting element
+ """
+ pass
+
+ @abc.abstractmethod
+ def map2(self, value):
+ """
+ This method is called for each element in the second of the connected streams.
+
+ :param value: The stream element
+ :return: The resulting element
+ """
+ pass
+
+
class FlatMapFunction(Function):
"""
Base class for flatMap functions. FLatMap functions take elements and transform them, into zero,
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 937d1db..d05b2e7 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -22,6 +22,7 @@ from pyflink.datastream import StreamExecutionEnvironment
from pyflink.datastream.functions import FilterFunction
from pyflink.datastream.functions import KeySelector
from pyflink.datastream.functions import MapFunction, FlatMapFunction
+from pyflink.datastream.functions import CoMapFunction
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.java_gateway import get_gateway
from pyflink.testing.test_case_utils import PyFlinkTestCase
@@ -105,6 +106,34 @@ class DataStreamTests(PyFlinkTestCase):
results.sort()
self.assertEqual(expected, results)
+ def test_co_map_function_without_data_types(self):
+ self.env.set_parallelism(1)
+ ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)],
+ type_info=Types.ROW([Types.INT(), Types.INT()]))
+ ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")],
+ type_info=Types.ROW([Types.STRING(), Types.STRING()]))
+ ds1.connect(ds2).map(MyCoMapFunction()).add_sink(self.test_sink)
+ self.env.execute('co_map_function_test')
+ results = self.test_sink.get_results(True)
+ expected = ['2', '3', '4', 'a', 'b', 'c']
+ expected.sort()
+ results.sort()
+ self.assertEqual(expected, results)
+
+ def test_co_map_function_with_data_types(self):
+ self.env.set_parallelism(1)
+ ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)],
+ type_info=Types.ROW([Types.INT(), Types.INT()]))
+ ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")],
+ type_info=Types.ROW([Types.STRING(), Types.STRING()]))
+ ds1.connect(ds2).map(MyCoMapFunction(), type_info=Types.STRING()).add_sink(self.test_sink)
+ self.env.execute('co_map_function_test')
+ results = self.test_sink.get_results(False)
+ expected = ['2', '3', '4', 'a', 'b', 'c']
+ expected.sort()
+ results.sort()
+ self.assertEqual(expected, results)
+
def test_map_function_with_data_types_and_function_object(self):
ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
type_info=Types.ROW([Types.STRING(), Types.INT()]))
@@ -431,3 +460,12 @@ class MyFilterFunction(FilterFunction):
def filter(self, value):
return value[0] % 2 == 0
+
+
+class MyCoMapFunction(CoMapFunction):
+
+ def map1(self, value):
+ return str(value[0] + 1)
+
+ def map2(self, value):
+ return value[0]
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index 6a7945e..f90071d 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
name='flink-fn-execution.proto',
package='org.apache.flink.fn_execution.v1',
syntax='proto3',
- serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...]
+ serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...]
)
@@ -59,11 +59,19 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = _descriptor.EnumDescriptor(
name='REDUCE', index=2, number=2,
options=None,
type=None),
+ _descriptor.EnumValueDescriptor(
+ name='CO_MAP', index=3, number=3,
+ options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='CO_FLAT_MAP', index=4, number=4,
+ options=None,
+ type=None),
],
containing_type=None,
options=None,
serialized_start=585,
- serialized_end=634,
+ serialized_end=663,
)
_sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE)
@@ -160,8 +168,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=2514,
- serialized_end=2797,
+ serialized_start=2543,
+ serialized_end=2826,
)
_sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME)
@@ -246,8 +254,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
- serialized_start=3318,
- serialized_end=3549,
+ serialized_start=3347,
+ serialized_end=3578,
)
_sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME)
@@ -410,7 +418,7 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
oneofs=[
],
serialized_start=435,
- serialized_end=634,
+ serialized_end=663,
)
@@ -447,8 +455,8 @@ _USERDEFINEDDATASTREAMFUNCTIONS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=637,
- serialized_end=772,
+ serialized_start=666,
+ serialized_end=801,
)
@@ -485,8 +493,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=850,
- serialized_end=1001,
+ serialized_start=879,
+ serialized_end=1030,
)
_SCHEMA_TIMEINFO = _descriptor.Descriptor(
@@ -515,8 +523,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1003,
- serialized_end=1032,
+ serialized_start=1032,
+ serialized_end=1061,
)
_SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
@@ -545,8 +553,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1034,
- serialized_end=1068,
+ serialized_start=1063,
+ serialized_end=1097,
)
_SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -575,8 +583,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1070,
- serialized_end=1114,
+ serialized_start=1099,
+ serialized_end=1143,
)
_SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -605,8 +613,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1116,
- serialized_end=1155,
+ serialized_start=1145,
+ serialized_end=1184,
)
_SCHEMA_DECIMALINFO = _descriptor.Descriptor(
@@ -642,8 +650,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1157,
- serialized_end=1204,
+ serialized_start=1186,
+ serialized_end=1233,
)
_SCHEMA_BINARYINFO = _descriptor.Descriptor(
@@ -672,8 +680,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1206,
- serialized_end=1234,
+ serialized_start=1235,
+ serialized_end=1263,
)
_SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
@@ -702,8 +710,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1236,
- serialized_end=1267,
+ serialized_start=1265,
+ serialized_end=1296,
)
_SCHEMA_CHARINFO = _descriptor.Descriptor(
@@ -732,8 +740,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1269,
- serialized_end=1295,
+ serialized_start=1298,
+ serialized_end=1324,
)
_SCHEMA_VARCHARINFO = _descriptor.Descriptor(
@@ -762,8 +770,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=1297,
- serialized_end=1326,
+ serialized_start=1326,
+ serialized_end=1355,
)
_SCHEMA_FIELDTYPE = _descriptor.Descriptor(
@@ -886,8 +894,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
name='type_info', full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=1329,
- serialized_end=2401,
+ serialized_start=1358,
+ serialized_end=2430,
)
_SCHEMA_FIELD = _descriptor.Descriptor(
@@ -930,8 +938,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2403,
- serialized_end=2511,
+ serialized_start=2432,
+ serialized_end=2540,
)
_SCHEMA = _descriptor.Descriptor(
@@ -961,8 +969,8 @@ _SCHEMA = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=775,
- serialized_end=2797,
+ serialized_start=804,
+ serialized_end=2826,
)
@@ -1016,8 +1024,8 @@ _TYPEINFO_FIELDTYPE = _descriptor.Descriptor(
name='type_info', full_name='org.apache.flink.fn_execution.v1.TypeInfo.FieldType.type_info',
index=0, containing_type=None, fields=[]),
],
- serialized_start=2878,
- serialized_end=3203,
+ serialized_start=2907,
+ serialized_end=3232,
)
_TYPEINFO_FIELD = _descriptor.Descriptor(
@@ -1060,8 +1068,8 @@ _TYPEINFO_FIELD = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=3205,
- serialized_end=3315,
+ serialized_start=3234,
+ serialized_end=3344,
)
_TYPEINFO = _descriptor.Descriptor(
@@ -1091,8 +1099,8 @@ _TYPEINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
- serialized_start=2800,
- serialized_end=3549,
+ serialized_start=2829,
+ serialized_end=3578,
)
_USERDEFINEDFUNCTION_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py b/flink-python/pyflink/fn_execution/operation_utils.py
index 52426ce..c361681 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -108,6 +108,12 @@ def extract_data_stream_stateless_funcs(udfs):
def wrap_func(value):
return reduce_func(value[0], value[1])
func = wrap_func
+ elif func_type == udf.CO_MAP:
+ co_map_func = cloudpickle.loads(udfs[0].payload)
+
+ def wrap_func(value):
+ return co_map_func.map1(value[1]) if value[0] else co_map_func.map2(value[2])
+ func = wrap_func
return func
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto
index 9a1e64b..cd903f3 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -58,6 +58,8 @@ message UserDefinedDataStreamFunction {
MAP = 0;
FLAT_MAP = 1;
REDUCE = 2;
+ CO_MAP = 3;
+ CO_FLAT_MAP = 4;
}
FunctionType functionType = 1;
bytes payload = 2;
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
new file mode 100644
index 0000000..7383076
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.runtime.operators.python;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.datastream.runtime.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.datastream.runtime.runners.python.beam.BeamDataStreamPythonStatelessFunctionRunner;
+import org.apache.flink.datastream.runtime.typeutils.python.PythonTypeUtils;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.python.PythonFunctionRunner;
+import org.apache.flink.streaming.api.operators.python.AbstractTwoInputPythonFunctionOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.functions.python.PythonEnv;
+import org.apache.flink.table.runtime.util.StreamRecordCollector;
+import org.apache.flink.types.Row;
+
+import com.google.protobuf.ByteString;
+
+import java.util.Map;
+
+/**
+ * {@link DataStreamTwoInputPythonStatelessFunctionOperator} is responsible for launching beam
+ * runner which will start a python harness to execute two-input user defined python function.
+ */
+public class DataStreamTwoInputPythonStatelessFunctionOperator<IN1, IN2, OUT>
+ extends AbstractTwoInputPythonFunctionOperator<IN1, IN2, OUT> {
+
+ private static final long serialVersionUID = 1L;
+
+ private static final String DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN =
+ "flink:transform:datastream_stateless_function:v1";
+ private static final String DATA_STREAM_MAP_FUNCTION_CODER_URN = "flink:coder:datastream:map_function:v1";
+ private static final String DATA_STREAM_FLAT_MAP_FUNCTION_CODER_URN = "flink:coder:datastream:flatmap_function:v1";
+
+
+ private final DataStreamPythonFunctionInfo pythonFunctionInfo;
+
+ private final TypeInformation<OUT> outputTypeInfo;
+
+ private final Map<String, String> jobOptions;
+ private transient TypeSerializer<OUT> outputTypeSerializer;
+
+ private transient ByteArrayInputStreamWithPos bais;
+
+ private transient DataInputViewStreamWrapper baisWrapper;
+
+ private transient ByteArrayOutputStreamWithPos baos;
+
+ private transient DataOutputViewStreamWrapper baosWrapper;
+
+ private transient StreamRecordCollector streamRecordCollector;
+
+ private transient TypeSerializer<Row> runnerInputTypeSerializer;
+
+ private final TypeInformation<Row> runnerInputTypeInfo;
+
+ private transient Row reuseRow;
+
+ public DataStreamTwoInputPythonStatelessFunctionOperator(
+ Configuration config,
+ TypeInformation<IN1> inputTypeInfo1,
+ TypeInformation<IN2> inputTypeInfo2,
+ TypeInformation<OUT> outputTypeInfo,
+ DataStreamPythonFunctionInfo pythonFunctionInfo) {
+ super(config);
+ this.pythonFunctionInfo = pythonFunctionInfo;
+ jobOptions = config.toMap();
+ this.outputTypeInfo = outputTypeInfo;
+ // The row contains three field. The first field indicate left input or right input
+ // The second field contains left input and the third field contains right input.
+ runnerInputTypeInfo = new RowTypeInfo(Types.BOOLEAN, inputTypeInfo1, inputTypeInfo2);
+ }
+
+ @Override
+ public void open() throws Exception {
+ super.open();
+ bais = new ByteArrayInputStreamWithPos();
+ baisWrapper = new DataInputViewStreamWrapper(bais);
+
+ baos = new ByteArrayOutputStreamWithPos();
+ baosWrapper = new DataOutputViewStreamWrapper(baos);
+ this.outputTypeSerializer = PythonTypeUtils.TypeInfoToSerializerConverter
+ .typeInfoSerializerConverter(outputTypeInfo);
+ runnerInputTypeSerializer = PythonTypeUtils.TypeInfoToSerializerConverter
+ .typeInfoSerializerConverter(runnerInputTypeInfo);
+
+ reuseRow = new Row(3);
+ this.streamRecordCollector = new StreamRecordCollector(output);
+ }
+
+ @Override
+ public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
+
+ String coderUrn;
+ int functionType = this.pythonFunctionInfo.getFunctionType();
+ if (functionType == FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.CO_MAP.getNumber()) {
+ coderUrn = DATA_STREAM_MAP_FUNCTION_CODER_URN;
+ } else if (functionType == FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.CO_FLAT_MAP.getNumber()) {
+ coderUrn = DATA_STREAM_FLAT_MAP_FUNCTION_CODER_URN;
+ } else {
+ throw new RuntimeException("Function Type for ConnectedStream should be Map or FlatMap");
+ }
+
+ return new BeamDataStreamPythonStatelessFunctionRunner(
+ getRuntimeContext().getTaskName(),
+ createPythonEnvironmentManager(),
+ runnerInputTypeInfo,
+ outputTypeInfo,
+ DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN,
+ getUserDefinedDataStreamFunctionsProto(),
+ coderUrn,
+ jobOptions,
+ getFlinkMetricContainer()
+ );
+ }
+
+ @Override
+ public PythonEnv getPythonEnv() {
+ return pythonFunctionInfo.getPythonFunction().getPythonEnv();
+ }
+
+ @Override
+ public void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception {
+ byte[] rawResult = resultTuple.f0;
+ int length = resultTuple.f1;
+ bais.setBuffer(rawResult, 0, length);
+ streamRecordCollector.collect(outputTypeSerializer.deserialize(baisWrapper));
+ }
+
+ @Override
+ public void processElement1(StreamRecord<IN1> element) throws Exception {
+ // construct combined row.
+ reuseRow.setField(0, true);
+ reuseRow.setField(1, element.getValue());
+ reuseRow.setField(2, null); // need to set null since it is a reuse row.
+ processElement();
+ }
+
+ @Override
+ public void processElement2(StreamRecord<IN2> element) throws Exception {
+ // construct combined row.
+ reuseRow.setField(0, false);
+ reuseRow.setField(1, null); // need to set null since it is a reuse row.
+ reuseRow.setField(2, element.getValue());
+ processElement();
+ }
+
+ private void processElement() throws Exception {
+ runnerInputTypeSerializer.serialize(reuseRow, baosWrapper);
+ pythonFunctionRunner.process(baos.toByteArray());
+ baos.reset();
+ checkInvokeFinishBundleByCount();
+ emitResults();
+ }
+
+ protected FlinkFnApi.UserDefinedDataStreamFunctions getUserDefinedDataStreamFunctionsProto() {
+ FlinkFnApi.UserDefinedDataStreamFunctions.Builder builder = FlinkFnApi.UserDefinedDataStreamFunctions.newBuilder();
+ builder.addUdfs(getUserDefinedDataStreamFunctionProto(pythonFunctionInfo));
+ return builder.build();
+ }
+
+ private FlinkFnApi.UserDefinedDataStreamFunction getUserDefinedDataStreamFunctionProto(
+ DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo) {
+ FlinkFnApi.UserDefinedDataStreamFunction.Builder builder =
+ FlinkFnApi.UserDefinedDataStreamFunction.newBuilder();
+ builder.setFunctionType(FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.forNumber(
+ dataStreamPythonFunctionInfo.getFunctionType()));
+ builder.setPayload(ByteString.copyFrom(
+ dataStreamPythonFunctionInfo.getPythonFunction().getSerializedPythonFunction()));
+ return builder.build();
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
index 7361320..1295b8f 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
@@ -19,340 +19,26 @@
package org.apache.flink.streaming.api.operators.python;
import org.apache.flink.annotation.Internal;
-import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
-import org.apache.flink.configuration.MemorySize;
-import org.apache.flink.python.PythonConfig;
-import org.apache.flink.python.PythonFunctionRunner;
-import org.apache.flink.python.PythonOptions;
-import org.apache.flink.python.env.PythonDependencyInfo;
-import org.apache.flink.python.env.PythonEnvironmentManager;
-import org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager;
-import org.apache.flink.python.metric.FlinkMetricContainer;
-import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.runtime.memory.MemoryReservationException;
-import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
-import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
-import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.table.functions.python.PythonEnv;
-import org.apache.flink.util.Preconditions;
-
-import java.io.IOException;
-import java.util.concurrent.ScheduledFuture;
/**
- * Base class for all stream operators to execute Python functions.
+ * Base class for all one input stream operators to execute Python functions.
*/
@Internal
public abstract class AbstractPythonFunctionOperator<IN, OUT>
- extends AbstractStreamOperator<OUT>
+ extends AbstractPythonFunctionOperatorBase<OUT>
implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
private static final long serialVersionUID = 1L;
- /**
- * The {@link PythonFunctionRunner} which is responsible for Python user-defined function execution.
- */
- protected transient PythonFunctionRunner pythonFunctionRunner;
-
- /**
- * Max number of elements to include in a bundle.
- */
- protected transient int maxBundleSize;
-
- /**
- * Number of processed elements in the current bundle.
- */
- private transient int elementCount;
-
- /**
- * Max duration of a bundle.
- */
- private transient long maxBundleTimeMills;
-
- /**
- * Time that the last bundle was finished.
- */
- private transient long lastFinishBundleTime;
-
- /**
- * A timer that finishes the current bundle after a fixed amount of time.
- */
- private transient ScheduledFuture<?> checkFinishBundleTimer;
-
- /**
- * Callback to be executed after the current bundle was finished.
- */
- private transient Runnable bundleFinishedCallback;
-
- /**
- * The size of the reserved memory from the MemoryManager.
- */
- private transient long reservedMemory;
-
- /**
- * The python config.
- */
- private PythonConfig config;
-
public AbstractPythonFunctionOperator(Configuration config) {
- this.config = new PythonConfig(Preconditions.checkNotNull(config));
- this.chainingStrategy = ChainingStrategy.ALWAYS;
- }
-
- public PythonConfig getPythonConfig() {
- return config;
- }
-
- @Override
- public void open() throws Exception {
- try {
-
- if (config.isUsingManagedMemory()) {
- reserveMemoryForPythonWorker();
- }
-
- this.maxBundleSize = config.getMaxBundleSize();
- if (this.maxBundleSize <= 0) {
- this.maxBundleSize = PythonOptions.MAX_BUNDLE_SIZE.defaultValue();
- LOG.error("Invalid value for the maximum bundle size. Using default value of " +
- this.maxBundleSize + '.');
- } else {
- LOG.info("The maximum bundle size is configured to {}.", this.maxBundleSize);
- }
-
- this.maxBundleTimeMills = config.getMaxBundleTimeMills();
- if (this.maxBundleTimeMills <= 0L) {
- this.maxBundleTimeMills = PythonOptions.MAX_BUNDLE_TIME_MILLS.defaultValue();
- LOG.error("Invalid value for the maximum bundle time. Using default value of " +
- this.maxBundleTimeMills + '.');
- } else {
- LOG.info("The maximum bundle time is configured to {} milliseconds.", this.maxBundleTimeMills);
- }
-
- this.pythonFunctionRunner = createPythonFunctionRunner();
- this.pythonFunctionRunner.open(config);
-
- this.elementCount = 0;
- this.lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime();
-
- // Schedule timer to check timeout of finish bundle.
- long bundleCheckPeriod = Math.max(this.maxBundleTimeMills, 1);
- this.checkFinishBundleTimer =
- getProcessingTimeService()
- .scheduleAtFixedRate(
- // ProcessingTimeService callbacks are executed under the checkpointing lock
- timestamp -> checkInvokeFinishBundleByTime(), bundleCheckPeriod, bundleCheckPeriod);
- } finally {
- super.open();
- }
- }
-
- @Override
- public void close() throws Exception {
- try {
- invokeFinishBundle();
- } finally {
- super.close();
- }
- }
-
- @Override
- public void dispose() throws Exception {
- try {
- if (checkFinishBundleTimer != null) {
- checkFinishBundleTimer.cancel(true);
- checkFinishBundleTimer = null;
- }
- if (pythonFunctionRunner != null) {
- pythonFunctionRunner.close();
- pythonFunctionRunner = null;
- }
- if (reservedMemory > 0) {
- getContainingTask().getEnvironment().getMemoryManager().releaseMemory(this, reservedMemory);
- reservedMemory = -1;
- }
- } finally {
- super.dispose();
- }
+ super(config);
}
@Override
public void endInput() throws Exception {
invokeFinishBundle();
}
-
- @Override
- public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
- try {
- invokeFinishBundle();
- } finally {
- super.prepareSnapshotPreBarrier(checkpointId);
- }
- }
-
- @Override
- public void processWatermark(Watermark mark) throws Exception {
- // Due to the asynchronous communication with the SDK harness,
- // a bundle might still be in progress and not all items have
- // yet been received from the SDK harness. If we just set this
- // watermark as the new output watermark, we could violate the
- // order of the records, i.e. pending items in the SDK harness
- // could become "late" although they were "on time".
- //
- // We can solve this problem using one of the following options:
- //
- // 1) Finish the current bundle and emit this watermark as the
- // new output watermark. Finishing the bundle ensures that
- // all the items have been processed by the SDK harness and
- // the execution results sent to the downstream operator.
- //
- // 2) Hold on the output watermark for as long as the current
- // bundle has not been finished. We have to remember to manually
- // finish the bundle in case we receive the final watermark.
- // To avoid latency, we should process this watermark again as
- // soon as the current bundle is finished.
- //
- // Approach 1) is the easiest and gives better latency, yet 2)
- // gives better throughput due to the bundle not getting cut on
- // every watermark. So we have implemented 2) below.
- if (mark.getTimestamp() == Long.MAX_VALUE) {
- invokeFinishBundle();
- super.processWatermark(mark);
- } else if (elementCount == 0) {
- // forward the watermark immediately if the bundle is already finished.
- super.processWatermark(mark);
- } else {
- // It is not safe to advance the output watermark yet, so add a hold on the current
- // output watermark.
- bundleFinishedCallback =
- () -> {
- try {
- // at this point the bundle is finished, allow the watermark to pass
- super.processWatermark(mark);
- } catch (Exception e) {
- throw new RuntimeException(
- "Failed to process watermark after finished bundle.", e);
- }
- };
- }
- }
-
- /**
- * Reset the {@link PythonConfig} if needed.
- * */
- public void setPythonConfig(PythonConfig pythonConfig) {
- this.config = pythonConfig;
- }
-
- /**
- * Returns the {@link PythonConfig}.
- * */
- public PythonConfig getConfig() {
- return config;
- }
-
- /**
- * Creates the {@link PythonFunctionRunner} which is responsible for Python user-defined function execution.
- */
- public abstract PythonFunctionRunner createPythonFunctionRunner() throws Exception;
-
- /**
- * Returns the {@link PythonEnv} used to create PythonEnvironmentManager..
- */
- public abstract PythonEnv getPythonEnv();
-
- /**
- * Sends the execution result to the downstream operator.
- */
- public abstract void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception;
-
- /**
- * Reserves the memory used by the Python worker from the MemoryManager. This makes sure that
- * the memory used by the Python worker is managed by Flink.
- */
- private void reserveMemoryForPythonWorker() throws MemoryReservationException {
- long requiredPythonWorkerMemory = MemorySize.parse(config.getPythonFrameworkMemorySize())
- .add(MemorySize.parse(config.getPythonDataBufferMemorySize()))
- .getBytes();
- MemoryManager memoryManager = getContainingTask().getEnvironment().getMemoryManager();
- long availableManagedMemory = memoryManager.computeMemorySize(
- getOperatorConfig().getManagedMemoryFraction());
- if (requiredPythonWorkerMemory <= availableManagedMemory) {
- memoryManager.reserveMemory(this, requiredPythonWorkerMemory);
- LOG.info("Reserved memory {} for Python worker.", requiredPythonWorkerMemory);
- this.reservedMemory = requiredPythonWorkerMemory;
- // TODO enforce the memory limit of the Python worker
- } else {
- LOG.warn("Required Python worker memory {} exceeds the available managed off-heap " +
- "memory {}. Skipping reserving off-heap memory from the MemoryManager. This does " +
- "not affect the functionality. However, it may affect the stability of a job as " +
- "the memory used by the Python worker is not managed by Flink.",
- requiredPythonWorkerMemory, availableManagedMemory);
- this.reservedMemory = -1;
- }
- }
-
- protected void emitResults() throws Exception {
- Tuple2<byte[], Integer> resultTuple;
- while ((resultTuple = pythonFunctionRunner.pollResult()) != null) {
- emitResult(resultTuple);
- }
- }
-
- /**
- * Checks whether to invoke finishBundle by elements count. Called in processElement.
- */
- protected void checkInvokeFinishBundleByCount() throws Exception {
- elementCount++;
- if (elementCount >= maxBundleSize) {
- invokeFinishBundle();
- }
- }
-
- /**
- * Checks whether to invoke finishBundle by timeout.
- */
- private void checkInvokeFinishBundleByTime() throws Exception {
- long now = getProcessingTimeService().getCurrentProcessingTime();
- if (now - lastFinishBundleTime >= maxBundleTimeMills) {
- invokeFinishBundle();
- }
- }
-
- protected void invokeFinishBundle() throws Exception {
- if (elementCount > 0) {
- pythonFunctionRunner.flush();
- elementCount = 0;
- emitResults();
- lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime();
- // callback only after current bundle was fully finalized
- if (bundleFinishedCallback != null) {
- bundleFinishedCallback.run();
- bundleFinishedCallback = null;
- }
- }
- }
-
- protected PythonEnvironmentManager createPythonEnvironmentManager() throws IOException {
- PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create(
- config, getRuntimeContext().getDistributedCache());
- PythonEnv pythonEnv = getPythonEnv();
- if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) {
- return new ProcessPythonEnvironmentManager(
- dependencyInfo,
- getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(),
- System.getenv());
- } else {
- throw new UnsupportedOperationException(String.format(
- "Execution type '%s' is not supported.", pythonEnv.getExecType()));
- }
- }
-
- protected FlinkMetricContainer getFlinkMetricContainer() {
- return this.config.isMetricEnabled() ?
- new FlinkMetricContainer(getRuntimeContext().getMetricGroup()) : null;
- }
}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
similarity index 95%
copy from flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
copy to flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
index 7361320..280bf0a 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
@@ -32,9 +32,7 @@ import org.apache.flink.python.metric.FlinkMetricContainer;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.memory.MemoryReservationException;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.util.Preconditions;
@@ -46,9 +44,8 @@ import java.util.concurrent.ScheduledFuture;
* Base class for all stream operators to execute Python functions.
*/
@Internal
-public abstract class AbstractPythonFunctionOperator<IN, OUT>
- extends AbstractStreamOperator<OUT>
- implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+public abstract class AbstractPythonFunctionOperatorBase<OUT>
+ extends AbstractStreamOperator<OUT> {
private static final long serialVersionUID = 1L;
@@ -65,7 +62,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
/**
* Number of processed elements in the current bundle.
*/
- private transient int elementCount;
+ protected transient int elementCount;
/**
* Max duration of a bundle.
@@ -85,7 +82,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
/**
* Callback to be executed after the current bundle was finished.
*/
- private transient Runnable bundleFinishedCallback;
+ protected transient Runnable bundleFinishedCallback;
/**
* The size of the reserved memory from the MemoryManager.
@@ -97,7 +94,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
*/
private PythonConfig config;
- public AbstractPythonFunctionOperator(Configuration config) {
+ public AbstractPythonFunctionOperatorBase(Configuration config) {
this.config = new PythonConfig(Preconditions.checkNotNull(config));
this.chainingStrategy = ChainingStrategy.ALWAYS;
}
@@ -180,11 +177,6 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
}
@Override
- public void endInput() throws Exception {
- invokeFinishBundle();
- }
-
- @Override
public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
try {
invokeFinishBundle();
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java
new file mode 100644
index 0000000..ed221c6
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+
+/**
+ * Base class for all two input stream operators to execute Python functions.
+ */
+@Internal
+public abstract class AbstractTwoInputPythonFunctionOperator<IN1, IN2, OUT>
+ extends AbstractPythonFunctionOperatorBase<OUT>
+ implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput {
+
+ private static final long serialVersionUID = 1L;
+
+ public AbstractTwoInputPythonFunctionOperator(Configuration config) {
+ super(config);
+ }
+
+ @Override
+ public void endInput(int inputId) throws Exception {
+ invokeFinishBundle();
+ }
+}