You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/08/08 09:58:39 UTC
[flink] branch master updated: [FLINK-28836][python] Support broadcast in Thread Mode
This is an automated email from the ASF dual-hosted git repository.
hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new ff91aa53cfc [FLINK-28836][python] Support broadcast in Thread Mode
ff91aa53cfc is described below
commit ff91aa53cfc9327ed591cbe0c1b3dcf9116dc3a5
Author: huangxingbo <hx...@apache.org>
AuthorDate: Mon Aug 8 14:44:08 2022 +0800
[FLINK-28836][python] Support broadcast in Thread Mode
This closes #20490.
---
flink-python/pyflink/datastream/data_stream.py | 3 +-
.../pyflink/datastream/tests/test_data_stream.py | 256 ++++++++++-----------
.../fn_execution/datastream/embedded/operations.py | 65 +++++-
.../datastream/embedded/process_function.py | 128 ++++++++++-
.../fn_execution/datastream/embedded/state_impl.py | 87 ++++++-
.../pyflink/fn_execution/embedded/converters.py | 13 +-
.../fn_execution/embedded/operation_utils.py | 12 +-
.../pyflink/fn_execution/embedded/operations.py | 12 +-
...ractOneInputEmbeddedPythonFunctionOperator.java | 4 +-
...ractTwoInputEmbeddedPythonFunctionOperator.java | 4 +-
...eddedPythonBatchCoBroadcastProcessOperator.java | 82 +++++++
...PythonBatchKeyedCoBroadcastProcessOperator.java | 80 +++++++
...thonBroadcastStateTransformationTranslator.java | 68 ++++--
...eyedBroadcastStateTransformationTranslator.java | 68 ++++--
14 files changed, 691 insertions(+), 191 deletions(-)
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 4e8b5cbb229..6938aa0d4da 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -2636,7 +2636,8 @@ class BroadcastConnectedStream(object):
jvm.String, [i.get_name() for i in self.broadcast_state_descriptors]
)
j_state_descriptors = JPythonConfigUtil.convertStateNamesToStateDescriptors(j_state_names)
- j_conf = jvm.org.apache.flink.configuration.Configuration()
+ j_conf = get_j_env_configuration(
+ self.broadcast_stream.input_stream._j_data_stream.getExecutionEnvironment())
j_data_stream_python_function_info = _create_j_data_stream_python_function_info(
func, func_type
)
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index c52da695d4e..5b62715e3af 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -398,6 +398,134 @@ class DataStreamTests(object):
"<Row('on_timer', 4)>"]
self.assert_equals_sorted(expected, results)
+ def test_co_broadcast_process(self):
+ ds = self.env.from_collection([1, 2, 3, 4, 5], type_info=Types.INT()) # type: DataStream
+ ds_broadcast = self.env.from_collection(
+ [(0, "a"), (1, "b")], type_info=Types.TUPLE([Types.INT(), Types.STRING()])
+ ) # type: DataStream
+
+ class MyBroadcastProcessFunction(BroadcastProcessFunction):
+ def __init__(self, map_state_desc):
+ self._map_state_desc = map_state_desc
+ self._cache = defaultdict(list)
+
+ def process_element(self, value: int, ctx: BroadcastProcessFunction.ReadOnlyContext):
+ ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+ key = value % 2
+ if ro_broadcast_state.contains(key):
+ if self._cache.get(key) is not None:
+ for v in self._cache[key]:
+ yield ro_broadcast_state.get(key) + str(v)
+ self._cache[key].clear()
+ yield ro_broadcast_state.get(key) + str(value)
+ else:
+ self._cache[key].append(value)
+
+ def process_broadcast_element(
+ self, value: Tuple[int, str], ctx: BroadcastProcessFunction.Context
+ ):
+ key = value[0]
+ yield str(key) + value[1]
+ broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+ broadcast_state.put(key, value[1])
+ if self._cache.get(key) is not None:
+ for v in self._cache[key]:
+ yield value[1] + str(v)
+ self._cache[key].clear()
+
+ map_state_desc = MapStateDescriptor(
+ "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
+ )
+ ds.connect(ds_broadcast.broadcast(map_state_desc)).process(
+ MyBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
+ ).add_sink(self.test_sink)
+
+ self.env.execute("test_co_broadcast_process")
+ expected = ["0a", "0a", "1b", "1b", "a2", "a4", "b1", "b3", "b5"]
+ self.assert_equals_sorted(expected, self.test_sink.get_results())
+
+ def test_keyed_co_broadcast_process(self):
+ ds = self.env.from_collection(
+ [(1, '1603708211000'),
+ (2, '1603708212000'),
+ (3, '1603708213000'),
+ (4, '1603708214000')],
+ type_info=Types.ROW([Types.INT(), Types.STRING()])) # type: DataStream
+ ds_broadcast = self.env.from_collection(
+ [(0, '1603708215000', 'a'),
+ (1, '1603708215000', 'b')],
+ type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])
+ ) # type: DataStream
+ watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
+ .with_timestamp_assigner(SecondColumnTimestampAssigner())
+ ds = ds.assign_timestamps_and_watermarks(watermark_strategy)
+ ds_broadcast = ds_broadcast.assign_timestamps_and_watermarks(watermark_strategy)
+
+ def _create_string(s, t):
+ return 'value: {}, ts: {}'.format(s, t)
+
+ class MyKeyedBroadcastProcessFunction(KeyedBroadcastProcessFunction):
+ def __init__(self, map_state_desc):
+ self._map_state_desc = map_state_desc
+ self._cache = None
+
+ def open(self, runtime_context: RuntimeContext):
+ self._cache = defaultdict(list)
+
+ def process_element(
+ self, value: Tuple[int, str], ctx: KeyedBroadcastProcessFunction.ReadOnlyContext
+ ):
+ ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+ key = value[0] % 2
+ if ro_broadcast_state.contains(key):
+ if self._cache.get(key) is not None:
+ for v in self._cache[key]:
+ yield _create_string(ro_broadcast_state.get(key) + str(v[0]), v[1])
+ self._cache[key].clear()
+ yield _create_string(ro_broadcast_state.get(key) + str(value[0]), value[1])
+ else:
+ self._cache[key].append(value)
+ ctx.timer_service().register_event_time_timer(ctx.timestamp() + 10000)
+
+ def process_broadcast_element(
+ self, value: Tuple[int, str, str], ctx: KeyedBroadcastProcessFunction.Context
+ ):
+ key = value[0]
+ yield _create_string(str(key) + value[2], ctx.timestamp())
+ broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
+ broadcast_state.put(key, value[2])
+ if self._cache.get(key) is not None:
+ for v in self._cache[key]:
+ yield _create_string(value[2] + str(v[0]), v[1])
+ self._cache[key].clear()
+
+ def on_timer(self, timestamp: int, ctx: KeyedBroadcastProcessFunction.OnTimerContext):
+ yield _create_string(ctx.get_current_key(), timestamp)
+
+ map_state_desc = MapStateDescriptor(
+ "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
+ )
+ ds.key_by(lambda t: t[0]).connect(ds_broadcast.broadcast(map_state_desc)).process(
+ MyKeyedBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
+ ).add_sink(self.test_sink)
+
+ self.env.execute("test_keyed_co_broadcast_process")
+ expected = [
+ 'value: 0a, ts: 1603708215000',
+ 'value: 0a, ts: 1603708215000',
+ 'value: 1, ts: 1603708221000',
+ 'value: 1b, ts: 1603708215000',
+ 'value: 1b, ts: 1603708215000',
+ 'value: 2, ts: 1603708222000',
+ 'value: 3, ts: 1603708223000',
+ 'value: 4, ts: 1603708224000',
+ 'value: a2, ts: 1603708212000',
+ 'value: a4, ts: 1603708214000',
+ 'value: b1, ts: 1603708211000',
+ 'value: b3, ts: 1603708213000'
+ ]
+ self.assert_equals_sorted(expected, self.test_sink.get_results())
+
class DataStreamStreamingTests(DataStreamTests):
@@ -1027,134 +1155,6 @@ class ProcessDataStreamTests(DataStreamTests):
side_expected = ['1', '1', '2', '2', '3', '3', '4', '4']
self.assert_equals_sorted(side_expected, side_sink.get_results())
- def test_co_broadcast_process(self):
- ds = self.env.from_collection([1, 2, 3, 4, 5], type_info=Types.INT()) # type: DataStream
- ds_broadcast = self.env.from_collection(
- [(0, "a"), (1, "b")], type_info=Types.TUPLE([Types.INT(), Types.STRING()])
- ) # type: DataStream
-
- class MyBroadcastProcessFunction(BroadcastProcessFunction):
- def __init__(self, map_state_desc):
- self._map_state_desc = map_state_desc
- self._cache = defaultdict(list)
-
- def process_element(self, value: int, ctx: BroadcastProcessFunction.ReadOnlyContext):
- ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
- key = value % 2
- if ro_broadcast_state.contains(key):
- if self._cache.get(key) is not None:
- for v in self._cache[key]:
- yield ro_broadcast_state.get(key) + str(v)
- self._cache[key].clear()
- yield ro_broadcast_state.get(key) + str(value)
- else:
- self._cache[key].append(value)
-
- def process_broadcast_element(
- self, value: Tuple[int, str], ctx: BroadcastProcessFunction.Context
- ):
- key = value[0]
- yield str(key) + value[1]
- broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
- broadcast_state.put(key, value[1])
- if self._cache.get(key) is not None:
- for v in self._cache[key]:
- yield value[1] + str(v)
- self._cache[key].clear()
-
- map_state_desc = MapStateDescriptor(
- "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
- )
- ds.connect(ds_broadcast.broadcast(map_state_desc)).process(
- MyBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
- ).add_sink(self.test_sink)
-
- self.env.execute("test_co_broadcast_process")
- expected = ["0a", "0a", "1b", "1b", "a2", "a4", "b1", "b3", "b5"]
- self.assert_equals_sorted(expected, self.test_sink.get_results())
-
- def test_keyed_co_broadcast_process(self):
- ds = self.env.from_collection(
- [(1, '1603708211000'),
- (2, '1603708212000'),
- (3, '1603708213000'),
- (4, '1603708214000')],
- type_info=Types.ROW([Types.INT(), Types.STRING()])) # type: DataStream
- ds_broadcast = self.env.from_collection(
- [(0, '1603708215000', 'a'),
- (1, '1603708215000', 'b')],
- type_info=Types.ROW([Types.INT(), Types.STRING(), Types.STRING()])
- ) # type: DataStream
- watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
- .with_timestamp_assigner(SecondColumnTimestampAssigner())
- ds = ds.assign_timestamps_and_watermarks(watermark_strategy)
- ds_broadcast = ds_broadcast.assign_timestamps_and_watermarks(watermark_strategy)
-
- def _create_string(s, t):
- return 'value: {}, ts: {}'.format(s, t)
-
- class MyKeyedBroadcastProcessFunction(KeyedBroadcastProcessFunction):
- def __init__(self, map_state_desc):
- self._map_state_desc = map_state_desc
- self._cache = None
-
- def open(self, runtime_context: RuntimeContext):
- self._cache = defaultdict(list)
-
- def process_element(
- self, value: Tuple[int, str], ctx: KeyedBroadcastProcessFunction.ReadOnlyContext
- ):
- ro_broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
- key = value[0] % 2
- if ro_broadcast_state.contains(key):
- if self._cache.get(key) is not None:
- for v in self._cache[key]:
- yield _create_string(ro_broadcast_state.get(key) + str(v[0]), v[1])
- self._cache[key].clear()
- yield _create_string(ro_broadcast_state.get(key) + str(value[0]), value[1])
- else:
- self._cache[key].append(value)
- ctx.timer_service().register_event_time_timer(ctx.timestamp() + 10000)
-
- def process_broadcast_element(
- self, value: Tuple[int, str, str], ctx: KeyedBroadcastProcessFunction.Context
- ):
- key = value[0]
- yield _create_string(str(key) + value[2], ctx.timestamp())
- broadcast_state = ctx.get_broadcast_state(self._map_state_desc)
- broadcast_state.put(key, value[2])
- if self._cache.get(key) is not None:
- for v in self._cache[key]:
- yield _create_string(value[2] + str(v[0]), v[1])
- self._cache[key].clear()
-
- def on_timer(self, timestamp: int, ctx: KeyedBroadcastProcessFunction.OnTimerContext):
- yield _create_string(ctx.get_current_key(), timestamp)
-
- map_state_desc = MapStateDescriptor(
- "mapping", key_type_info=Types.INT(), value_type_info=Types.STRING()
- )
- ds.key_by(lambda t: t[0]).connect(ds_broadcast.broadcast(map_state_desc)).process(
- MyKeyedBroadcastProcessFunction(map_state_desc), output_type=Types.STRING()
- ).add_sink(self.test_sink)
-
- self.env.execute("test_keyed_co_broadcast_process")
- expected = [
- 'value: 0a, ts: 1603708215000',
- 'value: 0a, ts: 1603708215000',
- 'value: 1, ts: 1603708221000',
- 'value: 1b, ts: 1603708215000',
- 'value: 1b, ts: 1603708215000',
- 'value: 2, ts: 1603708222000',
- 'value: 3, ts: 1603708223000',
- 'value: 4, ts: 1603708224000',
- 'value: a2, ts: 1603708212000',
- 'value: a4, ts: 1603708214000',
- 'value: b1, ts: 1603708211000',
- 'value: b3, ts: 1603708213000'
- ]
- self.assert_equals_sorted(expected, self.test_sink.get_results())
-
class ProcessDataStreamStreamingTests(DataStreamStreamingTests, ProcessDataStreamTests,
PyFlinkStreamingTestCase):
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
index 69757e7f128..beef4be98a8 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
@@ -21,7 +21,11 @@ from pyflink.fn_execution.coders import TimeWindowCoder, CountWindowCoder
from pyflink.fn_execution.datastream import operations
from pyflink.fn_execution.datastream.embedded.process_function import (
InternalProcessFunctionContext, InternalKeyedProcessFunctionContext,
- InternalKeyedProcessFunctionOnTimerContext, InternalWindowTimerContext)
+ InternalKeyedProcessFunctionOnTimerContext, InternalWindowTimerContext,
+ InternalBroadcastProcessFunctionContext, InternalBroadcastProcessFunctionReadOnlyContext,
+ InternalKeyedBroadcastProcessFunctionContext,
+ InternalKeyedBroadcastProcessFunctionReadOnlyContext,
+ InternalKeyedBroadcastProcessFunctionOnTimerContext)
from pyflink.fn_execution.datastream.embedded.runtime_context import StreamingRuntimeContext
from pyflink.fn_execution.datastream.embedded.timerservice_impl import InternalTimerServiceImpl
from pyflink.fn_execution.datastream.window.window_operator import WindowOperator
@@ -80,7 +84,7 @@ class TwoInputOperation(operations.TwoInputOperation):
def extract_process_function(
user_defined_function_proto, j_runtime_context, j_function_context, j_timer_context,
- job_parameters, j_keyed_state_backend):
+ job_parameters, j_keyed_state_backend, j_operator_state_backend):
from pyflink.fn_execution import flink_fn_execution_pb2
user_defined_func = pickle.loads(user_defined_function_proto.payload)
@@ -151,6 +155,31 @@ def extract_process_function(
return TwoInputOperation(
open_func, close_func, process_element_func1, process_element_func2)
+ elif func_type == UserDefinedDataStreamFunction.CO_BROADCAST_PROCESS:
+
+ broadcast_ctx = InternalBroadcastProcessFunctionContext(
+ j_function_context, j_operator_state_backend)
+
+ read_only_broadcast_ctx = InternalBroadcastProcessFunctionReadOnlyContext(
+ j_function_context, j_operator_state_backend)
+
+ process_element = user_defined_func.process_element
+
+ process_broadcast_element = user_defined_func.process_broadcast_element
+
+ def process_element_func1(value):
+ elements = process_element(value, read_only_broadcast_ctx)
+ if elements:
+ yield from elements
+
+ def process_element_func2(value):
+ elements = process_broadcast_element(value, broadcast_ctx)
+ if elements:
+ yield from elements
+
+ return TwoInputOperation(
+ open_func, close_func, process_element_func1, process_element_func2)
+
elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:
function_context = InternalKeyedProcessFunctionContext(
@@ -185,6 +214,38 @@ def extract_process_function(
return TwoInputOperation(
open_func, close_func, process_element_func1, process_element_func2, on_timer_func)
+ elif func_type == UserDefinedDataStreamFunction.KEYED_CO_BROADCAST_PROCESS:
+ broadcast_ctx = InternalKeyedBroadcastProcessFunctionContext(
+ j_function_context, j_operator_state_backend)
+
+ read_only_broadcast_ctx = InternalKeyedBroadcastProcessFunctionReadOnlyContext(
+ j_function_context, user_defined_function_proto.key_type_info, j_operator_state_backend)
+
+ timer_context = InternalKeyedBroadcastProcessFunctionOnTimerContext(
+ j_timer_context, user_defined_function_proto.key_type_info, j_operator_state_backend)
+
+ process_element = user_defined_func.process_element
+
+ process_broadcast_element = user_defined_func.process_broadcast_element
+
+ on_timer = user_defined_func.on_timer
+
+ def process_element_func1(value):
+ elements = process_element(value[1], read_only_broadcast_ctx)
+ if elements:
+ yield from elements
+
+ def process_element_func2(value):
+ elements = process_broadcast_element(value, broadcast_ctx)
+ if elements:
+ yield from elements
+
+ def on_timer_func(timestamp):
+ yield from on_timer(timestamp, timer_context)
+
+ return TwoInputOperation(
+ open_func, close_func, process_element_func1, process_element_func2, on_timer_func)
+
elif func_type == UserDefinedDataStreamFunction.WINDOW:
window_operation_descriptor = (
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
index 9cccce3ae94..b23b2b3110f 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
@@ -15,10 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
+from abc import ABC
+
from pyflink.datastream import (ProcessFunction, KeyedProcessFunction, CoProcessFunction,
KeyedCoProcessFunction, TimerService, TimeDomain)
+from pyflink.datastream.functions import (BaseBroadcastProcessFunction, BroadcastProcessFunction,
+ KeyedBroadcastProcessFunction)
+from pyflink.datastream.state import MapStateDescriptor, BroadcastState, ReadOnlyBroadcastState
+from pyflink.fn_execution.datastream.embedded.state_impl import (ReadOnlyBroadcastStateImpl,
+ BroadcastStateImpl)
from pyflink.fn_execution.datastream.embedded.timerservice_impl import TimerServiceImpl
-from pyflink.fn_execution.embedded.converters import from_type_info
+from pyflink.fn_execution.embedded.converters import from_type_info_proto, from_type_info
+from pyflink.fn_execution.embedded.java_utils import to_java_state_descriptor
class InternalProcessFunctionContext(ProcessFunction.Context, CoProcessFunction.Context,
@@ -106,3 +114,121 @@ class InternalWindowTimerContext(object):
def get_current_key(self):
return self._key_converter.to_internal(self._context.getCurrentKey())
+
+
+class InternalBaseBroadcastProcessFunctionContext(BaseBroadcastProcessFunction.Context, ABC):
+
+ def __init__(self, context, operator_state_backend):
+ self._context = context
+ self._operator_state_backend = operator_state_backend
+
+ def timestamp(self) -> int:
+ return self._context.timestamp()
+
+ def current_processing_time(self) -> int:
+ return self._context.currentProcessingTime()
+
+ def current_watermark(self) -> int:
+ return self._context.currentWatermark()
+
+
+class InternalBroadcastProcessFunctionContext(InternalBaseBroadcastProcessFunctionContext,
+ BroadcastProcessFunction.Context):
+
+ def __init__(self, context, operator_state_backend):
+ super(InternalBroadcastProcessFunctionContext, self).__init__(
+ context, operator_state_backend)
+
+ def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> BroadcastState:
+ return BroadcastStateImpl(
+ self._operator_state_backend.getBroadcastState(
+ to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
+
+
+class InternalBroadcastProcessFunctionReadOnlyContext(InternalBaseBroadcastProcessFunctionContext,
+ BroadcastProcessFunction.ReadOnlyContext):
+
+ def __init__(self, context, operator_state_backend):
+ super(InternalBroadcastProcessFunctionReadOnlyContext, self).__init__(
+ context, operator_state_backend)
+
+ def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> ReadOnlyBroadcastState:
+ return ReadOnlyBroadcastStateImpl(
+ self._operator_state_backend.getBroadcastState(
+ to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
+
+
+class InternalKeyedBroadcastProcessFunctionContext(InternalBaseBroadcastProcessFunctionContext,
+ KeyedBroadcastProcessFunction.Context):
+
+ def __init__(self, context, operator_state_backend):
+ super(InternalKeyedBroadcastProcessFunctionContext, self).__init__(
+ context, operator_state_backend)
+
+ def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> BroadcastState:
+ return BroadcastStateImpl(
+ self._operator_state_backend.getBroadcastState(
+ to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
+
+
+class InternalKeyedBroadcastProcessFunctionReadOnlyContext(
+ InternalBaseBroadcastProcessFunctionContext,
+ KeyedBroadcastProcessFunction.ReadOnlyContext
+):
+
+ def __init__(self, context, key_type_info, operator_state_backend):
+ super(InternalKeyedBroadcastProcessFunctionReadOnlyContext, self).__init__(
+ context, operator_state_backend)
+ self._key_converter = from_type_info_proto(key_type_info)
+ self._timer_service = TimerServiceImpl(self._context.timerService())
+
+ def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> ReadOnlyBroadcastState:
+ return ReadOnlyBroadcastStateImpl(
+ self._operator_state_backend.getBroadcastState(
+ to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
+
+ def timer_service(self) -> TimerService:
+ return self._timer_service
+
+ def get_current_key(self):
+ return self._key_converter.to_internal(self._context.getCurrentKey())
+
+
+class InternalKeyedBroadcastProcessFunctionOnTimerContext(
+ InternalBaseBroadcastProcessFunctionContext,
+ KeyedBroadcastProcessFunction.OnTimerContext,
+):
+
+ def __init__(self, context, key_type_info, operator_state_backend):
+ super(InternalKeyedBroadcastProcessFunctionOnTimerContext, self).__init__(
+ context, operator_state_backend)
+ self._timer_service = TimerServiceImpl(self._context.timerService())
+ self._key_converter = from_type_info_proto(key_type_info)
+
+ def get_broadcast_state(self, state_descriptor: MapStateDescriptor) -> ReadOnlyBroadcastState:
+ return ReadOnlyBroadcastStateImpl(
+ self._operator_state_backend.getBroadcastState(
+ to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
+
+ def current_processing_time(self) -> int:
+ return self._timer_service.current_processing_time()
+
+ def current_watermark(self) -> int:
+ return self._timer_service.current_watermark()
+
+ def timer_service(self) -> TimerService:
+ return self._timer_service
+
+ def timestamp(self) -> int:
+ return self._context.timestamp()
+
+ def time_domain(self) -> TimeDomain:
+ return TimeDomain(self._context.timeDomain())
+
+ def get_current_key(self):
+ return self._key_converter.to_internal(self._context.getCurrentKey())
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
index 6973be038eb..8eb28eaf380 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
@@ -19,33 +19,42 @@ from abc import ABC
from typing import List, Iterable, Tuple, Dict, Collection
from pyflink.datastream import ReduceFunction, AggregateFunction
-from pyflink.datastream.state import (T, IN, OUT, V, K)
+from pyflink.datastream.state import (T, IN, OUT, V, K, State)
from pyflink.fn_execution.embedded.converters import (DataConverter, DictDataConverter,
ListDataConverter)
from pyflink.fn_execution.internal_state import (InternalValueState, InternalKvState,
InternalListState, InternalReducingState,
InternalAggregatingState, InternalMapState,
- N)
+ N, InternalReadOnlyBroadcastState,
+ InternalBroadcastState)
-class StateImpl(InternalKvState, ABC):
+class StateImpl(State, ABC):
def __init__(self,
state,
- value_converter: DataConverter,
- window_converter: DataConverter = None):
+ value_converter: DataConverter):
self._state = state
self._value_converter = value_converter
- self._window_converter = window_converter
def clear(self):
self._state.clear()
+
+class KeyedStateImpl(StateImpl, InternalKvState, ABC):
+
+ def __init__(self,
+ state,
+ value_converter: DataConverter,
+ window_converter: DataConverter = None):
+ super(KeyedStateImpl, self).__init__(state, value_converter)
+ self._window_converter = window_converter
+
def set_current_namespace(self, namespace) -> None:
j_window = self._window_converter.to_external(namespace)
self._state.setCurrentNamespace(j_window)
-class ValueStateImpl(StateImpl, InternalValueState):
+class ValueStateImpl(KeyedStateImpl, InternalValueState):
def __init__(self,
value_state,
value_converter: DataConverter,
@@ -59,7 +68,7 @@ class ValueStateImpl(StateImpl, InternalValueState):
self._state.update(self._value_converter.to_external(value))
-class ListStateImpl(StateImpl, InternalListState):
+class ListStateImpl(KeyedStateImpl, InternalListState):
def __init__(self,
list_state,
@@ -88,7 +97,7 @@ class ListStateImpl(StateImpl, InternalListState):
self._state.mergeNamespaces(j_target, j_sources)
-class ReducingStateImpl(StateImpl, InternalReducingState):
+class ReducingStateImpl(KeyedStateImpl, InternalReducingState):
def __init__(self,
value_state,
@@ -135,7 +144,7 @@ class ReducingStateImpl(StateImpl, InternalReducingState):
self._state.update(self._value_converter.to_external(merged))
-class AggregatingStateImpl(StateImpl, InternalAggregatingState):
+class AggregatingStateImpl(KeyedStateImpl, InternalAggregatingState):
def __init__(self,
value_state,
value_converter,
@@ -185,7 +194,7 @@ class AggregatingStateImpl(StateImpl, InternalAggregatingState):
self._state.update(self._value_converter.to_external(merged))
-class MapStateImpl(StateImpl, InternalMapState):
+class MapStateImpl(KeyedStateImpl, InternalMapState):
def __init__(self,
map_state,
map_converter: DictDataConverter,
@@ -231,3 +240,59 @@ class MapStateImpl(StateImpl, InternalMapState):
def is_empty(self) -> bool:
return self._state.isEmpty()
+
+
+class ReadOnlyBroadcastStateImpl(StateImpl, InternalReadOnlyBroadcastState):
+
+ def __init__(self,
+ map_state,
+ map_converter: DictDataConverter):
+ super(ReadOnlyBroadcastStateImpl, self).__init__(map_state, map_converter)
+ self._k_converter = map_converter._key_converter
+ self._v_converter = map_converter._value_converter
+
+ def get(self, key: K) -> V:
+ return self._v_converter.to_internal(
+ self._state.get(self._k_converter.to_external(key)))
+
+ def contains(self, key: K) -> bool:
+ return self._state.contains(self._k_converter.to_external(key))
+
+ def items(self) -> Iterable[Tuple[K, V]]:
+ entries = self._state.entries()
+ for entry in entries:
+ yield (self._k_converter.to_internal(entry.getKey()),
+ self._v_converter.to_internal(entry.getValue()))
+
+ def keys(self) -> Iterable[K]:
+ for k in self._state.keys():
+ yield self._k_converter.to_internal(k)
+
+ def values(self) -> Iterable[V]:
+ for v in self._state.values():
+ yield self._v_converter.to_internal(v)
+
+ def is_empty(self) -> bool:
+ return self._state.isEmpty()
+
+
+class BroadcastStateImpl(ReadOnlyBroadcastStateImpl, InternalBroadcastState):
+ def __init__(self,
+ map_state,
+ map_converter: DictDataConverter):
+ super(BroadcastStateImpl, self).__init__(map_state, map_converter)
+ self._map_converter = map_converter
+ self._k_converter = map_converter._key_converter
+ self._v_converter = map_converter._value_converter
+
+ def to_read_only_broadcast_state(self) -> InternalReadOnlyBroadcastState[K, V]:
+ return ReadOnlyBroadcastStateImpl(self._state, self._map_converter)
+
+ def put(self, key: K, value: V) -> None:
+ self._state.put(self._k_converter.to_external(key), self._v_converter.to_external(value))
+
+ def put_all(self, dict_value: Dict[K, V]) -> None:
+ self._state.putAll(self._value_converter.to_external(dict_value))
+
+ def remove(self, key: K) -> None:
+ self._state.remove(self._k_converter.to_external(key))
diff --git a/flink-python/pyflink/fn_execution/embedded/converters.py b/flink-python/pyflink/fn_execution/embedded/converters.py
index d2ba951f884..8829d48074e 100644
--- a/flink-python/pyflink/fn_execution/embedded/converters.py
+++ b/flink-python/pyflink/fn_execution/embedded/converters.py
@@ -95,18 +95,19 @@ class RowDataConverter(DataConverter):
def __init__(self, field_data_converters: List[DataConverter], field_names: List[str]):
self._field_data_converters = field_data_converters
- self._reuse_row = Row()
- self._reuse_row.set_field_names(field_names)
+ self._field_names = field_names
def to_internal(self, value) -> IN:
if value is None:
return None
- self._reuse_row._values = [self._field_data_converters[i].to_internal(item)
- for i, item in enumerate(value[1])]
- self._reuse_row.set_row_kind(RowKind(value[0]))
+ row = Row()
+ row._values = [self._field_data_converters[i].to_internal(item)
+ for i, item in enumerate(value[1])]
+ row.set_field_names(self._field_names)
+ row.set_row_kind(RowKind(value[0]))
- return self._reuse_row
+ return row
def to_external(self, value: Row) -> OUT:
if value is None:
diff --git a/flink-python/pyflink/fn_execution/embedded/operation_utils.py b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
index 0c0580671ec..0dadbb359b9 100644
--- a/flink-python/pyflink/fn_execution/embedded/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
@@ -103,7 +103,8 @@ def create_table_operation_from_proto(proto, input_coder_info, output_coder_into
def create_one_input_user_defined_data_stream_function_from_protos(
function_infos, input_coder_info, output_coder_info, runtime_context,
- function_context, timer_context, job_parameters, keyed_state_backend):
+ function_context, timer_context, job_parameters, keyed_state_backend,
+ operator_state_backend):
serialized_fns = [pare_user_defined_data_stream_function_proto(proto)
for proto in function_infos]
input_data_converter = (
@@ -119,14 +120,16 @@ def create_one_input_user_defined_data_stream_function_from_protos(
function_context,
timer_context,
job_parameters,
- keyed_state_backend)
+ keyed_state_backend,
+ operator_state_backend)
return function_operation
def create_two_input_user_defined_data_stream_function_from_protos(
function_infos, input_coder_info1, input_coder_info2, output_coder_info, runtime_context,
- function_context, timer_context, job_parameters, keyed_state_backend):
+ function_context, timer_context, job_parameters, keyed_state_backend,
+ operator_state_backend):
serialized_fns = [pare_user_defined_data_stream_function_proto(proto)
for proto in function_infos]
@@ -148,6 +151,7 @@ def create_two_input_user_defined_data_stream_function_from_protos(
function_context,
timer_context,
job_parameters,
- keyed_state_backend)
+ keyed_state_backend,
+ operator_state_backend)
return function_operation
diff --git a/flink-python/pyflink/fn_execution/embedded/operations.py b/flink-python/pyflink/fn_execution/embedded/operations.py
index 88eee565224..be7771f8aca 100644
--- a/flink-python/pyflink/fn_execution/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/embedded/operations.py
@@ -67,7 +67,8 @@ class OneInputFunctionOperation(FunctionOperation):
function_context,
timer_context,
job_parameters,
- keyed_state_backend):
+ keyed_state_backend,
+ operator_state_backend):
operations = (
[extract_process_function(
serialized_fn,
@@ -75,7 +76,8 @@ class OneInputFunctionOperation(FunctionOperation):
function_context,
timer_context,
job_parameters,
- keyed_state_backend)
+ keyed_state_backend,
+ operator_state_backend)
for serialized_fn in serialized_fns])
super(OneInputFunctionOperation, self).__init__(operations, output_data_converter)
self._input_data_converter = input_data_converter
@@ -99,7 +101,8 @@ class TwoInputFunctionOperation(FunctionOperation):
function_context,
timer_context,
job_parameters,
- keyed_state_backend):
+ keyed_state_backend,
+ operator_state_backend):
operations = (
[extract_process_function(
serialized_fn,
@@ -107,7 +110,8 @@ class TwoInputFunctionOperation(FunctionOperation):
function_context,
timer_context,
job_parameters,
- keyed_state_backend)
+ keyed_state_backend,
+ operator_state_backend)
for serialized_fn in serialized_fns])
super(TwoInputFunctionOperation, self).__init__(operations, output_data_converter)
self._input_data_converter1 = input_data_converter1
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
index 296f0e5a697..b02f3bfcb32 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
@@ -116,6 +116,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
interpreter.set("timer_context", getTimerContext());
interpreter.set("job_parameters", getJobParameters());
interpreter.set("keyed_state_backend", getKeyedStateBackend());
+ interpreter.set("operator_state_backend", getOperatorStateBackend());
interpreter.exec(
"from pyflink.fn_execution.embedded.operation_utils import create_one_input_user_defined_data_stream_function_from_protos");
@@ -129,7 +130,8 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
+ "function_context,"
+ "timer_context,"
+ "job_parameters,"
- + "keyed_state_backend)");
+ + "keyed_state_backend,"
+ + "operator_state_backend)");
interpreter.invokeMethod("operation", "open");
}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java
index 232ccda4575..b0767a309fe 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java
@@ -137,6 +137,7 @@ public abstract class AbstractTwoInputEmbeddedPythonFunctionOperator<IN1, IN2, O
interpreter.set("timer_context", getTimerContext());
interpreter.set("keyed_state_backend", getKeyedStateBackend());
interpreter.set("job_parameters", getJobParameters());
+ interpreter.set("operator_state_backend", getOperatorStateBackend());
interpreter.exec(
"from pyflink.fn_execution.embedded.operation_utils import create_two_input_user_defined_data_stream_function_from_protos");
@@ -151,7 +152,8 @@ public abstract class AbstractTwoInputEmbeddedPythonFunctionOperator<IN1, IN2, O
+ "function_context,"
+ "timer_context,"
+ "job_parameters,"
- + "keyed_state_backend)");
+ + "keyed_state_backend,"
+ + "operator_state_backend)");
interpreter.invokeMethod("operation", "open");
}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchCoBroadcastProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchCoBroadcastProcessOperator.java
new file mode 100644
index 00000000000..bd7ebfcdc6d
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchCoBroadcastProcessOperator.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python.embedded;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.InputSelectable;
+import org.apache.flink.streaming.api.operators.InputSelection;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * The {@link EmbeddedPythonBatchCoBroadcastProcessOperator} is responsible for executing the Python
+ * CoBroadcastProcess Function under BATCH mode, {@link EmbeddedPythonCoProcessOperator} is used
+ * under STREAMING mode. This operator forces to run out data from broadcast side first, and then
+ * process data from regular side.
+ *
+ * @param <IN1> The input type of the regular stream
+ * @param <IN2> The input type of the broadcast stream
+ * @param <OUT> The output type of the CoBroadcastProcess function
+ */
+@Internal
+public class EmbeddedPythonBatchCoBroadcastProcessOperator<IN1, IN2, OUT>
+ extends EmbeddedPythonCoProcessOperator<IN1, IN2, OUT>
+ implements BoundedMultiInput, InputSelectable {
+
+ private static final long serialVersionUID = 1L;
+
+ private transient volatile boolean isBroadcastSideDone = false;
+
+ public EmbeddedPythonBatchCoBroadcastProcessOperator(
+ Configuration config,
+ DataStreamPythonFunctionInfo pythonFunctionInfo,
+ TypeInformation<IN1> inputTypeInfo1,
+ TypeInformation<IN2> inputTypeInfo2,
+ TypeInformation<OUT> outputTypeInfo) {
+ super(config, pythonFunctionInfo, inputTypeInfo1, inputTypeInfo2, outputTypeInfo);
+ }
+
+ @Override
+ public void endInput(int inputId) throws Exception {
+ if (inputId == 2) {
+ isBroadcastSideDone = true;
+ }
+ }
+
+ @Override
+ public InputSelection nextSelection() {
+ if (!isBroadcastSideDone) {
+ return InputSelection.SECOND;
+ } else {
+ return InputSelection.FIRST;
+ }
+ }
+
+ @Override
+ public void processElement1(StreamRecord<IN1> element) throws Exception {
+ Preconditions.checkState(
+ isBroadcastSideDone,
+ "Should not process regular input before broadcast side is done.");
+
+ super.processElement1(element);
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchKeyedCoBroadcastProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchKeyedCoBroadcastProcessOperator.java
new file mode 100644
index 00000000000..a5335780756
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonBatchKeyedCoBroadcastProcessOperator.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python.embedded;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.InputSelectable;
+import org.apache.flink.streaming.api.operators.InputSelection;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * The {@link EmbeddedPythonBatchKeyedCoBroadcastProcessOperator} is responsible for executing the
+ * Python CoBroadcastProcess function under BATCH mode, {@link EmbeddedPythonKeyedCoProcessOperator}
+ * is used under STREAMING mode. This operator forces to run out data from broadcast side first, and
+ * then process data from regular side.
+ *
+ * @param <OUT> The output type of the CoBroadcastProcess function
+ */
+@Internal
+public class EmbeddedPythonBatchKeyedCoBroadcastProcessOperator<K, IN1, IN2, OUT>
+ extends EmbeddedPythonKeyedCoProcessOperator<K, IN1, IN2, OUT>
+ implements BoundedMultiInput, InputSelectable {
+
+ private static final long serialVersionUID = 1L;
+
+ private transient volatile boolean isBroadcastSideDone = false;
+
+ public EmbeddedPythonBatchKeyedCoBroadcastProcessOperator(
+ Configuration config,
+ DataStreamPythonFunctionInfo pythonFunctionInfo,
+ TypeInformation<IN1> inputTypeInfo1,
+ TypeInformation<IN2> inputTypeInfo2,
+ TypeInformation<OUT> outputTypeInfo) {
+ super(config, pythonFunctionInfo, inputTypeInfo1, inputTypeInfo2, outputTypeInfo);
+ }
+
+ @Override
+ public void endInput(int inputId) throws Exception {
+ if (inputId == 2) {
+ isBroadcastSideDone = true;
+ }
+ }
+
+ @Override
+ public InputSelection nextSelection() {
+ if (!isBroadcastSideDone) {
+ return InputSelection.SECOND;
+ } else {
+ return InputSelection.FIRST;
+ }
+ }
+
+ @Override
+ public void processElement1(StreamRecord<IN1> element) throws Exception {
+ Preconditions.checkState(
+ isBroadcastSideDone,
+ "Should not process regular input before broadcast side is done.");
+
+ super.processElement1(element);
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
index 5854fb096ed..32d9a259754 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonBroadcastStateTransformationTranslator.java
@@ -18,7 +18,12 @@
package org.apache.flink.streaming.runtime.translators.python;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.PythonOptions;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonBatchCoBroadcastProcessOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonCoProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonBatchCoBroadcastProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonCoProcessOperator;
import org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
@@ -29,8 +34,10 @@ import java.util.Collection;
/**
* A {@link org.apache.flink.streaming.api.graph.TransformationTranslator} that translates {@link
- * PythonBroadcastStateTransformation} into {@link ExternalPythonCoProcessOperator} or {@link
- * ExternalPythonBatchCoBroadcastProcessOperator}.
+ * PythonBroadcastStateTransformation} into {@link ExternalPythonCoProcessOperator}/{@link
+ * EmbeddedPythonCoProcessOperator} in streaming mode or {@link
+ * ExternalPythonBatchCoBroadcastProcessOperator}/{@link
+ * EmbeddedPythonBatchCoBroadcastProcessOperator} in batch mode.
*/
@Internal
public class PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
@@ -43,13 +50,27 @@ public class PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
Preconditions.checkNotNull(transformation);
Preconditions.checkNotNull(context);
- ExternalPythonBatchCoBroadcastProcessOperator operator =
- new ExternalPythonBatchCoBroadcastProcessOperator(
- transformation.getConfiguration(),
- transformation.getDataStreamPythonFunctionInfo(),
- transformation.getRegularInput().getOutputType(),
- transformation.getBroadcastInput().getOutputType(),
- transformation.getOutputType());
+ Configuration config = transformation.getConfiguration();
+
+ StreamOperator<OUT> operator;
+
+ if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+ operator =
+ new EmbeddedPythonBatchCoBroadcastProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ } else {
+ operator =
+ new ExternalPythonBatchCoBroadcastProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ }
return translateInternal(
transformation,
@@ -68,13 +89,28 @@ public class PythonBroadcastStateTransformationTranslator<IN1, IN2, OUT>
Preconditions.checkNotNull(transformation);
Preconditions.checkNotNull(context);
- ExternalPythonCoProcessOperator<IN1, IN2, OUT> operator =
- new ExternalPythonCoProcessOperator<>(
- transformation.getConfiguration(),
- transformation.getDataStreamPythonFunctionInfo(),
- transformation.getRegularInput().getOutputType(),
- transformation.getBroadcastInput().getOutputType(),
- transformation.getOutputType());
+ Configuration config = transformation.getConfiguration();
+
+ StreamOperator<OUT> operator;
+
+ if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+ operator =
+ new EmbeddedPythonCoProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ } else {
+
+ operator =
+ new ExternalPythonCoProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ }
return translateInternal(
transformation,
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
index d24cc1374de..cdbf89c1420 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/runtime/translators/python/PythonKeyedBroadcastStateTransformationTranslator.java
@@ -18,7 +18,12 @@
package org.apache.flink.streaming.runtime.translators.python;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.python.PythonOptions;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonBatchKeyedCoBroadcastProcessOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedCoProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonBatchKeyedCoBroadcastProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonKeyedCoProcessOperator;
import org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
@@ -30,8 +35,10 @@ import java.util.Collection;
/**
* A {@link org.apache.flink.streaming.api.graph.TransformationTranslator} that translates {@link
- * PythonKeyedBroadcastStateTransformation} into {@link ExternalPythonKeyedCoProcessOperator} or
- * {@link ExternalPythonBatchKeyedCoBroadcastProcessOperator}.
+ * PythonKeyedBroadcastStateTransformation} into {@link ExternalPythonKeyedCoProcessOperator}/{@link
+ * EmbeddedPythonKeyedCoProcessOperator} in streaming mode or {@link
+ * ExternalPythonBatchKeyedCoBroadcastProcessOperator}/{@link
+ * EmbeddedPythonBatchKeyedCoBroadcastProcessOperator} in batch mode.
*/
@Internal
public class PythonKeyedBroadcastStateTransformationTranslator<OUT>
@@ -44,13 +51,27 @@ public class PythonKeyedBroadcastStateTransformationTranslator<OUT>
Preconditions.checkNotNull(transformation);
Preconditions.checkNotNull(context);
- ExternalPythonKeyedCoProcessOperator<OUT> operator =
- new ExternalPythonBatchKeyedCoBroadcastProcessOperator<>(
- transformation.getConfiguration(),
- transformation.getDataStreamPythonFunctionInfo(),
- transformation.getRegularInput().getOutputType(),
- transformation.getBroadcastInput().getOutputType(),
- transformation.getOutputType());
+ Configuration config = transformation.getConfiguration();
+
+ StreamOperator<OUT> operator;
+
+ if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+ operator =
+ new EmbeddedPythonBatchKeyedCoBroadcastProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ } else {
+ operator =
+ new ExternalPythonBatchKeyedCoBroadcastProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ }
return translateInternal(
transformation,
@@ -69,13 +90,28 @@ public class PythonKeyedBroadcastStateTransformationTranslator<OUT>
Preconditions.checkNotNull(transformation);
Preconditions.checkNotNull(context);
- ExternalPythonKeyedCoProcessOperator<OUT> operator =
- new ExternalPythonKeyedCoProcessOperator<>(
- transformation.getConfiguration(),
- transformation.getDataStreamPythonFunctionInfo(),
- transformation.getRegularInput().getOutputType(),
- transformation.getBroadcastInput().getOutputType(),
- transformation.getOutputType());
+ Configuration config = transformation.getConfiguration();
+
+ StreamOperator<OUT> operator;
+
+ if (config.get(PythonOptions.PYTHON_EXECUTION_MODE).equals("thread")) {
+ operator =
+ new EmbeddedPythonKeyedCoProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+
+ } else {
+ operator =
+ new ExternalPythonKeyedCoProcessOperator<>(
+ transformation.getConfiguration(),
+ transformation.getDataStreamPythonFunctionInfo(),
+ transformation.getRegularInput().getOutputType(),
+ transformation.getBroadcastInput().getOutputType(),
+ transformation.getOutputType());
+ }
return translateInternal(
transformation,