You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by he...@apache.org on 2020/08/16 09:12:03 UTC

[flink] 01/02: [FLINK-18943][python] Support CoMapFunction for Python DataStream API

This is an automated email from the ASF dual-hosted git repository.

hequn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 175038b705f82655a4a082de64a73f989f6abea4
Author: hequn.chq <he...@alibaba-inc.com>
AuthorDate: Sat Aug 15 12:39:41 2020 +0800

    [FLINK-18943][python] Support CoMapFunction for Python DataStream API
---
 flink-python/dev/glibc_version_fix.h               |   0
 flink-python/pyflink/datastream/data_stream.py     | 105 ++++++-
 flink-python/pyflink/datastream/functions.py       |  35 +++
 .../pyflink/datastream/tests/test_data_stream.py   |  38 +++
 .../pyflink/fn_execution/flink_fn_execution_pb2.py |  90 +++---
 .../pyflink/fn_execution/operation_utils.py        |   6 +
 .../pyflink/proto/flink-fn-execution.proto         |   2 +
 ...eamTwoInputPythonStatelessFunctionOperator.java | 196 +++++++++++++
 .../python/AbstractPythonFunctionOperator.java     | 320 +--------------------
 ...ava => AbstractPythonFunctionOperatorBase.java} |  18 +-
 .../AbstractTwoInputPythonFunctionOperator.java    |  44 +++
 11 files changed, 482 insertions(+), 372 deletions(-)

diff --git a/flink-python/dev/glibc_version_fix.h b/flink-python/dev/glibc_version_fix.h
old mode 100644
new mode 100755
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 67b2903..17abb70 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -24,7 +24,7 @@ from pyflink.common.typeinfo import TypeInformation
 from pyflink.datastream.functions import _get_python_env, FlatMapFunctionWrapper, FlatMapFunction, \
     MapFunction, MapFunctionWrapper, Function, FunctionWrapper, SinkFunction, FilterFunction, \
     FilterFunctionWrapper, KeySelectorFunctionWrapper, KeySelector, ReduceFunction, \
-    ReduceFunctionWrapper
+    ReduceFunctionWrapper, CoMapFunction
 from pyflink.java_gateway import get_gateway
 
 
@@ -359,6 +359,17 @@ class DataStream(object):
         j_united_stream = self._j_data_stream.union(j_data_stream_arr)
         return DataStream(j_data_stream=j_united_stream)
 
