You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/08/08 09:58:39 UTC

[flink] branch master updated: [FLINK-28836][python] Support broadcast in Thread Mode

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new ff91aa53cfc [FLINK-28836][python] Support broadcast in Thread Mode
ff91aa53cfc is described below

commit ff91aa53cfc9327ed591cbe0c1b3dcf9116dc3a5
Author: huangxingbo <hx...@apache.org>
AuthorDate: Mon Aug 8 14:44:08 2022 +0800

    [FLINK-28836][python] Support broadcast in Thread Mode
    
    This closes #20490.
---
 flink-python/pyflink/datastream/data_stream.py     |   3 +-
 .../pyflink/datastream/tests/test_data_stream.py   | 256 ++++++++++-----------
 .../fn_execution/datastream/embedded/operations.py |  65 +++++-
 .../datastream/embedded/process_function.py        | 128 ++++++++++-
 .../fn_execution/datastream/embedded/state_impl.py |  87 ++++++-
 .../pyflink/fn_execution/embedded/converters.py    |  13 +-
 .../fn_execution/embedded/operation_utils.py       |  12 +-
 .../pyflink/fn_execution/embedded/operations.py    |  12 +-
 ...ractOneInputEmbeddedPythonFunctionOperator.java |   4 +-
 ...ractTwoInputEmbeddedPythonFunctionOperator.java |   4 +-
 ...eddedPythonBatchCoBroadcastProcessOperator.java |  82 +++++++
 ...PythonBatchKeyedCoBroadcastProcessOperator.java |  80 +++++++
 ...thonBroadcastStateTransformationTranslator.java |  68 ++++--
 ...eyedBroadcastStateTransformationTranslator.java |  68 ++++--
 14 files changed, 691 insertions(+), 191 deletions(-)

diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 4e8b5cbb229..6938aa0d4da 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -2636,7 +2636,8 @@ class BroadcastConnectedStream(object):
             jvm.String, [i.get_name() for i in self.broadcast_state_descriptors]
         )
         j_state_descriptors = JPythonConfigUtil.convertStateNamesToStateDescriptors(j_state_names)
-        j_conf = jvm.org.apache.flink.configuration.Configuration()
+        j_conf = get_j_env_configuration(
+            self.broadcast_stream.input_stream._j_data_stream.getExecutionEnvironment())
         j_data_stream_python_function_info = _create_j_data_stream_python_function_info(
             func, func_type
         )
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index c52da695d4e..5b62715e3af 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -398,6 +398,134 @@ class DataStreamTests(object):
                     "<Row('on_timer', 4)>"]
         self.assert_equals_sorted(expected, results)
 
+    def test_co_broadcast_process(self):
+        ds = self.env.from_collection([1, 2, 3, 4, 5], type_info=Types.INT())  # type: DataStream
+        ds_broadcast = self.env.from_collection(
+            [(0, "a"), (1, "b")], type_info=Types.TUPLE([Types.INT(), Types.STRING()])
+        )  # type: DataStream
+
+        class MyBroadcastProcessFunction(BroadcastProcessFunction):
+            def __init__(self, map_state_desc):
+                self._map_state_desc = map_state_desc
+                self._cache = defaultdict(list)
+
+            def process_element(self, value: int, ctx: BroadcastProcessFunction.ReadOnlyContext):
+                ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+                key = value % 2
+                if ro_broadcast_state.contains(key):
+                    if self._cache.get(key) is not None:
+                        for v in self._cache[key]:
+                            yield ro_broadcast_state.get(key) + str(v)
+                        self._cache[key].clear()
+                    yield ro_broadcast_state.get(key) + str(value)
+                else:
+                    self._cache[key].append(value)
+
+            def process_broadcast_element(
+                self, value: Tuple[int, str], ctx: BroadcastProcessFunction.Context
+            ):
+                key = value[0]
+                yield str(key) + value[1]
+                broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+                broadcast_state.put(key, value[1])
+                if self._cache.get(key) is not None:
+                    for v in self._cache[key]:
+                        yield value[1] + str(v)
+                    self._cache[key].clear()
+
+        map_state_desc = MapStateDescriptor(
+            "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
+        )
+        ds.connect(ds_broadcast.broadcast(map_state_desc)).process(
+            MyBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
+        ).add_sink(self.test_sink)
+
+        self.env.execute("test_co_broadcast_process")
+        expected = ["0a", "0a", "1b", "1b", "a2", "a4", "b1", "b3", "b5"]
+        self.assert_equals_sorted(expected, self.test_sink.get_results())
+
+    def test_keyed_co_broadcast_process(self):
+        ds = self.env.from_collection(
+            [(1, '1603708211000'),
+             (2, '1603708212000'),
+             (3, '1603708213000'),
+             (4, '1603708214000')],
+            type_info=Types.ROW([Types.INT(), Types.STRING()]))  # type: DataStream
+        ds_broadcast = self.env.from_collection(
+            [(0, '1603708215000', 'a'),
+             (1, '1603708215000', 'b')],
+            type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])
+        )  # type: DataStream
+        watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
+            .with_timestamp_assigner(SecondColumnTimestampAssigner())
+        ds = ds.assign_timestamps_and_watermarks(watermark_strategy)
+        ds_broadcast = ds_broadcast.assign_timestamps_and_watermarks(watermark_strategy)
+
+        def _create_string(s, t):
+            return 'value: {}, ts: {}'.format(s, t)
+
+        class MyKeyedBroadcastProcessFunction(KeyedBroadcastProcessFunction):
+            def __init__(self, map_state_desc):
+                self._map_state_desc = map_state_desc
+                self._cache = None
+
+            def open(self, runtime_context: RuntimeContext):
+                self._cache = defaultdict(list)
+
+            def process_element(
+                self, value: Tuple[int, str], ctx: KeyedBroadcastProcessFunction.ReadOnlyContext
+            ):
+                ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+                key = value[0] % 2
+                if ro_broadcast_state.contains(key):
+                    if self._cache.get(key) is not None:
+                        for v in self._cache[key]:
+                            yield _create_string(ro_broadcast_state.get(key) + str(v[0]), v[1])
+                        self._cache[key].clear()
+                    yield _create_string(ro_broadcast_state.get(key) + str(value[0]), value[1])
+                else:
+                    self._cache[key].append(value)
+                ctx.timer_service().register_event_time_timer(ctx.timestamp() + 10000)
+
+            def process_broadcast_element(
+                self, value: Tuple[int, str, str], ctx: KeyedBroadcastProcessFunction.Context
+            ):
+                key = value[0]
+                yield _create_string(str(key) + value[2], ctx.timestamp())
+                broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+                broadcast_state.put(key, value[2])
+                if self._cache.get(key) is not None:
+                    for v in self._cache[key]:
+                        yield _create_string(value[2] + str(v[0]), v[1])
+                    self._cache[key].clear()
+
+            def on_timer(self, timestamp: int, ctx: KeyedBroadcastProcessFunction.OnTimerContext):
+                yield _create_string(ctx.get_current_key(), timestamp)
+
+        map_state_desc = MapStateDescriptor(
+            "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
+        )
+        ds.key_by(lambda t: t[0]).connect(ds_broadcast.broadcast(map_state_desc)).process(
+            MyKeyedBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
+        ).add_sink(self.test_sink)
+
+        self.env.execute("test_keyed_co_broadcast_process")
+        expected = [
+            'value: 0a, ts: 1603708215000',
+            'value: 0a, ts: 1603708215000',
+            'value: 1, ts: 1603708221000',
+            'value: 1b, ts: 1603708215000',
+            'value: 1b, ts: 1603708215000',
+            'value: 2, ts: 1603708222000',
+            'value: 3, ts: 1603708223000',
+            'value: 4, ts: 1603708224000',
+            'value: a2, ts: 1603708212000',
+            'value: a4, ts: 1603708214000',
+            'value: b1, ts: 1603708211000',
+            'value: b3, ts: 1603708213000'
+        ]
+        self.assert_equals_sorted(expected, self.test_sink.get_results())
+
 
 class DataStreamStreamingTests(DataStreamTests):
 