+    def connect(self, ds: 'DataStream') -> 'ConnectedStreams':
+        """
+        Creates a new 'ConnectedStreams' by connecting 'DataStream' outputs of (possible)
+        different types with each other. The DataStreams connected using this operator can
+        be used with CoFunctions to apply joint transformations.
+
+        :param ds: The DataStream with which this stream will be connected.
+        :return: The `ConnectedStreams`.
+        """
+        return ConnectedStreams(self, ds)
+
     def shuffle(self) -> 'DataStream':
         """
         Sets the partitioning of the DataStream so that the output elements are shuffled uniformly
@@ -687,6 +698,9 @@ class KeyedStream(DataStream):
             j_python_data_stream_scalar_function_operator
         ))
 
+    def connect(self, ds: 'KeyedStream') -> 'ConnectedStreams':
+        raise Exception('Connect on KeyedStream has not been supported yet.')
+
     def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
         return self._values().filter(func)
 
@@ -763,3 +777,92 @@ class KeyedStream(DataStream):
 
     def slot_sharing_group(self, slot_sharing_group: str) -> 'DataStream':
         raise Exception("Setting slot sharing group for KeyedStream is not supported.")
+
+
+class ConnectedStreams(object):
+    """
+    ConnectedStreams represent two connected streams of (possibly) different data types.
+    Connected streams are useful for cases where operations on one stream directly
+    affect the operations on the other stream, usually via shared state between the streams.
+
+    An example for the use of connected streams would be to apply rules that change over time
+    onto another stream. One of the connected streams has the rules, the other stream the
+    elements to apply the rules to. The operation on the connected stream maintains the
+    current set of rules in the state. It may receive either a rule update and update the state
+    or a data element and apply the rules in the state to the element.
+
+    The connected stream can be conceptually viewed as a union stream of an Either type, that
+    holds either the first stream's type or the second stream's type.
+    """
+
+    def __init__(self, stream1: DataStream, stream2: DataStream):
+        self.stream1 = stream1
+        self.stream2 = stream2
+
+    def map(self, func: CoMapFunction, type_info: TypeInformation = None) \
+            -> 'DataStream':
+        """
+        Applies a CoMap transformation on a `ConnectedStreams` and maps the output to a common
+        type. The transformation calls a `CoMapFunction.map1` for each element of the first
+        input and `CoMapFunction.map2` for each element of the second input. Each CoMapFunction
+        call returns exactly one element.
+
+        :param func: The CoMapFunction used to jointly transform the two input DataStreams
+        :param type_info: `TypeInformation` for the result type of the function.
+        :return: The transformed `DataStream`
+        """
+        if not isinstance(func, CoMapFunction):
+            raise TypeError("The input function must be a CoMapFunction!")
+        func_name = str(func)
+
+        # get connected stream
+        j_connected_stream = self.stream1._j_data_stream.connect(self.stream2._j_data_stream)
+        from pyflink.fn_execution import flink_fn_execution_pb2
+        j_operator, j_output_type = self._get_connected_stream_operator(
+            func, type_info, func_name, flink_fn_execution_pb2.UserDefinedDataStreamFunction.CO_MAP)
+        return DataStream(j_connected_stream.transform("Co-Process", j_output_type, j_operator))
+
+    def _get_connected_stream_operator(self, func: Union[Function, FunctionWrapper],
+                                       type_info: TypeInformation, func_name: str,
+                                       func_type: int):
+        gateway = get_gateway()
+        import cloudpickle
+        serialized_func = cloudpickle.dumps(func)
+
+        j_input_types1 = self.stream1._j_data_stream.getTransformation().getOutputType()
+        j_input_types2 = self.stream2._j_data_stream.getTransformation().getOutputType()
+
+        if type_info is None:
+            output_type_info = PickledBytesTypeInfo.PICKLED_BYTE_ARRAY_TYPE_INFO()
+        else:
+            if isinstance(type_info, list):
+                output_type_info = RowTypeInfo(type_info)
+            else:
+                output_type_info = type_info
+
+        DataStreamPythonFunction = gateway.jvm.org.apache.flink.datastream.runtime.functions \
+            .python.DataStreamPythonFunction
+        j_python_data_stream_scalar_function = DataStreamPythonFunction(
+            func_name,
+            bytearray(serialized_func),
+            _get_python_env())
+
+        DataStreamPythonFunctionInfo = gateway.jvm. \
+            org.apache.flink.datastream.runtime.functions.python \
+            .DataStreamPythonFunctionInfo
+
+        j_python_data_stream_function_info = DataStreamPythonFunctionInfo(
+            j_python_data_stream_scalar_function,
+            func_type)
+
+        j_conf = gateway.jvm.org.apache.flink.configuration.Configuration()
+        DataStreamPythonFunctionOperator = gateway.jvm.org.apache.flink.datastream.runtime \
+            .operators.python.DataStreamTwoInputPythonStatelessFunctionOperator
+        j_python_data_stream_function_operator = DataStreamPythonFunctionOperator(
+            j_conf,
+            j_input_types1,
+            j_input_types2,
+            output_type_info.get_java_type_info(),
+            j_python_data_stream_function_info)
+
+        return j_python_data_stream_function_operator, output_type_info.get_java_type_info()
diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py
index c5049e8..10433ee 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -56,6 +56,41 @@ class MapFunction(Function):
         pass
 
 
+class CoMapFunction(Function):
+    """
+    A CoMapFunction implements a map() transformation over two connected streams.
+
+    The same instance of the transformation function is used to transform both of
+    the connected streams. That way, the stream transformations can share state.
+
+    The basic syntax for using a MapFunction is as follows:
+    ::
+        >>> ds1 = ...
+        >>> ds2 = ...
+        >>> new_ds = ds1.connect(ds2).map(MyCoMapFunction())
+    """
+
+    @abc.abstractmethod
+    def map1(self, value):
+        """
+        This method is called for each element in the first of the connected streams.
+
+        :param value: The stream element
+        :return: The resulting element
+        """
+        pass
+
+    @abc.abstractmethod
+    def map2(self, value):
+        """
+        This method is called for each element in the second of the connected streams.
+
+        :param value: The stream element
+        :return: The resulting element
+        """
+        pass
+
+
 class FlatMapFunction(Function):
     """
     Base class for flatMap functions. FLatMap functions take elements and transform them, into zero,
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 937d1db..d05b2e7 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -22,6 +22,7 @@ from pyflink.datastream import StreamExecutionEnvironment
 from pyflink.datastream.functions import FilterFunction
 from pyflink.datastream.functions import KeySelector
 from pyflink.datastream.functions import MapFunction, FlatMapFunction
+from pyflink.datastream.functions import CoMapFunction
 from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
 from pyflink.java_gateway import get_gateway
 from pyflink.testing.test_case_utils import PyFlinkTestCase
@@ -105,6 +106,34 @@ class DataStreamTests(PyFlinkTestCase):
         results.sort()
         self.assertEqual(expected, results)
 
+    def test_co_map_function_without_data_types(self):
+        self.env.set_parallelism(1)
+        ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)],
+                                       type_info=Types.ROW([Types.INT(), Types.INT()]))
+        ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")],
+                                       type_info=Types.ROW([Types.STRING(), Types.STRING()]))
+        ds1.connect(ds2).map(MyCoMapFunction()).add_sink(self.test_sink)
+        self.env.execute('co_map_function_test')
+        results = self.test_sink.get_results(True)
+        expected = ['2', '3', '4', 'a', 'b', 'c']
+        expected.sort()
+        results.sort()
+        self.assertEqual(expected, results)
+
+    def test_co_map_function_with_data_types(self):
+        self.env.set_parallelism(1)
+        ds1 = self.env.from_collection([(1, 1), (2, 2), (3, 3)],
+                                       type_info=Types.ROW([Types.INT(), Types.INT()]))
+        ds2 = self.env.from_collection([("a", "a"), ("b", "b"), ("c", "c")],
+                                       type_info=Types.ROW([Types.STRING(), Types.STRING()]))
+        ds1.connect(ds2).map(MyCoMapFunction(), type_info=Types.STRING()).add_sink(self.test_sink)
+        self.env.execute('co_map_function_test')
+        results = self.test_sink.get_results(False)
+        expected = ['2', '3', '4', 'a', 'b', 'c']
+        expected.sort()
+        results.sort()
+        self.assertEqual(expected, results)
+
     def test_map_function_with_data_types_and_function_object(self):
         ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
                                       type_info=Types.ROW([Types.STRING(), Types.INT()]))
@@ -431,3 +460,12 @@ class MyFilterFunction(FilterFunction):
 
     def filter(self, value):
         return value[0] % 2 == 0
+
+
+class MyCoMapFunction(CoMapFunction):
+
+    def map1(self, value):
+        return str(value[0] + 1)
+
+    def map2(self, value):
+        return value[0]
diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
index 6a7945e..f90071d 100644
--- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
+++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py
@@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
   name='flink-fn-execution.proto',
   package='org.apache.flink.fn_execution.v1',
   syntax='proto3',
-  serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...]
+  serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\xfc\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12K\n\x06inputs\x18\x02 \x03(\x0b\x32;.org.apache.flink.fn_execution.v1.UserDefinedFunction.Input\x1a\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\ [...]
 )
 
 
@@ -59,11 +59,19 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = _descriptor.EnumDescriptor(
       name='REDUCE', index=2, number=2,
       options=None,
       type=None),
+    _descriptor.EnumValueDescriptor(
+      name='CO_MAP', index=3, number=3,
+      options=None,
+      type=None),
+    _descriptor.EnumValueDescriptor(
+      name='CO_FLAT_MAP', index=4, number=4,
+      options=None,
+      type=None),
   ],
   containing_type=None,
   options=None,
   serialized_start=585,
-  serialized_end=634,
+  serialized_end=663,
 )
 _sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE)
 
@@ -160,8 +168,8 @@ _SCHEMA_TYPENAME = _descriptor.EnumDescriptor(
   ],
   containing_type=None,
   options=None,
-  serialized_start=2514,
-  serialized_end=2797,
+  serialized_start=2543,
+  serialized_end=2826,
 )
 _sym_db.RegisterEnumDescriptor(_SCHEMA_TYPENAME)
 
@@ -246,8 +254,8 @@ _TYPEINFO_TYPENAME = _descriptor.EnumDescriptor(
   ],
   containing_type=None,
   options=None,
-  serialized_start=3318,
-  serialized_end=3549,
+  serialized_start=3347,
+  serialized_end=3578,
 )
 _sym_db.RegisterEnumDescriptor(_TYPEINFO_TYPENAME)
 
@@ -410,7 +418,7 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor(
   oneofs=[
   ],
   serialized_start=435,
-  serialized_end=634,
+  serialized_end=663,
 )
 
 
@@ -447,8 +455,8 @@ _USERDEFINEDDATASTREAMFUNCTIONS = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=637,
-  serialized_end=772,
+  serialized_start=666,
+  serialized_end=801,
 )
 
 
@@ -485,8 +493,8 @@ _SCHEMA_MAPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=850,
-  serialized_end=1001,
+  serialized_start=879,
+  serialized_end=1030,
 )
 
 _SCHEMA_TIMEINFO = _descriptor.Descriptor(
@@ -515,8 +523,8 @@ _SCHEMA_TIMEINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1003,
-  serialized_end=1032,
+  serialized_start=1032,
+  serialized_end=1061,
 )
 
 _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
@@ -545,8 +553,8 @@ _SCHEMA_TIMESTAMPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1034,
-  serialized_end=1068,
+  serialized_start=1063,
+  serialized_end=1097,
 )
 
 _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -575,8 +583,8 @@ _SCHEMA_LOCALZONEDTIMESTAMPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1070,
-  serialized_end=1114,
+  serialized_start=1099,
+  serialized_end=1143,
 )
 
 _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
@@ -605,8 +613,8 @@ _SCHEMA_ZONEDTIMESTAMPINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1116,
-  serialized_end=1155,
+  serialized_start=1145,
+  serialized_end=1184,
 )
 
 _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
@@ -642,8 +650,8 @@ _SCHEMA_DECIMALINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1157,
-  serialized_end=1204,
+  serialized_start=1186,
+  serialized_end=1233,
 )
 
 _SCHEMA_BINARYINFO = _descriptor.Descriptor(
@@ -672,8 +680,8 @@ _SCHEMA_BINARYINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1206,
-  serialized_end=1234,
+  serialized_start=1235,
+  serialized_end=1263,
 )
 
 _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
@@ -702,8 +710,8 @@ _SCHEMA_VARBINARYINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1236,
-  serialized_end=1267,
+  serialized_start=1265,
+  serialized_end=1296,
 )
 
 _SCHEMA_CHARINFO = _descriptor.Descriptor(
@@ -732,8 +740,8 @@ _SCHEMA_CHARINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1269,
-  serialized_end=1295,
+  serialized_start=1298,
+  serialized_end=1324,
 )
 
 _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