@@ -1027,134 +1155,6 @@ class ProcessDataStreamTests(DataStreamTests):
         side_expected = ['1', '1', '2', '2', '3', '3', '4', '4']
         self.assert_equals_sorted(side_expected, side_sink.get_results())
 
-    def test_co_broadcast_process(self):
-        ds = self.env.from_collection([1, 2, 3, 4, 5], type_info=Types.INT())  # type: DataStream
-        ds_broadcast = self.env.from_collection(
-            [(0, "a"), (1, "b")], type_info=Types.TUPLE([Types.INT(), Types.STRING()])
-        )  # type: DataStream
-
-        class MyBroadcastProcessFunction(BroadcastProcessFunction):
-            def __init__(self, map_state_desc):
-                self._map_state_desc = map_state_desc
-                self._cache = defaultdict(list)
-
-            def process_element(self, value: int, ctx: BroadcastProcessFunction.ReadOnlyContext):
-                ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
-                key = value % 2
-                if ro_broadcast_state.contains(key):
-                    if self._cache.get(key) is not None:
-                        for v in self._cache[key]:
-                            yield ro_broadcast_state.get(key) + str(v)
-                        self._cache[key].clear()
-                    yield ro_broadcast_state.get(key) + str(value)
-                else:
-                    self._cache[key].append(value)
-
-            def process_broadcast_element(
-                self, value: Tuple[int, str], ctx: BroadcastProcessFunction.Context
-            ):
-                key = value[0]
-                yield str(key) + value[1]
-                broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
-                broadcast_state.put(key, value[1])
-                if self._cache.get(key) is not None:
-                    for v in self._cache[key]:
-                        yield value[1] + str(v)
-                    self._cache[key].clear()
-
-        map_state_desc = MapStateDescriptor(
-            "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
-        )
-        ds.connect(ds_broadcast.broadcast(map_state_desc)).process(
-            MyBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
-        ).add_sink(self.test_sink)
-
-        self.env.execute("test_co_broadcast_process")
-        expected = ["0a", "0a", "1b", "1b", "a2", "a4", "b1", "b3", "b5"]
-        self.assert_equals_sorted(expected, self.test_sink.get_results())
-
-    def test_keyed_co_broadcast_process(self):
-        ds = self.env.from_collection(
-            [(1, '1603708211000'),
-             (2, '1603708212000'),
-             (3, '1603708213000'),
-             (4, '1603708214000')],
-            type_info=Types.ROW([Types.INT(), Types.STRING()]))  # type: DataStream
-        ds_broadcast = self.env.from_collection(
-            [(0, '1603708215000', 'a'),
-             (1, '1603708215000', 'b')],
-            type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])
-        )  # type: DataStream
-        watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
-            .with_timestamp_assigner(SecondColumnTimestampAssigner())
-        ds = ds.assign_timestamps_and_watermarks(watermark_strategy)
-        ds_broadcast = ds_broadcast.assign_timestamps_and_watermarks(watermark_strategy)
-
-        def _create_string(s, t):
-            return 'value: {}, ts: {}'.format(s, t)
-
-        class MyKeyedBroadcastProcessFunction(KeyedBroadcastProcessFunction):
-            def __init__(self, map_state_desc):
-                self._map_state_desc = map_state_desc
-                self._cache = None
-
-            def open(self, runtime_context: RuntimeContext):
-                self._cache = defaultdict(list)
-
-            def process_element(
-                self, value: Tuple[int, str], ctx: KeyedBroadcastProcessFunction.ReadOnlyContext
-            ):
-                ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
-                key = value[0] % 2
-                if ro_broadcast_state.contains(key):
-                    if self._cache.get(key) is not None:
-                        for v in self._cache[key]:
-                            yield _create_string(ro_broadcast_state.get(key) + str(v[0]), v[1])
-                        self._cache[key].clear()
-                    yield _create_string(ro_broadcast_state.get(key) + str(value[0]), value[1])
-                else:
-                    self._cache[key].append(value)
-                ctx.timer_service().register_event_time_timer(ctx.timestamp() + 10000)
-
-            def process_broadcast_element(
-                self, value: Tuple[int, str, str], ctx: KeyedBroadcastProcessFunction.Context
-            ):
-                key = value[0]
-                yield _create_string(str(key) + value[2], ctx.timestamp())
-                broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
-                broadcast_state.put(key, value[2])
-                if self._cache.get(key) is not None:
-                    for v in self._cache[key]:
-                        yield _create_string(value[2] + str(v[0]), v[1])
-                    self._cache[key].clear()
-
-            def on_timer(self, timestamp: int, ctx: KeyedBroadcastProcessFunction.OnTimerContext):
-                yield _create_string(ctx.get_current_key(), timestamp)
-
-        map_state_desc = MapStateDescriptor(
-            "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
-        )
-        ds.key_by(lambda t: t[0]).connect(ds_broadcast.broadcast(map_state_desc)).process(
-            MyKeyedBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
-        ).add_sink(self.test_sink)
-
-        self.env.execute("test_keyed_co_broadcast_process")
-        expected = [
-            'value: 0a, ts: 1603708215000',
-            'value: 0a, ts: 1603708215000',
-            'value: 1, ts: 1603708221000',
-            'value: 1b, ts: 1603708215000',
-            'value: 1b, ts: 1603708215000',
-            'value: 2, ts: 1603708222000',
-            'value: 3, ts: 1603708223000',
-            'value: 4, ts: 1603708224000',
-            'value: a2, ts: 1603708212000',
-            'value: a4, ts: 1603708214000',
-            'value: b1, ts: 1603708211000',
-            'value: b3, ts: 1603708213000'
-        ]
-        self.assert_equals_sorted(expected, self.test_sink.get_results())
-
 
 class ProcessDataStreamStreamingTests(DataStreamStreamingTests, ProcessDataStreamTests,
                                       PyFlinkStreamingTestCase):
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
index 69757e7f128..beef4be98a8 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
@@ -21,7 +21,11 @@ from pyflink.fn_execution.coders import TimeWindowCoder, CountWindowCoder
 from pyflink.fn_execution.datastream import operations
 from pyflink.fn_execution.datastream.embedded.process_function import (
     InternalProcessFunctionContext, InternalKeyedProcessFunctionContext,
-    InternalKeyedProcessFunctionOnTimerContext, InternalWindowTimerContext)
+    InternalKeyedProcessFunctionOnTimerContext, InternalWindowTimerContext,
+    InternalBroadcastProcessFunctionContext, InternalBroadcastProcessFunctionReadOnlyContext,
+    InternalKeyedBroadcastProcessFunctionContext,
+    InternalKeyedBroadcastProcessFunctionReadOnlyContext,
+    InternalKeyedBroadcastProcessFunctionOnTimerContext)
 from pyflink.fn_execution.datastream.embedded.runtime_context import StreamingRuntimeContext
 from pyflink.fn_execution.datastream.embedded.timerservice_impl import InternalTimerServiceImpl
 from pyflink.fn_execution.datastream.window.window_operator import WindowOperator