@@ -762,8 +770,8 @@ _SCHEMA_VARCHARINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=1297,
-  serialized_end=1326,
+  serialized_start=1326,
+  serialized_end=1355,
 )
 
 _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
@@ -886,8 +894,8 @@ _SCHEMA_FIELDTYPE = _descriptor.Descriptor(
       name='type_info', full_name='org.apache.flink.fn_execution.v1.Schema.FieldType.type_info',
       index=0, containing_type=None, fields=[]),
   ],
-  serialized_start=1329,
-  serialized_end=2401,
+  serialized_start=1358,
+  serialized_end=2430,
 )
 
 _SCHEMA_FIELD = _descriptor.Descriptor(
@@ -930,8 +938,8 @@ _SCHEMA_FIELD = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=2403,
-  serialized_end=2511,
+  serialized_start=2432,
+  serialized_end=2540,
 )
 
 _SCHEMA = _descriptor.Descriptor(
@@ -961,8 +969,8 @@ _SCHEMA = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=775,
-  serialized_end=2797,
+  serialized_start=804,
+  serialized_end=2826,
 )
 
 
@@ -1016,8 +1024,8 @@ _TYPEINFO_FIELDTYPE = _descriptor.Descriptor(
       name='type_info', full_name='org.apache.flink.fn_execution.v1.TypeInfo.FieldType.type_info',
       index=0, containing_type=None, fields=[]),
   ],
-  serialized_start=2878,
-  serialized_end=3203,
+  serialized_start=2907,
+  serialized_end=3232,
 )
 
 _TYPEINFO_FIELD = _descriptor.Descriptor(
@@ -1060,8 +1068,8 @@ _TYPEINFO_FIELD = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=3205,
-  serialized_end=3315,
+  serialized_start=3234,
+  serialized_end=3344,
 )
 
 _TYPEINFO = _descriptor.Descriptor(
@@ -1091,8 +1099,8 @@ _TYPEINFO = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=2800,
-  serialized_end=3549,
+  serialized_start=2829,
+  serialized_end=3578,
 )
 
 _USERDEFINEDFUNCTION_INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py b/flink-python/pyflink/fn_execution/operation_utils.py
index 52426ce..c361681 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -108,6 +108,12 @@ def extract_data_stream_stateless_funcs(udfs):
         def wrap_func(value):
             return reduce_func(value[0], value[1])
         func = wrap_func
+    elif func_type == udf.CO_MAP:
+        co_map_func = cloudpickle.loads(udfs[0].payload)
+
+        def wrap_func(value):
+            return co_map_func.map1(value[1]) if value[0] else co_map_func.map2(value[2])
+        func = wrap_func
     return func
 
 
diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto
index 9a1e64b..cd903f3 100644
--- a/flink-python/pyflink/proto/flink-fn-execution.proto
+++ b/flink-python/pyflink/proto/flink-fn-execution.proto
@@ -58,6 +58,8 @@ message UserDefinedDataStreamFunction {
     MAP = 0;
     FLAT_MAP = 1;
     REDUCE = 2;
+    CO_MAP = 3;
+    CO_FLAT_MAP = 4;
   }
   FunctionType functionType = 1;
   bytes payload = 2;
diff --git a/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
new file mode 100644
index 0000000..7383076
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/datastream/runtime/operators/python/DataStreamTwoInputPythonStatelessFunctionOperator.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.runtime.operators.python;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.datastream.runtime.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.datastream.runtime.runners.python.beam.BeamDataStreamPythonStatelessFunctionRunner;
+import org.apache.flink.datastream.runtime.typeutils.python.PythonTypeUtils;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.python.PythonFunctionRunner;
+import org.apache.flink.streaming.api.operators.python.AbstractTwoInputPythonFunctionOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.functions.python.PythonEnv;
+import org.apache.flink.table.runtime.util.StreamRecordCollector;
+import org.apache.flink.types.Row;
+
+import com.google.protobuf.ByteString;
+
+import java.util.Map;
+
+/**
+ * {@link DataStreamTwoInputPythonStatelessFunctionOperator} is responsible for launching beam
+ * runner which will start a python harness to execute two-input user defined python function.
+ */
+public class DataStreamTwoInputPythonStatelessFunctionOperator<IN1, IN2, OUT>
+	extends AbstractTwoInputPythonFunctionOperator<IN1, IN2, OUT> {
+
+	private static final long serialVersionUID = 1L;
+
+	private static final String DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN =
+		"flink:transform:datastream_stateless_function:v1";
+	private static final String DATA_STREAM_MAP_FUNCTION_CODER_URN = "flink:coder:datastream:map_function:v1";
+	private static final String DATA_STREAM_FLAT_MAP_FUNCTION_CODER_URN = "flink:coder:datastream:flatmap_function:v1";
+
+
+	private final DataStreamPythonFunctionInfo pythonFunctionInfo;
+
+	private final TypeInformation<OUT> outputTypeInfo;
+
+	private final Map<String, String> jobOptions;
+	private transient TypeSerializer<OUT> outputTypeSerializer;
+
+	private transient ByteArrayInputStreamWithPos bais;
+
+	private transient DataInputViewStreamWrapper baisWrapper;
+
+	private transient ByteArrayOutputStreamWithPos baos;
+
+	private transient DataOutputViewStreamWrapper baosWrapper;
+
+	private transient StreamRecordCollector streamRecordCollector;
+
+	private transient TypeSerializer<Row> runnerInputTypeSerializer;
+
+	private final TypeInformation<Row> runnerInputTypeInfo;
+
+	private transient Row reuseRow;
+
+	public DataStreamTwoInputPythonStatelessFunctionOperator(
+		Configuration config,
+		TypeInformation<IN1> inputTypeInfo1,
+		TypeInformation<IN2> inputTypeInfo2,
+		TypeInformation<OUT> outputTypeInfo,
+		DataStreamPythonFunctionInfo pythonFunctionInfo) {
+		super(config);
+		this.pythonFunctionInfo = pythonFunctionInfo;
+		jobOptions = config.toMap();
+		this.outputTypeInfo = outputTypeInfo;
+		// The row contains three field. The first field indicate left input or right input
+		// The second field contains left input and the third field contains right input.
+		runnerInputTypeInfo = new RowTypeInfo(Types.BOOLEAN, inputTypeInfo1, inputTypeInfo2);
+	}
+
+	@Override
+	public void open() throws Exception {
+		super.open();
+		bais = new ByteArrayInputStreamWithPos();
+		baisWrapper = new DataInputViewStreamWrapper(bais);
+
+		baos = new ByteArrayOutputStreamWithPos();
+		baosWrapper = new DataOutputViewStreamWrapper(baos);
+		this.outputTypeSerializer = PythonTypeUtils.TypeInfoToSerializerConverter
+			.typeInfoSerializerConverter(outputTypeInfo);
+		runnerInputTypeSerializer = PythonTypeUtils.TypeInfoToSerializerConverter
+			.typeInfoSerializerConverter(runnerInputTypeInfo);
+
+		reuseRow = new Row(3);
+		this.streamRecordCollector = new StreamRecordCollector(output);
+	}
+
+	@Override
+	public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
+
+		String coderUrn;
+		int functionType = this.pythonFunctionInfo.getFunctionType();
+		if (functionType == FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.CO_MAP.getNumber()) {
+			coderUrn = DATA_STREAM_MAP_FUNCTION_CODER_URN;
+		} else if (functionType == FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.CO_FLAT_MAP.getNumber()) {
+			coderUrn = DATA_STREAM_FLAT_MAP_FUNCTION_CODER_URN;
+		} else {
+			throw new RuntimeException("Function Type for ConnectedStream should be Map or FlatMap");
+		}
+
+		return new BeamDataStreamPythonStatelessFunctionRunner(
+			getRuntimeContext().getTaskName(),
+			createPythonEnvironmentManager(),
+			runnerInputTypeInfo,
+			outputTypeInfo,
+			DATA_STREAM_STATELESS_PYTHON_FUNCTION_URN,
+			getUserDefinedDataStreamFunctionsProto(),
+			coderUrn,
+			jobOptions,
+			getFlinkMetricContainer()
+		);
+	}
+
+	@Override
+	public PythonEnv getPythonEnv() {
+		return pythonFunctionInfo.getPythonFunction().getPythonEnv();
+	}
+
+	@Override
+	public void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception {
+		byte[] rawResult = resultTuple.f0;
+		int length = resultTuple.f1;
+		bais.setBuffer(rawResult, 0, length);
+		streamRecordCollector.collect(outputTypeSerializer.deserialize(baisWrapper));
+	}
+
+	@Override
+	public void processElement1(StreamRecord<IN1> element) throws Exception {
+		// construct combined row.
+		reuseRow.setField(0, true);
+		reuseRow.setField(1, element.getValue());
+		reuseRow.setField(2, null);  // need to set null since it is a reuse row.
+		processElement();
+	}
+
+	@Override
+	public void processElement2(StreamRecord<IN2> element) throws Exception {
+		// construct combined row.
+		reuseRow.setField(0, false);
+		reuseRow.setField(1, null); // need to set null since it is a reuse row.
+		reuseRow.setField(2, element.getValue());
+		processElement();
+	}
+
+	private void processElement() throws Exception {
+		runnerInputTypeSerializer.serialize(reuseRow, baosWrapper);
+		pythonFunctionRunner.process(baos.toByteArray());
+		baos.reset();
+		checkInvokeFinishBundleByCount();
+		emitResults();
+	}
+
+	protected FlinkFnApi.UserDefinedDataStreamFunctions getUserDefinedDataStreamFunctionsProto() {
+		FlinkFnApi.UserDefinedDataStreamFunctions.Builder builder = FlinkFnApi.UserDefinedDataStreamFunctions.newBuilder();
+		builder.addUdfs(getUserDefinedDataStreamFunctionProto(pythonFunctionInfo));
+		return builder.build();
+	}
+
+	private FlinkFnApi.UserDefinedDataStreamFunction getUserDefinedDataStreamFunctionProto(
+		DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo) {
+		FlinkFnApi.UserDefinedDataStreamFunction.Builder builder =
+			FlinkFnApi.UserDefinedDataStreamFunction.newBuilder();
+		builder.setFunctionType(FlinkFnApi.UserDefinedDataStreamFunction.FunctionType.forNumber(
+			dataStreamPythonFunctionInfo.getFunctionType()));
+		builder.setPayload(ByteString.copyFrom(
+			dataStreamPythonFunctionInfo.getPythonFunction().getSerializedPythonFunction()));
+		return builder.build();
+	}
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
index 7361320..1295b8f 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
@@ -19,340 +19,26 @@
 package org.apache.flink.streaming.api.operators.python;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.configuration.MemorySize;