@@ -80,7 +84,7 @@ class TwoInputOperation(operations.TwoInputOperation):
 
 def extract_process_function(
         user_defined_function_proto, j_runtime_context, j_function_context, j_timer_context,
-        job_parameters, j_keyed_state_backend):
+        job_parameters, j_keyed_state_backend, j_operator_state_backend):
     from pyflink.fn_execution import flink_fn_execution_pb2
 
     user_defined_func = pickle.loads(user_defined_function_proto.payload)
@@ -151,6 +155,31 @@ def extract_process_function(
         return TwoInputOperation(
             open_func, close_func, process_element_func1, process_element_func2)
 
+    elif func_type == UserDefinedDataStreamFunction.CO_BROADCAST_PROCESS:
+
+        broadcast_ctx = InternalBroadcastProcessFunctionContext(
+            j_function_context, j_operator_state_backend)
+
+        read_only_broadcast_ctx = InternalBroadcastProcessFunctionReadOnlyContext(
+            j_function_context, j_operator_state_backend)
+
+        process_element = user_defined_func.process_element
+
+        process_broadcast_element = user_defined_func.process_broadcast_element
+
+        def process_element_func1(value):
+            elements = process_element(value, read_only_broadcast_ctx)
+            if elements:
+                yield from elements
+
+        def process_element_func2(value):
+            elements = process_broadcast_element(value, broadcast_ctx)
+            if elements:
+                yield from elements
+
+        return TwoInputOperation(
+            open_func, close_func, process_element_func1, process_element_func2)
+
     elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:
 
         function_context = InternalKeyedProcessFunctionContext(
@@ -185,6 +214,38 @@ def extract_process_function(
         return TwoInputOperation(
             open_func, close_func, process_element_func1, process_element_func2, on_timer_func)
 
+    elif func_type == UserDefinedDataStreamFunction.KEYED_CO_BROADCAST_PROCESS:
+        broadcast_ctx = InternalKeyedBroadcastProcessFunctionContext(
+            j_function_context, j_operator_state_backend)
+
+        read_only_broadcast_ctx = InternalKeyedBroadcastProcessFunctionReadOnlyContext(
+            j_function_context, user_defined_function_proto.key_type_info, j_operator_state_backend)
+
+        timer_context = InternalKeyedBroadcastProcessFunctionOnTimerContext(
+            j_timer_context, user_defined_function_proto.key_type_info, j_operator_state_backend)
+
+        process_element = user_defined_func.process_element
+
+        process_broadcast_element = user_defined_func.process_broadcast_element
+
+        on_timer = user_defined_func.on_timer
+
+        def process_element_func1(value):
+            elements = process_element(value[1], read_only_broadcast_ctx)
+            if elements:
+                yield from elements
+
+        def process_element_func2(value):
+            elements = process_broadcast_element(value, broadcast_ctx)
+            if elements:
+                yield from elements
+
+        def on_timer_func(timestamp):
+            yield from on_timer(timestamp, timer_context)
+
+        return TwoInputOperation(
+            open_func, close_func, process_element_func1, process_element_func2, on_timer_func)
+
     elif func_type == UserDefinedDataStreamFunction.WINDOW:
 
         window_operation_descriptor = (
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
index 9cccce3ae94..b23b2b3110f 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
@@ -15,10 +15,18 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
+from abc import ABC
+
 from pyflink.datastream import (ProcessFunction, KeyedProcessFunction, CoProcessFunction,
                                 KeyedCoProcessFunction, TimerService, TimeDomain)
+from pyflink.datastream.functions import (BaseBroadcastProcessFunction, BroadcastProcessFunction,
+                                          KeyedBroadcastProcessFunction)
+from pyflink.datastream.state import MapStateDescriptor, BroadcastState, ReadOnlyBroadcastState
+from pyflink.fn_execution.datastream.embedded.state_impl import (ReadOnlyBroadcastStateImpl,
+                                                                 BroadcastStateImpl)
 from pyflink.fn_execution.datastream.embedded.timerservice_impl import TimerServiceImpl
-from pyflink.fn_execution.embedded.converters import from_type_info
+from pyflink.fn_execution.embedded.converters import from_type_info_proto, from_type_info
+from pyflink.fn_execution.embedded.java_utils import to_java_state_descriptor
 
 
 class InternalProcessFunctionContext(ProcessFunction.Context, CoProcessFunction.Context,
@@ -106,3 +114,121 @@ class InternalWindowTimerContext(object):
 
     def get_current_key(self):
         return self._key_converter.to_internal(self._context.getCurrentKey())
+
+
+class InternalBaseBroadcastProcessFunctionContext(BaseBroadcastProcessFunction.Context, ABC):
+
+    def __init__(self, context, operator_state_backend):
+        self._context = context
+        self._operator_state_backend = operator_state_backend
+
+    def timestamp(self) -> int:
+        return self._context.timestamp()
+
+    def current_processing_time(self) -> int:
+        return self._context.currentProcessingTime()
+
+    def current_watermark(self) -> int:
+        return self._context.currentWatermark()
+
+
+class InternalBroadcastProcessFunctionContext(InternalBaseBroadcastProcessFunctionContext,
+                                              BroadcastProcessFunction.Context):
+
+    def __init__(self, context, operator_state_backend):
+        super(InternalBroadcastProcessFunctionContext, self).__init__(
+            context, operator_state_backend)
+
+    def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> BroadcastState:
+        return BroadcastStateImpl(
+            self._operator_state_backend.getBroadcastState(
+                to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
+
+
+class InternalBroadcastProcessFunctionReadOnlyContext(InternalBaseBroadcastProcessFunctionContext,
+                                                      BroadcastProcessFunction.ReadOnlyContext):
+
+    def __init__(self, context, operator_state_backend):
+        super(InternalBroadcastProcessFunctionReadOnlyContext, self).__init__(
+            context, operator_state_backend)
+
+    def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> ReadOnlyBroadcastState:
+        return ReadOnlyBroadcastStateImpl(
+            self._operator_state_backend.getBroadcastState(
+                to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
+
+
+class InternalKeyedBroadcastProcessFunctionContext(InternalBaseBroadcastProcessFunctionContext,
+                                                   KeyedBroadcastProcessFunction.Context):
+
+    def __init__(self, context, operator_state_backend):
+        super(InternalKeyedBroadcastProcessFunctionContext, self).__init__(
+            context, operator_state_backend)
+
+    def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> BroadcastState:
+        return BroadcastStateImpl(
+            self._operator_state_backend.getBroadcastState(
+                to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
+
+
+class InternalKeyedBroadcastProcessFunctionReadOnlyContext(
+    InternalBaseBroadcastProcessFunctionContext,
+    KeyedBroadcastProcessFunction.ReadOnlyContext
+):
+
+    def __init__(self, context, key_type_info, operator_state_backend):
+        super(InternalKeyedBroadcastProcessFunctionReadOnlyContext, self).__init__(
+            context, operator_state_backend)
+        self._key_converter = from_type_info_proto(key_type_info)
+        self._timer_service = TimerServiceImpl(self._context.timerService())
+
+    def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> ReadOnlyBroadcastState:
+        return ReadOnlyBroadcastStateImpl(
+            self._operator_state_backend.getBroadcastState(
+                to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
+
+    def timer_service(self) -> TimerService:
+        return self._timer_service
+
+    def get_current_key(self):
+        return self._key_converter.to_internal(self._context.getCurrentKey())
+
+
+class InternalKeyedBroadcastProcessFunctionOnTimerContext(
+    InternalBaseBroadcastProcessFunctionContext,
+    KeyedBroadcastProcessFunction.OnTimerContext,
+):
+
+    def __init__(self, context, key_type_info, operator_state_backend):
+        super(InternalKeyedBroadcastProcessFunctionOnTimerContext, self).__init__(
+            context, operator_state_backend)
+        self._timer_service = TimerServiceImpl(self._context.timerService())
+        self._key_converter = from_type_info_proto(key_type_info)
+
+    def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> ReadOnlyBroadcastState:
+        return ReadOnlyBroadcastStateImpl(
+            self._operator_state_backend.getBroadcastState(
+                to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
+
+    def current_processing_time(self) -> int:
+        return self._timer_service.current_processing_time()
+
+    def current_watermark(self) -> int:
+        return self._timer_service.current_watermark()
+
+    def timer_service(self) -> TimerService:
+        return self._timer_service
+
+    def timestamp(self) -> int:
+        return self._context.timestamp()
+
+    def time_domain(self) -> TimeDomain:
+        return TimeDomain(self._context.timeDomain())
+
+    def get_current_key(self):
+        return self._key_converter.to_internal(self._context.getCurrentKey())
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
index 6973be038eb..8eb28eaf380 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
@@ -19,33 +19,42 @@ from abc import ABC
 from typing import List, Iterable, Tuple, Dict, Collection
 
 from pyflink.datastream import ReduceFunction, AggregateFunction
-from pyflink.datastream.state import (T, IN, OUT, V, K)
+from pyflink.datastream.state import (T, IN, OUT, V, K, State)
 from pyflink.fn_execution.embedded.converters import (DataConverter, DictDataConverter,
                                                       ListDataConverter)
 from pyflink.fn_execution.internal_state import (InternalValueState, InternalKvState,
                                                  InternalListState, InternalReducingState,
                                                  InternalAggregatingState, InternalMapState,
-                                                 N)
+                                                 N, InternalReadOnlyBroadcastState,
+                                                 InternalBroadcastState)
 
 
-class StateImpl(InternalKvState, ABC):
+class StateImpl(State, ABC):
     def __init__(self,
                  state,
-                 value_converter: DataConverter,
-                 window_converter: DataConverter = None):
+                 value_converter: DataConverter):
         self._state = state
         self._value_converter = value_converter
-        self._window_converter = window_converter
 
     def clear(self):
         self._state.clear()
 
+
+class KeyedStateImpl(StateImpl, InternalKvState, ABC):
+
+    def __init__(self,
+                 state,
+                 value_converter: DataConverter,
+                 window_converter: DataConverter = None):
+        super(KeyedStateImpl, self).__init__(state, value_converter)
+        self._window_converter = window_converter
+
     def set_current_namespace(self, namespace) -> None:
         j_window = self._window_converter.to_external(namespace)
         self._state.setCurrentNamespace(j_window)
 
 
-class ValueStateImpl(StateImpl, InternalValueState):
+class ValueStateImpl(KeyedStateImpl, InternalValueState):
     def __init__(self,
                  value_state,
                  value_converter: DataConverter,
@@ -59,7 +68,7 @@ class ValueStateImpl(StateImpl, InternalValueState):
         self._state.update(self._value_converter.to_external(value))
 
 
-class ListStateImpl(StateImpl, InternalListState):
+class ListStateImpl(KeyedStateImpl, InternalListState):
 
     def __init__(self,
                  list_state,
@@ -88,7 +97,7 @@ class ListStateImpl(StateImpl, InternalListState):
         self._state.mergeNamespaces(j_target, j_sources)
 
 
-class ReducingStateImpl(StateImpl, InternalReducingState):
+class ReducingStateImpl(KeyedStateImpl, InternalReducingState):
 
     def __init__(self,
                  value_state,
@@ -135,7 +144,7 @@ class ReducingStateImpl(StateImpl, InternalReducingState):
             self._state.update(self._value_converter.to_external(merged))
 
 
-class AggregatingStateImpl(StateImpl, InternalAggregatingState):
+class AggregatingStateImpl(KeyedStateImpl, InternalAggregatingState):
     def __init__(self,
                  value_state,
                  value_converter,
@@ -185,7 +194,7 @@ class AggregatingStateImpl(StateImpl, InternalAggregatingState):
             self._state.update(self._value_converter.to_external(merged))
 
 
-class MapStateImpl(StateImpl, InternalMapState):
+class MapStateImpl(KeyedStateImpl, InternalMapState):
     def __init__(self,
                  map_state,
                  map_converter: DictDataConverter,
@@ -231,3 +240,59 @@ class MapStateImpl(StateImpl, InternalMapState):
 
     def is_empty(self) -> bool:
         return self._state.isEmpty()
+
+
+class ReadOnlyBroadcastStateImpl(StateImpl, InternalReadOnlyBroadcastState):
+
+    def __init__(self,
+                 map_state,
+                 map_converter: DictDataConverter):
+        super(ReadOnlyBroadcastStateImpl, self).__init__(map_state, map_converter)
+        self._k_converter = map_converter._key_converter
+        self._v_converter = map_converter._value_converter
+
+    def get(self, key: K) -> V:
+        return self._v_converter.to_internal(
+            self._state.get(self._k_converter.to_external(key)))
+
+    def contains(self, key: K) -> bool:
+        return self._state.contains(self._k_converter.to_external(key))
+
+    def items(self) -> Iterable[Tuple[K, V]]:
+        entries = self._state.entries()
+        for entry in entries:
+            yield (self._k_converter.to_internal(entry.getKey()),
+                   self._v_converter.to_internal(entry.getValue()))
+
+    def keys(self) -> Iterable[K]:
+        for k in self._state.keys():
+            yield self._k_converter.to_internal(k)
+
+    def values(self) -> Iterable[V]:
+        for v in self._state.values():
+            yield self._v_converter.to_internal(v)
+
+    def is_empty(self) -> bool:
+        return self._state.isEmpty()
+
+
+class BroadcastStateImpl(ReadOnlyBroadcastStateImpl, InternalBroadcastState):
+    def __init__(self,
+                 map_state,
+                 map_converter: DictDataConverter):
+        super(BroadcastStateImpl, self).__init__(map_state, map_converter)
+        self._map_converter = map_converter
+        self._k_converter = map_converter._key_converter
+        self._v_converter = map_converter._value_converter
+
+    def to_read_only_broadcast_state(self) -> InternalReadOnlyBroadcastState[K, V]:
+        return ReadOnlyBroadcastStateImpl(self._state, self._map_converter)
+
+    def put(self, key: K, value: V) -> None:
+        self._state.put(self._k_converter.to_external(key), self._v_converter.to_external(value))
+
+    def put_all(self, dict_value: Dict[K, V]) -> None:
+        self._state.putAll(self._value_converter.to_external(dict_value))
+
+    def remove(self, key: K) -> None:
+        self._state.remove(self._k_converter.to_external(key))
diff --git a/flink-python/pyflink/fn_execution/embedded/converters.py b/flink-python/pyflink/fn_execution/embedded/converters.py
index d2ba951f884..8829d48074e 100644
--- a/flink-python/pyflink/fn_execution/embedded/converters.py
+++ b/flink-python/pyflink/fn_execution/embedded/converters.py
@@ -95,18 +95,19 @@ class RowDataConverter(DataConverter):
 
     def __init__(self, field_data_converters: List[DataConverter], field_names: List[str]):
         self._field_data_converters = field_data_converters
-        self._reuse_row = Row()
-        self._reuse_row.set_field_names(field_names)
+        self._field_names = field_names
 
     def to_internal(self, value) -> IN:
         if value is None:
             return None
 
-        self._reuse_row._values = [self._field_data_converters[i].to_internal(item)
-                                   for i, item in enumerate(value[1])]
-        self._reuse_row.set_row_kind(RowKind(value[0]))
+        row = Row()
+        row._values = [self._field_data_converters[i].to_internal(item)
+                       for i, item in enumerate(value[1])]
+        row.set_field_names(self._field_names)
+        row.set_row_kind(RowKind(value[0]))
 
-        return self._reuse_row
+        return row
 
     def to_external(self, value: Row) -> OUT:
         if value is None:
diff --git a/flink-python/pyflink/fn_execution/embedded/operation_utils.py b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
index 0c0580671ec..0dadbb359b9 100644
--- a/flink-python/pyflink/fn_execution/embedded/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
@@ -103,7 +103,8 @@ def create_table_operation_from_proto(proto, input_coder_info, output_coder_into
 
 def create_one_input_user_defined_data_stream_function_from_protos(
         function_infos, input_coder_info, output_coder_info, runtime_context,
-        function_context, timer_context, job_parameters, keyed_state_backend):
+        function_context, timer_context, job_parameters, keyed_state_backend,
+        operator_state_backend):
     serialized_fns = [pare_user_defined_data_stream_function_proto(proto)
                       for proto in function_infos]
     input_data_converter = (
@@ -119,14 +120,16 @@ def create_one_input_user_defined_data_stream_function_from_protos(
         function_context,
         timer_context,
         job_parameters,
-        keyed_state_backend)
+        keyed_state_backend,
+        operator_state_backend)
 
     return function_operation
 
 
 def create_two_input_user_defined_data_stream_function_from_protos(
         function_infos, input_coder_info1, input_coder_info2, output_coder_info, runtime_context,
-        function_context, timer_context, job_parameters, keyed_state_backend):
+        function_context, timer_context, job_parameters, keyed_state_backend,
+        operator_state_backend):
     serialized_fns = [pare_user_defined_data_stream_function_proto(proto)
                       for proto in function_infos]
 
@@ -148,6 +151,7 @@ def create_two_input_user_defined_data_stream_function_from_protos(
         function_context,
         timer_context,
         job_parameters,
-        keyed_state_backend)
+        keyed_state_backend,
+        operator_state_backend)
 
     return function_operation
diff --git a/flink-python/pyflink/fn_execution/embedded/operations.py b/flink-python/pyflink/fn_execution/embedded/operations.py
index 88eee565224..be7771f8aca 100644
--- a/flink-python/pyflink/fn_execution/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/embedded/operations.py
@@ -67,7 +67,8 @@ class OneInputFunctionOperation(FunctionOperation):
                  function_context,
                  timer_context,
                  job_parameters,
-                 keyed_state_backend):
+                 keyed_state_backend,
+                 operator_state_backend):
         operations = (
             [extract_process_function(
                 serialized_fn,
@@ -75,7 +76,8 @@ class OneInputFunctionOperation(FunctionOperation):
                 function_context,
                 timer_context,
                 job_parameters,
-                keyed_state_backend)
+                keyed_state_backend,
+                operator_state_backend)
                 for serialized_fn in serialized_fns])
         super(OneInputFunctionOperation, self).__init__(operations, output_data_converter)
         self._input_data_converter = input_data_converter
@@ -99,7 +101,8 @@ class TwoInputFunctionOperation(FunctionOperation):
                  function_context,
                  timer_context,
                  job_parameters,
-                 keyed_state_backend):
+                 keyed_state_backend,
+                 operator_state_backend):
         operations = (
             [extract_process_function(
                 serialized_fn,
@@ -107,7 +110,8 @@ class TwoInputFunctionOperation(FunctionOperation):
                 function_context,
                 timer_context,
                 job_parameters,
-                keyed_state_backend)
+                keyed_state_backend,
+                operator_state_backend)
                 for serialized_fn in serialized_fns])
         super(TwoInputFunctionOperation, self).__init__(operations, output_data_converter)
         self._input_data_converter1 = input_data_converter1
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
index 296f0e5a697..b02f3bfcb32 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
@@ -116,6 +116,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
         interpreter.set("timer_context", getTimerContext());
         interpreter.set("job_parameters", getJobParameters());
         interpreter.set("keyed_state_backend", getKeyedStateBackend());
+        interpreter.set("operator_state_backend", getOperatorStateBackend());
 
         interpreter.exec(
                 "from pyflink.fn_execution.embedded.operation_utils import create_one_input_user_defined_data_stream_function_from_protos");
@@ -129,7 +130,8 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
                         + "function_context,"
                         + "timer_context,"
                         + "job_parameters,"
-                        + "keyed_state_backend)");
+                        + "keyed_state_backend,"
+                        + "operator_state_backend)");
 
         interpreter.invokeMethod("operation", "open");
     }
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java
index 232ccda4575..b0767a309fe 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java
@@ -137,6 +137,7 @@ public abstract class AbstractTwoInputEmbeddedPythonFunctionOperator<IN1, IN2, O
         interpreter.set("timer_context", getTimerContext());
         interpreter.set("keyed_state_backend", getKeyedStateBackend());
         interpreter.set("job_parameters", getJobParameters());
+        interpreter.set("operator_state_backend", getOperatorStateBackend());
 
         interpreter.exec(
                 "from pyflink.fn_execution.embedded.operation_utils import create_two_input_user_defined_data_stream_function_from_protos");
@@ -151,7 +152,8 @@ public abstract class AbstractTwoInputEmbeddedPythonFunctionOperator<IN1, IN2, O
                         + "function_context,"
                         + "timer_context,"
                         + "job_parameters,"
-                        + "keyed_state_backend)");
+                        + "keyed_state_backend,"
+                        + "operator_state_backend)");
 
         interpreter.invokeMethod("operation", "open");
     }
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchCoBroadcastProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchCoBroadcastProcessOperator.java
new file mode 100644
index 00000000000..bd7ebfcdc6d
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchCoBroadcastProcessOperator.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python.embedded;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.InputSelectable;
+import org.apache.flink.streaming.api.operators.InputSelection;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * The {@link EmbeddedPythonBatchCoBroadcastProcessOperator} is responsible for executing the Python
+ * CoBroadcastProcess Function under BATCH mode, {@link EmbeddedPythonCoProcessOperator} is used
+ * under STREAMING mode. This operator forces to run out data from broadcast side first, and then
+ * process data from regular side.
+ *
+ * @param <IN1> The input type of the regular stream
+ * @param <IN2> The input type of the broadcast stream
+ * @param <OUT> The output type of the CoBroadcastProcess function
+ */
+@Internal
+public class EmbeddedPythonBatchCoBroadcastProcessOperator<IN1, IN2, OUT>
+        extends EmbeddedPythonCoProcessOperator<IN1, IN2, OUT>
+        implements BoundedMultiInput, InputSelectable {
+
+    private static final long serialVersionUID = 1L;
+
+    private transient volatile boolean isBroadcastSideDone = false;
+
+    public EmbeddedPythonBatchCoBroadcastProcessOperator(
+            Configuration config,
+            DataStreamPythonFunctionInfo pythonFunctionInfo,
+            TypeInformation<IN1> inputTypeInfo1,
+            TypeInformation<IN2> inputTypeInfo2,
+            TypeInformation<OUT> outputTypeInfo) {
+        super(config, pythonFunctionInfo, inputTypeInfo1, inputTypeInfo2, outputTypeInfo);
+    }
+
+    @Override
+    public void endInput(int inputId) throws Exception {
+        if (inputId == 2) {
+            isBroadcastSideDone = true;
+        }
+    }
+
+    @Override
+    public InputSelection nextSelection() {
+        if (!isBroadcastSideDone) {
+            return InputSelection.SECOND;
+        } else {
+            return InputSelection.FIRST;
+        }
+    }
+
+    @Override
+    public void processElement1(StreamRecord<IN1> element) throws Exception {
+        Preconditions.checkState(
+                isBroadcastSideDone,
+                "Should not process regular input before broadcast side is done.");
+
+        super.processElement1(element);
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchKeyedCoBroadcastProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchKeyedCoBroadcastProcessOperator.java
new file mode 100644
index 00000000000..a5335780756
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchKeyedCoBroadcastProcessOperator.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python.embedded;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.InputSelectable;
+import org.apache.flink.streaming.api.operators.InputSelection;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * The {@link EmbeddedPythonBatchKeyedCoBroadcastProcessOperator} is responsible for executing the
+ * Python CoBroadcastProcess function under BATCH mode, {@link EmbeddedPythonKeyedCoProcessOperator}
+ * is used under STREAMING mode. This operator forces to run out data from broadcast side first, and
+ * then process data from regular side.
+ *
+ * @param <OUT> The output type of the CoBroadcastProcess function
+ */
+@Internal
+public class EmbeddedPythonBatchKeyedCoBroadcastProcessOperator<K, IN1, IN2, OUT>
+        extends EmbeddedPythonKeyedCoProcessOperator<K, IN1, IN2, OUT>
+        implements BoundedMultiInput, InputSelectable {
+
+    private static final long serialVersionUID = 1L;
+
+    private transient volatile boolean isBroadcastSideDone = false;
+
+    public EmbeddedPythonBatchKeyedCoBroadcastProcessOperator(
+            Configuration config,
+            DataStreamPythonFunctionInfo pythonFunctionInfo,
+            TypeInformation<IN1> inputTypeInfo1,
+            TypeInformation<IN2> inputTypeInfo2,
+            TypeInformation<OUT> outputTypeInfo) {
+        super(config, pythonFunctionInfo, inputTypeInfo1, inputTypeInfo2, outputTypeInfo);
+    }
+
+    @Override
+    public void endInput(int inputId) throws Exception {
+        if (inputId == 2) {
+            isBroadcastSideDone = true;
+        }
+    }
+
+    @Override
+    public InputSelection nextSelection() {
+        if (!isBroadcastSideDone) {
+            return InputSelection.SECOND;
+        } else {
+            return InputSelection.FIRST;
+        }
+    }
+
+    @Override
+    public void processElement1(StreamRecord<IN1> element) throws Exception {
+        Preconditions.checkState(
+                isBroadcastSideDone,
+                "Should not process regular input before broadcast side is done.");
+
+        super.processElement1(element);
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
index 5854fb096ed..32d9a259754 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
@@ -18,7 +18,12 @@
 package org.apache.flink.streaming.runtime.translators.python;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.PythonOptions;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonBatchCoBroadcastProcessOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonCoProcessOperator;
 import org.apache.flink.streaming.api.operators.python.process.ExternalPythonBatchCoBroadcastProcessOperator;
 import org.apache.flink.streaming.api.operators.python.process.ExternalPythonCoProcessOperator;
 import org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
@@ -29,8 +34,10 @@ import java.util.Collection;
 
 /**
  * A {@link org.apache.flink.streaming.api.graph.TransformationTranslator} that translates {@link
- * PythonBroadcastStateTransformation} into {@link ExternalPythonCoProcessOperator} or {@link
- * ExternalPythonBatchCoBroadcastProcessOperator}.
+ * PythonBroadcastStateTransformation} into {@link ExternalPythonCoProcessOperator}/{@link
+ * EmbeddedPythonCoProcessOperator} in streaming mode or {@link
+ * ExternalPythonBatchCoBroadcastProcessOperator}/{@link
+ * EmbeddedPythonBatchCoBroadcastProcessOperator} in batch mode.
  */
 @Internal
 public class PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
@@ -43,13 +50,27 @@ public class PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
         Preconditions.checkNotNull(transformation);
         Preconditions.checkNotNull(context);
 
-        ExternalPythonBatchCoBroadcastProcessOperator operator =
-                new ExternalPythonBatchCoBroadcastProcessOperator(
-                        transformation.getConfiguration(),
-                        transformation.getDataStreamPythonFunctionInfo(),
-                        transformation.getRegularInput().getOutputType(),
-                        transformation.getBroadcastInput().getOutputType(),
-                        transformation.getOutputType());
+        Configuration config = transformation.getConfiguration();
+
+        StreamOperator<OUT> operator;
+
+        if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+            operator =
+                    new EmbeddedPythonBatchCoBroadcastProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        } else {
+            operator =
+                    new ExternalPythonBatchCoBroadcastProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        }
 
         return translateInternal(
                 transformation,
@@ -68,13 +89,28 @@ public class PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
         Preconditions.checkNotNull(transformation);
         Preconditions.checkNotNull(context);
 
-        ExternalPythonCoProcessOperator<IN1, IN2, OUT> operator =
-                new ExternalPythonCoProcessOperator<>(
-                        transformation.getConfiguration(),
-                        transformation.getDataStreamPythonFunctionInfo(),
-                        transformation.getRegularInput().getOutputType(),
-                        transformation.getBroadcastInput().getOutputType(),
-                        transformation.getOutputType());
+        Configuration config = transformation.getConfiguration();
+
+        StreamOperator<OUT> operator;
+
+        if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+            operator =
+                    new EmbeddedPythonCoProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        } else {
+
+            operator =
+                    new ExternalPythonCoProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        }
 
         return translateInternal(
                 transformation,
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
index d24cc1374de..cdbf89c1420 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
@@ -18,7 +18,12 @@
 package org.apache.flink.streaming.runtime.translators.python;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.PythonOptions;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonBatchKeyedCoBroadcastProcessOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedCoProcessOperator;
 import org.apache.flink.streaming.api.operators.python.process.ExternalPythonBatchKeyedCoBroadcastProcessOperator;
 import org.apache.flink.streaming.api.operators.python.process.ExternalPythonKeyedCoProcessOperator;
 import org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
@@ -30,8 +35,10 @@ import java.util.Collection;
 
 /**
  * A {@link org.apache.flink.streaming.api.graph.TransformationTranslator} that translates {@link
- * PythonKeyedBroadcastStateTransformation} into {@link ExternalPythonKeyedCoProcessOperator} or
- * {@link ExternalPythonBatchKeyedCoBroadcastProcessOperator}.
+ * PythonKeyedBroadcastStateTransformation} into {@link ExternalPythonKeyedCoProcessOperator}/{@link
+ * EmbeddedPythonKeyedCoProcessOperator} in streaming mode or {@link
+ * ExternalPythonBatchKeyedCoBroadcastProcessOperator}/{@link
+ * EmbeddedPythonBatchKeyedCoBroadcastProcessOperator} in batch mode.
  */
 @Internal
 public class PythonKeyedBroadcastStateTransformationTranslator<OUT>
@@ -44,13 +51,27 @@ public class PythonKeyedBroadcastStateTransformationTranslator<OUT>
         Preconditions.checkNotNull(transformation);
         Preconditions.checkNotNull(context);
 
-        ExternalPythonKeyedCoProcessOperator<OUT> operator =
-                new ExternalPythonBatchKeyedCoBroadcastProcessOperator<>(
-                        transformation.getConfiguration(),
-                        transformation.getDataStreamPythonFunctionInfo(),
-                        transformation.getRegularInput().getOutputType(),
-                        transformation.getBroadcastInput().getOutputType(),
-                        transformation.getOutputType());
+        Configuration config = transformation.getConfiguration();
+
+        StreamOperator<OUT> operator;
+
+        if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+            operator =
+                    new EmbeddedPythonBatchKeyedCoBroadcastProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        } else {
+            operator =
+                    new ExternalPythonBatchKeyedCoBroadcastProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        }
 
         return translateInternal(
                 transformation,
@@ -69,13 +90,28 @@ public class PythonKeyedBroadcastStateTransformationTranslator<OUT>
         Preconditions.checkNotNull(transformation);
         Preconditions.checkNotNull(context);
 
-        ExternalPythonKeyedCoProcessOperator<OUT> operator =
-                new ExternalPythonKeyedCoProcessOperator<>(
-                        transformation.getConfiguration(),
-                        transformation.getDataStreamPythonFunctionInfo(),
-                        transformation.getRegularInput().getOutputType(),
-                        transformation.getBroadcastInput().getOutputType(),
-                        transformation.getOutputType());
+        Configuration config = transformation.getConfiguration();
+
+        StreamOperator<OUT> operator;
+
+        if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+            operator =
+                    new EmbeddedPythonKeyedCoProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+
+        } else {
+            operator =
+                    new ExternalPythonKeyedCoProcessOperator<>(
+                            transformation.getConfiguration(),
+                            transformation.getDataStreamPythonFunctionInfo(),
+                            transformation.getRegularInput().getOutputType(),
+                            transformation.getBroadcastInput().getOutputType(),
+                            transformation.getOutputType());
+        }
 
         return translateInternal(
                 transformation,