-import org.apache.flink.python.PythonConfig;
-import org.apache.flink.python.PythonFunctionRunner;
-import org.apache.flink.python.PythonOptions;
-import org.apache.flink.python.env.PythonDependencyInfo;
-import org.apache.flink.python.env.PythonEnvironmentManager;
-import org.apache.flink.python.env.beam.ProcessPythonEnvironmentManager;
-import org.apache.flink.python.metric.FlinkMetricContainer;
-import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.runtime.memory.MemoryReservationException;
-import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
-import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
-import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.table.functions.python.PythonEnv;
-import org.apache.flink.util.Preconditions;
-
-import java.io.IOException;
-import java.util.concurrent.ScheduledFuture;
 
 /**
- * Base class for all stream operators to execute Python functions.
+ * Base class for all one input stream operators to execute Python functions.
  */
 @Internal
 public abstract class AbstractPythonFunctionOperator<IN, OUT>
-	extends AbstractStreamOperator<OUT>
+	extends AbstractPythonFunctionOperatorBase<OUT>
 	implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
 	private static final long serialVersionUID = 1L;
 
-	/**
-	 * The {@link PythonFunctionRunner} which is responsible for Python user-defined function execution.
-	 */
-	protected transient PythonFunctionRunner pythonFunctionRunner;
-
-	/**
-	 * Max number of elements to include in a bundle.
-	 */
-	protected transient int maxBundleSize;
-
-	/**
-	 * Number of processed elements in the current bundle.
-	 */
-	private transient int elementCount;
-
-	/**
-	 * Max duration of a bundle.
-	 */
-	private transient long maxBundleTimeMills;
-
-	/**
-	 * Time that the last bundle was finished.
-	 */
-	private transient long lastFinishBundleTime;
-
-	/**
-	 * A timer that finishes the current bundle after a fixed amount of time.
-	 */
-	private transient ScheduledFuture<?> checkFinishBundleTimer;
-
-	/**
-	 * Callback to be executed after the current bundle was finished.
-	 */
-	private transient Runnable bundleFinishedCallback;
-
-	/**
-	 * The size of the reserved memory from the MemoryManager.
-	 */
-	private transient long reservedMemory;
-
-	/**
-	 * The python config.
-	 */
-	private PythonConfig config;
-
 	public AbstractPythonFunctionOperator(Configuration config) {
-		this.config = new PythonConfig(Preconditions.checkNotNull(config));
-		this.chainingStrategy = ChainingStrategy.ALWAYS;
-	}
-
-	public PythonConfig getPythonConfig() {
-		return config;
-	}
-
-	@Override
-	public void open() throws Exception {
-		try {
-
-			if (config.isUsingManagedMemory()) {
-				reserveMemoryForPythonWorker();
-			}
-
-			this.maxBundleSize = config.getMaxBundleSize();
-			if (this.maxBundleSize <= 0) {
-				this.maxBundleSize = PythonOptions.MAX_BUNDLE_SIZE.defaultValue();
-				LOG.error("Invalid value for the maximum bundle size. Using default value of " +
-					this.maxBundleSize + '.');
-			} else {
-				LOG.info("The maximum bundle size is configured to {}.", this.maxBundleSize);
-			}
-
-			this.maxBundleTimeMills = config.getMaxBundleTimeMills();
-			if (this.maxBundleTimeMills <= 0L) {
-				this.maxBundleTimeMills = PythonOptions.MAX_BUNDLE_TIME_MILLS.defaultValue();
-				LOG.error("Invalid value for the maximum bundle time. Using default value of " +
-					this.maxBundleTimeMills + '.');
-			} else {
-				LOG.info("The maximum bundle time is configured to {} milliseconds.", this.maxBundleTimeMills);
-			}
-
-			this.pythonFunctionRunner = createPythonFunctionRunner();
-			this.pythonFunctionRunner.open(config);
-
-			this.elementCount = 0;
-			this.lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime();
-
-			// Schedule timer to check timeout of finish bundle.
-			long bundleCheckPeriod = Math.max(this.maxBundleTimeMills, 1);
-			this.checkFinishBundleTimer =
-				getProcessingTimeService()
-					.scheduleAtFixedRate(
-						// ProcessingTimeService callbacks are executed under the checkpointing lock
-						timestamp -> checkInvokeFinishBundleByTime(), bundleCheckPeriod, bundleCheckPeriod);
-		} finally {
-			super.open();
-		}
-	}
-
-	@Override
-	public void close() throws Exception {
-		try {
-			invokeFinishBundle();
-		} finally {
-			super.close();
-		}
-	}
-
-	@Override
-	public void dispose() throws Exception {
-		try {
-			if (checkFinishBundleTimer != null) {
-				checkFinishBundleTimer.cancel(true);
-				checkFinishBundleTimer = null;
-			}
-			if (pythonFunctionRunner != null) {
-				pythonFunctionRunner.close();
-				pythonFunctionRunner = null;
-			}
-			if (reservedMemory > 0) {
-				getContainingTask().getEnvironment().getMemoryManager().releaseMemory(this, reservedMemory);
-				reservedMemory = -1;
-			}
-		} finally {
-			super.dispose();
-		}
+		super(config);
 	}
 
 	@Override
 	public void endInput() throws Exception {
 		invokeFinishBundle();
 	}
-
-	@Override
-	public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
-		try {
-			invokeFinishBundle();
-		} finally {
-			super.prepareSnapshotPreBarrier(checkpointId);
-		}
-	}
-
-	@Override
-	public void processWatermark(Watermark mark) throws Exception {
-		// Due to the asynchronous communication with the SDK harness,
-		// a bundle might still be in progress and not all items have
-		// yet been received from the SDK harness. If we just set this
-		// watermark as the new output watermark, we could violate the
-		// order of the records, i.e. pending items in the SDK harness
-		// could become "late" although they were "on time".
-		//
-		// We can solve this problem using one of the following options:
-		//
-		// 1) Finish the current bundle and emit this watermark as the
-		//    new output watermark. Finishing the bundle ensures that
-		//    all the items have been processed by the SDK harness and
-		//    the execution results sent to the downstream operator.
-		//
-		// 2) Hold on the output watermark for as long as the current
-		//    bundle has not been finished. We have to remember to manually
-		//    finish the bundle in case we receive the final watermark.
-		//    To avoid latency, we should process this watermark again as
-		//    soon as the current bundle is finished.
-		//
-		// Approach 1) is the easiest and gives better latency, yet 2)
-		// gives better throughput due to the bundle not getting cut on
-		// every watermark. So we have implemented 2) below.
-		if (mark.getTimestamp() == Long.MAX_VALUE) {
-			invokeFinishBundle();
-			super.processWatermark(mark);
-		} else if (elementCount == 0) {
-			// forward the watermark immediately if the bundle is already finished.
-			super.processWatermark(mark);
-		} else {
-			// It is not safe to advance the output watermark yet, so add a hold on the current
-			// output watermark.
-			bundleFinishedCallback =
-				() -> {
-					try {
-						// at this point the bundle is finished, allow the watermark to pass
-						super.processWatermark(mark);
-					} catch (Exception e) {
-						throw new RuntimeException(
-							"Failed to process watermark after finished bundle.", e);
-					}
-				};
-		}
-	}
-
-	/**
-	 * Reset the {@link PythonConfig} if needed.
-	 * */
-	public void setPythonConfig(PythonConfig pythonConfig) {
-		this.config = pythonConfig;
-	}
-
-	/**
-	 * Returns the {@link PythonConfig}.
-	 * */
-	public PythonConfig getConfig() {
-		return config;
-	}
-
-	/**
-	 * Creates the {@link PythonFunctionRunner} which is responsible for Python user-defined function execution.
-	 */
-	public abstract PythonFunctionRunner createPythonFunctionRunner() throws Exception;
-
-	/**
-	 * Returns the {@link PythonEnv} used to create PythonEnvironmentManager..
-	 */
-	public abstract PythonEnv getPythonEnv();
-
-	/**
-	 * Sends the execution result to the downstream operator.
-	 */
-	public abstract void emitResult(Tuple2<byte[], Integer> resultTuple) throws Exception;
-
-	/**
-	 * Reserves the memory used by the Python worker from the MemoryManager. This makes sure that
-	 * the memory used by the Python worker is managed by Flink.
-	 */
-	private void reserveMemoryForPythonWorker() throws MemoryReservationException {
-		long requiredPythonWorkerMemory = MemorySize.parse(config.getPythonFrameworkMemorySize())
-			.add(MemorySize.parse(config.getPythonDataBufferMemorySize()))
-			.getBytes();
-		MemoryManager memoryManager = getContainingTask().getEnvironment().getMemoryManager();
-		long availableManagedMemory = memoryManager.computeMemorySize(
-			getOperatorConfig().getManagedMemoryFraction());
-		if (requiredPythonWorkerMemory <= availableManagedMemory) {
-			memoryManager.reserveMemory(this, requiredPythonWorkerMemory);
-			LOG.info("Reserved memory {} for Python worker.", requiredPythonWorkerMemory);
-			this.reservedMemory = requiredPythonWorkerMemory;
-			// TODO enforce the memory limit of the Python worker
-		} else {
-			LOG.warn("Required Python worker memory {} exceeds the available managed off-heap " +
-					"memory {}. Skipping reserving off-heap memory from the MemoryManager. This does " +
-					"not affect the functionality. However, it may affect the stability of a job as " +
-					"the memory used by the Python worker is not managed by Flink.",
-				requiredPythonWorkerMemory, availableManagedMemory);
-			this.reservedMemory = -1;
-		}
-	}
-
-	protected void emitResults() throws Exception {
-		Tuple2<byte[], Integer> resultTuple;
-		while ((resultTuple = pythonFunctionRunner.pollResult()) != null) {
-			emitResult(resultTuple);
-		}
-	}
-
-	/**
-	 * Checks whether to invoke finishBundle by elements count. Called in processElement.
-	 */
-	protected void checkInvokeFinishBundleByCount() throws Exception {
-		elementCount++;
-		if (elementCount >= maxBundleSize) {
-			invokeFinishBundle();
-		}
-	}
-
-	/**
-	 * Checks whether to invoke finishBundle by timeout.
-	 */
-	private void checkInvokeFinishBundleByTime() throws Exception {
-		long now = getProcessingTimeService().getCurrentProcessingTime();
-		if (now - lastFinishBundleTime >= maxBundleTimeMills) {
-			invokeFinishBundle();
-		}
-	}
-
-	protected void invokeFinishBundle() throws Exception {
-		if (elementCount > 0) {
-			pythonFunctionRunner.flush();
-			elementCount = 0;
-			emitResults();
-			lastFinishBundleTime = getProcessingTimeService().getCurrentProcessingTime();
-			// callback only after current bundle was fully finalized
-			if (bundleFinishedCallback != null) {
-				bundleFinishedCallback.run();
-				bundleFinishedCallback = null;
-			}
-		}
-	}
-
-	protected PythonEnvironmentManager createPythonEnvironmentManager() throws IOException {
-		PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create(
-			config, getRuntimeContext().getDistributedCache());
-		PythonEnv pythonEnv = getPythonEnv();
-		if (pythonEnv.getExecType() == PythonEnv.ExecType.PROCESS) {
-			return new ProcessPythonEnvironmentManager(
-				dependencyInfo,
-				getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories(),
-				System.getenv());
-		} else {
-			throw new UnsupportedOperationException(String.format(
-				"Execution type '%s' is not supported.", pythonEnv.getExecType()));
-		}
-	}
-
-	protected FlinkMetricContainer getFlinkMetricContainer() {
-		return this.config.isMetricEnabled() ?
-			new FlinkMetricContainer(getRuntimeContext().getMetricGroup()) : null;
-	}
 }
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
similarity index 95%
copy from flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
copy to flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
index 7361320..280bf0a 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractPythonFunctionOperatorBase.java
@@ -32,9 +32,7 @@ import org.apache.flink.python.metric.FlinkMetricContainer;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.memory.MemoryReservationException;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.ChainingStrategy;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.table.functions.python.PythonEnv;
 import org.apache.flink.util.Preconditions;
@@ -46,9 +44,8 @@ import java.util.concurrent.ScheduledFuture;
  * Base class for all stream operators to execute Python functions.
  */
 @Internal
-public abstract class AbstractPythonFunctionOperator<IN, OUT>
-	extends AbstractStreamOperator<OUT>
-	implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+public abstract class AbstractPythonFunctionOperatorBase<OUT>
+	extends AbstractStreamOperator<OUT> {
 
 	private static final long serialVersionUID = 1L;
 
@@ -65,7 +62,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
 	/**
 	 * Number of processed elements in the current bundle.
 	 */
-	private transient int elementCount;
+	protected transient int elementCount;
 
 	/**
 	 * Max duration of a bundle.
@@ -85,7 +82,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
 	/**
 	 * Callback to be executed after the current bundle was finished.
 	 */
-	private transient Runnable bundleFinishedCallback;
+	protected transient Runnable bundleFinishedCallback;
 
 	/**
 	 * The size of the reserved memory from the MemoryManager.
@@ -97,7 +94,7 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
 	 */
 	private PythonConfig config;
 
-	public AbstractPythonFunctionOperator(Configuration config) {
+	public AbstractPythonFunctionOperatorBase(Configuration config) {
 		this.config = new PythonConfig(Preconditions.checkNotNull(config));
 		this.chainingStrategy = ChainingStrategy.ALWAYS;
 	}
@@ -180,11 +177,6 @@ public abstract class AbstractPythonFunctionOperator<IN, OUT>
 	}
 
 	@Override
-	public void endInput() throws Exception {
-		invokeFinishBundle();
-	}
-
-	@Override
 	public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
 		try {
 			invokeFinishBundle();
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java
new file mode 100644
index 0000000..ed221c6
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/AbstractTwoInputPythonFunctionOperator.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+
+/**
+ * Base class for all two input stream operators to execute Python functions.
+ */
+@Internal
+public abstract class AbstractTwoInputPythonFunctionOperator<IN1, IN2, OUT>
+	extends AbstractPythonFunctionOperatorBase<OUT>
+	implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput {
+
+	private static final long serialVersionUID = 1L;
+
+	public AbstractTwoInputPythonFunctionOperator(Configuration config) {
+		super(config);
+	}
+
+	@Override
+	public void endInput(int inputId) throws Exception {
+		invokeFinishBundle();
+	}
+}