You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/07/29 11:27:25 UTC
[flink] branch master updated: [FLINK-28559][python] Support DataStream PythonKeyedProcessOperator in Thread Mode
This is an automated email from the ASF dual-hosted git repository.
hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 4cc33f97d13 [FLINK-28559][python] Support DataStream PythonKeyedProcessOperator in Thread Mode
4cc33f97d13 is described below
commit 4cc33f97d1351be4836e973a528042684475a069
Author: huangxingbo <hx...@apache.org>
AuthorDate: Wed Jul 27 16:21:46 2022 +0800
[FLINK-28559][python] Support DataStream PythonKeyedProcessOperator in Thread Mode
This closes #20375.
---
flink-python/dev/dev-requirements.txt | 2 +-
flink-python/pom.xml | 2 +-
flink-python/pyflink/datastream/data_stream.py | 60 +--
flink-python/pyflink/datastream/state.py | 4 +-
.../pyflink/datastream/tests/test_data_stream.py | 420 +++++++++++----------
.../fn_execution/datastream/embedded/operations.py | 82 ++--
.../datastream/embedded/process_function.py | 42 ++-
.../datastream/embedded/runtime_context.py | 34 +-
.../fn_execution/datastream/embedded/state_impl.py | 158 ++++++++
.../{process_function.py => timerservice_impl.py} | 30 +-
.../pyflink/fn_execution/embedded/converters.py | 56 ++-
.../pyflink/fn_execution/embedded/java_utils.py | 204 ++++++++++
.../fn_execution/embedded/operation_utils.py | 6 +-
.../pyflink/fn_execution/embedded/operations.py | 13 +-
flink-python/setup.py | 2 +-
.../chain/PythonOperatorChainingOptimizer.java | 4 +-
...ctEmbeddedDataStreamPythonFunctionOperator.java | 20 +-
...ractOneInputEmbeddedPythonFunctionOperator.java | 21 +-
.../EmbeddedPythonKeyedProcessOperator.java | 214 +++++++++++
.../embedded/EmbeddedPythonProcessOperator.java | 8 +-
.../table/EmbeddedPythonTableFunctionOperator.java | 3 +-
flink-python/src/main/resources/META-INF/NOTICE | 2 +-
22 files changed, 1057 insertions(+), 330 deletions(-)
diff --git a/flink-python/dev/dev-requirements.txt b/flink-python/dev/dev-requirements.txt
index da91f005bb3..63a5852c946 100755
--- a/flink-python/dev/dev-requirements.txt
+++ b/flink-python/dev/dev-requirements.txt
@@ -31,6 +31,6 @@ numpy>=1.14.3,<1.20; python_version < '3.7'
fastavro>=1.1.0,<1.4.8
grpcio>=1.29.0,<1.47
grpcio-tools>=1.3.5,<=1.14.2
-pemja==0.2.0; python_version >= '3.7' and platform_system != 'Windows'
+pemja==0.2.2; python_version >= '3.7' and platform_system != 'Windows'
httplib2>=0.19.0,<=0.20.4
protobuf<3.18
\ No newline at end of file
diff --git a/flink-python/pom.xml b/flink-python/pom.xml
index fa5fe677285..78a08113aa7 100644
--- a/flink-python/pom.xml
+++ b/flink-python/pom.xml
@@ -110,7 +110,7 @@ under the License.
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>pemja</artifactId>
- <version>0.2.0</version>
+ <version>0.2.2</version>
</dependency>
<!-- Protobuf dependencies -->
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 15418960590..b893308abc3 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -49,8 +49,8 @@ from pyflink.datastream.functions import (_get_python_env, FlatMapFunction, MapF
InternalSingleValueProcessAllWindowFunction)
from pyflink.datastream.output_tag import OutputTag
from pyflink.datastream.slot_sharing_group import SlotSharingGroup
-from pyflink.datastream.state import ValueStateDescriptor, ValueState, ListStateDescriptor, \
- StateDescriptor, ReducingStateDescriptor, AggregatingStateDescriptor, MapStateDescriptor
+from pyflink.datastream.state import (ListStateDescriptor, StateDescriptor, ReducingStateDescriptor,
+ AggregatingStateDescriptor, MapStateDescriptor, ReducingState)
from pyflink.datastream.utils import convert_to_python_obj
from pyflink.datastream.window import (CountTumblingWindowAssigner, CountSlidingWindowAssigner,
CountWindowSerializer, TimeWindowSerializer, Trigger,
@@ -1202,6 +1202,12 @@ class KeyedStream(DataStream):
output_type = _from_java_type(self._original_data_type_info.get_java_type_info())
+ gateway = get_gateway()
+ j_conf = get_j_env_configuration(self._j_data_stream.getExecutionEnvironment())
+ python_execution_mode = (
+ j_conf.getString(
+ gateway.jvm.org.apache.flink.python.PythonOptions.PYTHON_EXECUTION_MODE))
+
class ReduceProcessKeyedProcessFunctionAdapter(KeyedProcessFunction):
def __init__(self, reduce_function):
@@ -1213,40 +1219,47 @@ class KeyedStream(DataStream):
self._open_func = None
self._close_func = None
self._reduce_function = reduce_function
- self._reduce_value_state = None # type: ValueState
+ self._reduce_state = None # type: ReducingState
+ self._in_batch_execution_mode = True
def open(self, runtime_context: RuntimeContext):
if self._open_func:
self._open_func(runtime_context)
- self._reduce_value_state = runtime_context.get_state(
- ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), output_type))
- from pyflink.fn_execution.datastream.process.runtime_context import (
- StreamingRuntimeContext)
- self._in_batch_execution_mode = \
- cast(StreamingRuntimeContext, runtime_context)._in_batch_execution_mode
+ self._reduce_state = runtime_context.get_reducing_state(
+ ReducingStateDescriptor(
+ "_reduce_state" + str(uuid.uuid4()),
+ self._reduce_function,
+ output_type))
+
+ if python_execution_mode == "process":
+ from pyflink.fn_execution.datastream.process.runtime_context import (
+ StreamingRuntimeContext)
+ self._in_batch_execution_mode = (
+ cast(StreamingRuntimeContext, runtime_context)._in_batch_execution_mode)
+ else:
+ self._in_batch_execution_mode = runtime_context.get_job_parameter(
+ "inBatchExecutionMode", "false") == "true"
def close(self):
if self._close_func:
self._close_func()
def process_element(self, value, ctx: 'KeyedProcessFunction.Context'):
- reduce_value = self._reduce_value_state.value()
- if reduce_value is not None:
- reduce_value = self._reduce_function(reduce_value, value)
- else:
- # register a timer for emitting the result at the end when this is the
- # first input for this key
- if self._in_batch_execution_mode:
+ if self._in_batch_execution_mode:
+ reduce_value = self._reduce_state.get()
+ if reduce_value is None:
+ # register a timer for emitting the result at the end when this is the
+ # first input for this key
ctx.timer_service().register_event_time_timer(0x7fffffffffffffff)
- reduce_value = value
- self._reduce_value_state.update(reduce_value)
- if not self._in_batch_execution_mode:
+ self._reduce_state.add(value)
+ else:
+ self._reduce_state.add(value)
# only emitting the result when all the data for a key is received
- yield reduce_value
+ yield self._reduce_state.get()
def on_timer(self, timestamp: int, ctx: 'KeyedProcessFunction.OnTimerContext'):
- current_value = self._reduce_value_state.value()
+ current_value = self._reduce_state.get()
if current_value is not None:
yield current_value
@@ -2708,7 +2721,10 @@ def _get_one_input_stream_operator(data_stream: DataStream,
else:
JDataStreamPythonFunctionOperator = gateway.jvm.ExternalPythonProcessOperator
elif func_type == UserDefinedDataStreamFunction.KEYED_PROCESS: # type: ignore
- JDataStreamPythonFunctionOperator = gateway.jvm.ExternalPythonKeyedProcessOperator
+ if python_execution_mode == 'thread':
+ JDataStreamPythonFunctionOperator = gateway.jvm.EmbeddedPythonKeyedProcessOperator
+ else:
+ JDataStreamPythonFunctionOperator = gateway.jvm.ExternalPythonKeyedProcessOperator
elif func_type == UserDefinedDataStreamFunction.WINDOW: # type: ignore
window_serializer = typing.cast(WindowOperationDescriptor, func).window_serializer
if isinstance(window_serializer, TimeWindowSerializer):
diff --git a/flink-python/pyflink/datastream/state.py b/flink-python/pyflink/datastream/state.py
index b60bcdeb6ee..f38b9963bd6 100644
--- a/flink-python/pyflink/datastream/state.py
+++ b/flink-python/pyflink/datastream/state.py
@@ -899,14 +899,14 @@ class StateTtlConfig(object):
Configuration of cleanup strategy while taking the full snapshot.
"""
- def __init__(self, cleanup_size: int, run_cleanup_for_every_record: int):
+ def __init__(self, cleanup_size: int, run_cleanup_for_every_record: bool):
self._cleanup_size = cleanup_size
self._run_cleanup_for_every_record = run_cleanup_for_every_record
def get_cleanup_size(self) -> int:
return self._cleanup_size
- def run_cleanup_for_every_record(self) -> int:
+ def run_cleanup_for_every_record(self) -> bool:
return self._run_cleanup_for_every_record
class RocksdbCompactFilterCleanupStrategy(CleanupStrategy):
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index d302e2bf18d..beba02e1720 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -164,6 +164,210 @@ class DataStreamTests(object):
self.env.execute('test_partition_custom')
+ def test_keyed_process_function_with_state(self):
+ self.env.get_config().set_auto_watermark_interval(2000)
+ self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime)
+ data_stream = self.env.from_collection([(1, 'hi', '1603708211000'),
+ (2, 'hello', '1603708224000'),
+ (3, 'hi', '1603708226000'),
+ (4, 'hello', '1603708289000'),
+ (5, 'hi', '1603708291000'),
+ (6, 'hello', '1603708293000')],
+ type_info=Types.ROW([Types.INT(), Types.STRING(),
+ Types.STRING()]))
+
+ class MyTimestampAssigner(TimestampAssigner):
+
+ def extract_timestamp(self, value, record_timestamp) -> int:
+ return int(value[2])
+
+ class MyProcessFunction(KeyedProcessFunction):
+
+ def __init__(self):
+ self.value_state = None
+ self.list_state = None
+ self.map_state = None
+
+ def open(self, runtime_context: RuntimeContext):
+ value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
+ self.value_state = runtime_context.get_state(value_state_descriptor)
+ list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
+ self.list_state = runtime_context.get_list_state(list_state_descriptor)
+ map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
+ state_ttl_config = StateTtlConfig \
+ .new_builder(Time.seconds(1)) \
+ .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
+ .set_state_visibility(
+ StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp) \
+ .disable_cleanup_in_background() \
+ .build()
+ map_state_descriptor.enable_time_to_live(state_ttl_config)
+ self.map_state = runtime_context.get_map_state(map_state_descriptor)
+
+ def process_element(self, value, ctx):
+ current_value = self.value_state.value()
+ self.value_state.update(value[0])
+ current_list = [_ for _ in self.list_state.get()]
+ self.list_state.add(value[0])
+ map_entries = {k: v for k, v in self.map_state.items()}
+ keys = sorted(map_entries.keys())
+ map_entries_string = [str(k) + ': ' + str(map_entries[k]) for k in keys]
+ map_entries_string = '{' + ', '.join(map_entries_string) + '}'
+ self.map_state.put(value[0], value[1])
+ current_key = ctx.get_current_key()
+ yield "current key: {}, current value state: {}, current list state: {}, " \
+ "current map state: {}, current value: {}".format(str(current_key),
+ str(current_value),
+ str(current_list),
+ map_entries_string,
+ str(value))
+
+ def on_timer(self, timestamp, ctx):
+ pass
+
+ watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
+ .with_timestamp_assigner(MyTimestampAssigner())
+ data_stream.assign_timestamps_and_watermarks(watermark_strategy) \
+ .key_by(lambda x: x[1], key_type=Types.STRING()) \
+ .process(MyProcessFunction(), output_type=Types.STRING()) \
+ .add_sink(self.test_sink)
+ self.env.execute('test time stamp assigner with keyed process function')
+ results = self.test_sink.get_results()
+ expected = ["current key: hi, current value state: None, current list state: [], "
+ "current map state: {}, current value: Row(f0=1, f1='hi', "
+ "f2='1603708211000')",
+ "current key: hello, current value state: None, "
+ "current list state: [], current map state: {}, current value: Row(f0=2,"
+ " f1='hello', f2='1603708224000')",
+ "current key: hi, current value state: 1, current list state: [1], "
+ "current map state: {1: hi}, current value: Row(f0=3, f1='hi', "
+ "f2='1603708226000')",
+ "current key: hello, current value state: 2, current list state: [2], "
+ "current map state: {2: hello}, current value: Row(f0=4, f1='hello', "
+ "f2='1603708289000')",
+ "current key: hi, current value state: 3, current list state: [1, 3], "
+ "current map state: {1: hi, 3: hi}, current value: Row(f0=5, f1='hi', "
+ "f2='1603708291000')",
+ "current key: hello, current value state: 4, current list state: [2, 4],"
+ " current map state: {2: hello, 4: hello}, current value: Row(f0=6, "
+ "f1='hello', f2='1603708293000')"]
+ self.assert_equals_sorted(expected, results)
+
+ def test_reducing_state(self):
+ self.env.set_parallelism(2)
+ data_stream = self.env.from_collection([
+ (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
+ type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
+
+ class MyProcessFunction(KeyedProcessFunction):
+
+ def __init__(self):
+ self.reducing_state = None # type: ReducingState
+
+ def open(self, runtime_context: RuntimeContext):
+ self.reducing_state = runtime_context.get_reducing_state(
+ ReducingStateDescriptor(
+ 'reducing_state', lambda i, i2: i + i2, Types.INT()))
+
+ def process_element(self, value, ctx):
+ self.reducing_state.add(value[0])
+ yield self.reducing_state.get(), value[1]
+
+ data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
+ .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
+ .add_sink(self.test_sink)
+ self.env.execute('test_reducing_state')
+ result = self.test_sink.get_results()
+ expected_result = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
+ result.sort()
+ expected_result.sort()
+ self.assertEqual(expected_result, result)
+
+ def test_aggregating_state(self):
+ self.env.set_parallelism(2)
+ data_stream = self.env.from_collection([
+ (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
+ type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
+
+ class MyAggregateFunction(AggregateFunction):
+
+ def create_accumulator(self):
+ return 0
+
+ def add(self, value, accumulator):
+ return value + accumulator
+
+ def get_result(self, accumulator):
+ return accumulator
+
+ def merge(self, acc_a, acc_b):
+ return acc_a + acc_b
+
+ class MyProcessFunction(KeyedProcessFunction):
+
+ def __init__(self):
+ self.aggregating_state = None # type: AggregatingState
+
+ def open(self, runtime_context: RuntimeContext):
+ descriptor = AggregatingStateDescriptor(
+ 'aggregating_state', MyAggregateFunction(), Types.INT())
+ state_ttl_config = StateTtlConfig \
+ .new_builder(Time.seconds(1)) \
+ .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
+ .disable_cleanup_in_background() \
+ .build()
+ descriptor.enable_time_to_live(state_ttl_config)
+ self.aggregating_state = runtime_context.get_aggregating_state(descriptor)
+
+ def process_element(self, value, ctx):
+ self.aggregating_state.add(value[0])
+ yield self.aggregating_state.get(), value[1]
+
+ config = Configuration(
+ j_configuration=get_j_env_configuration(self.env._j_stream_execution_environment))
+ config.set_integer("python.fn-execution.bundle.size", 1)
+ data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
+ .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
+ .add_sink(self.test_sink)
+ self.env.execute('test_aggregating_state')
+ results = self.test_sink.get_results()
+ expected = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
+ self.assert_equals_sorted(expected, results)
+
+
+class DataStreamStreamingTests(DataStreamTests):
+
+ def test_reduce_with_state(self):
+ ds = self.env.from_collection([('a', 0), ('c', 1), ('d', 1), ('b', 0), ('e', 1)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+ keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
+
+ with self.assertRaises(Exception):
+ keyed_stream.name("keyed stream")
+
+ keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
+ self.env.execute('key_by_test')
+ results = self.test_sink.get_results(False)
+ expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde, 1]']
+ self.assert_equals_sorted(expected, results)
+
+
+class DataStreamBatchTests(DataStreamTests):
+
+ def test_reduce_with_state(self):
+ ds = self.env.from_collection([('a', 0), ('c', 1), ('d', 1), ('b', 0), ('e', 1)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+ keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
+
+ with self.assertRaises(Exception):
+ keyed_stream.name("keyed stream")
+
+ keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
+ self.env.execute('key_by_test')
+ results = self.test_sink.get_results(False)
+ expected = ['+I[ab, 0]', '+I[cde, 1]']
+ self.assert_equals_sorted(expected, results)
+
class ProcessDataStreamTests(DataStreamTests):
"""
@@ -640,176 +844,6 @@ class ProcessDataStreamTests(DataStreamTests):
"-9223372036854775808, current_value: Row(f0=4, f1='1603708289000')"]
self.assert_equals_sorted(expected, results)
- def test_keyed_process_function_with_state(self):
- self.env.get_config().set_auto_watermark_interval(2000)
- self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime)
- data_stream = self.env.from_collection([(1, 'hi', '1603708211000'),
- (2, 'hello', '1603708224000'),
- (3, 'hi', '1603708226000'),
- (4, 'hello', '1603708289000'),
- (5, 'hi', '1603708291000'),
- (6, 'hello', '1603708293000')],
- type_info=Types.ROW([Types.INT(), Types.STRING(),
- Types.STRING()]))
-
- class MyTimestampAssigner(TimestampAssigner):
-
- def extract_timestamp(self, value, record_timestamp) -> int:
- return int(value[2])
-
- class MyProcessFunction(KeyedProcessFunction):
-
- def __init__(self):
- self.value_state = None
- self.list_state = None
- self.map_state = None
-
- def open(self, runtime_context: RuntimeContext):
- value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
- self.value_state = runtime_context.get_state(value_state_descriptor)
- list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
- self.list_state = runtime_context.get_list_state(list_state_descriptor)
- map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
- state_ttl_config = StateTtlConfig \
- .new_builder(Time.seconds(1)) \
- .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
- .set_state_visibility(
- StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp) \
- .disable_cleanup_in_background() \
- .build()
- map_state_descriptor.enable_time_to_live(state_ttl_config)
- self.map_state = runtime_context.get_map_state(map_state_descriptor)
-
- def process_element(self, value, ctx):
- current_value = self.value_state.value()
- self.value_state.update(value[0])
- current_list = [_ for _ in self.list_state.get()]
- self.list_state.add(value[0])
- map_entries = {k: v for k, v in self.map_state.items()}
- keys = sorted(map_entries.keys())
- map_entries_string = [str(k) + ': ' + str(map_entries[k]) for k in keys]
- map_entries_string = '{' + ', '.join(map_entries_string) + '}'
- self.map_state.put(value[0], value[1])
- current_key = ctx.get_current_key()
- yield "current key: {}, current value state: {}, current list state: {}, " \
- "current map state: {}, current value: {}".format(str(current_key),
- str(current_value),
- str(current_list),
- map_entries_string,
- str(value))
-
- def on_timer(self, timestamp, ctx):
- pass
-
- watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
- .with_timestamp_assigner(MyTimestampAssigner())
- data_stream.assign_timestamps_and_watermarks(watermark_strategy) \
- .key_by(lambda x: x[1], key_type=Types.STRING()) \
- .process(MyProcessFunction(), output_type=Types.STRING()) \
- .add_sink(self.test_sink)
- self.env.execute('test time stamp assigner with keyed process function')
- results = self.test_sink.get_results()
- expected = ["current key: hi, current value state: None, current list state: [], "
- "current map state: {}, current value: Row(f0=1, f1='hi', "
- "f2='1603708211000')",
- "current key: hello, current value state: None, "
- "current list state: [], current map state: {}, current value: Row(f0=2,"
- " f1='hello', f2='1603708224000')",
- "current key: hi, current value state: 1, current list state: [1], "
- "current map state: {1: hi}, current value: Row(f0=3, f1='hi', "
- "f2='1603708226000')",
- "current key: hello, current value state: 2, current list state: [2], "
- "current map state: {2: hello}, current value: Row(f0=4, f1='hello', "
- "f2='1603708289000')",
- "current key: hi, current value state: 3, current list state: [1, 3], "
- "current map state: {1: hi, 3: hi}, current value: Row(f0=5, f1='hi', "
- "f2='1603708291000')",
- "current key: hello, current value state: 4, current list state: [2, 4],"
- " current map state: {2: hello, 4: hello}, current value: Row(f0=6, "
- "f1='hello', f2='1603708293000')"]
- self.assert_equals_sorted(expected, results)
-
- def test_reducing_state(self):
- self.env.set_parallelism(2)
- data_stream = self.env.from_collection([
- (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
- type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
-
- class MyProcessFunction(KeyedProcessFunction):
-
- def __init__(self):
- self.reducing_state = None # type: ReducingState
-
- def open(self, runtime_context: RuntimeContext):
- self.reducing_state = runtime_context.get_reducing_state(
- ReducingStateDescriptor(
- 'reducing_state', lambda i, i2: i + i2, Types.INT()))
-
- def process_element(self, value, ctx):
- self.reducing_state.add(value[0])
- yield self.reducing_state.get(), value[1]
-
- data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
- .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
- .add_sink(self.test_sink)
- self.env.execute('test_reducing_state')
- result = self.test_sink.get_results()
- expected_result = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
- result.sort()
- expected_result.sort()
- self.assertEqual(expected_result, result)
-
- def test_aggregating_state(self):
- self.env.set_parallelism(2)
- data_stream = self.env.from_collection([
- (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
- type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
-
- class MyAggregateFunction(AggregateFunction):
-
- def create_accumulator(self):
- return 0
-
- def add(self, value, accumulator):
- return value + accumulator
-
- def get_result(self, accumulator):
- return accumulator
-
- def merge(self, acc_a, acc_b):
- return acc_a + acc_b
-
- class MyProcessFunction(KeyedProcessFunction):
-
- def __init__(self):
- self.aggregating_state = None # type: AggregatingState
-
- def open(self, runtime_context: RuntimeContext):
- descriptor = AggregatingStateDescriptor(
- 'aggregating_state', MyAggregateFunction(), Types.INT())
- state_ttl_config = StateTtlConfig \
- .new_builder(Time.seconds(1)) \
- .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
- .disable_cleanup_in_background() \
- .build()
- descriptor.enable_time_to_live(state_ttl_config)
- self.aggregating_state = runtime_context.get_aggregating_state(descriptor)
-
- def process_element(self, value, ctx):
- self.aggregating_state.add(value[0])
- yield self.aggregating_state.get(), value[1]
-
- config = Configuration(
- j_configuration=get_j_env_configuration(self.env._j_stream_execution_environment))
- config.set_integer("python.fn-execution.bundle.size", 1)
- data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
- .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
- .add_sink(self.test_sink)
- self.env.execute('test_aggregating_state')
- results = self.test_sink.get_results()
- expected = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
- self.assert_equals_sorted(expected, results)
-
def test_process_side_output(self):
tag = OutputTag("side", Types.INT())
@@ -1122,7 +1156,8 @@ class ProcessDataStreamTests(DataStreamTests):
self.assert_equals_sorted(expected, self.test_sink.get_results())
-class StreamingModeDataStreamTests(ProcessDataStreamTests, PyFlinkStreamingTestCase):
+class ProcessDataStreamStreamingTests(DataStreamStreamingTests, ProcessDataStreamTests,
+ PyFlinkStreamingTestCase):
def test_data_stream_name(self):
ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
test_name = 'test_name'
@@ -1401,20 +1436,6 @@ class StreamingModeDataStreamTests(ProcessDataStreamTests, PyFlinkStreamingTestC
expected = ["+I[1, a]", "+I[3, a]", "+I[6, a]", "+I[4, b]"]
self.assert_equals_sorted(expected, results)
- def test_reduce_with_state(self):
- ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 1)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
- keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
-
- with self.assertRaises(Exception):
- keyed_stream.name("keyed stream")
-
- keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
- self.env.execute('key_by_test')
- results = self.test_sink.get_results(False)
- expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde, 1]']
- self.assert_equals_sorted(expected, results)
-
def test_keyed_sum(self):
self.env.set_parallelism(1)
ds = self.env.from_collection(
@@ -1529,7 +1550,8 @@ class StreamingModeDataStreamTests(ProcessDataStreamTests, PyFlinkStreamingTestC
self.assert_equals_sorted(expected, results)
-class BatchModeDataStreamTests(ProcessDataStreamTests, PyFlinkBatchTestCase):
+class ProcessDataStreamBatchTests(DataStreamBatchTests, ProcessDataStreamTests,
+ PyFlinkBatchTestCase):
def test_timestamp_assigner_and_watermark_strategy(self):
self.env.set_parallelism(1)
@@ -1587,20 +1609,6 @@ class BatchModeDataStreamTests(ProcessDataStreamTests, PyFlinkBatchTestCase):
expected = ["+I[6, a]", "+I[7, b]"]
self.assert_equals_sorted(expected, results)
- def test_reduce_with_state(self):
- ds = self.env.from_collection([('a', 0), ('c', 1), ('d', 1), ('b', 0), ('e', 1)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
- keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
-
- with self.assertRaises(Exception):
- keyed_stream.name("keyed stream")
-
- keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
- self.env.execute('key_by_test')
- results = self.test_sink.get_results(False)
- expected = ['+I[ab, 0]', '+I[cde, 1]']
- self.assert_equals_sorted(expected, results)
-
def test_keyed_sum(self):
self.env.set_parallelism(1)
ds = self.env.from_collection(
@@ -1704,9 +1712,17 @@ class BatchModeDataStreamTests(ProcessDataStreamTests, PyFlinkBatchTestCase):
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7")
-class EmbeddedDataStreamTests(DataStreamTests, PyFlinkStreamingTestCase):
+class EmbeddedDataStreamStreamTests(DataStreamStreamingTests, PyFlinkStreamingTestCase):
+ def setUp(self):
+ super(EmbeddedDataStreamStreamTests, self).setUp()
+ config = get_j_env_configuration(self.env._j_stream_execution_environment)
+ config.setString("python.execution-mode", "thread")
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7")
+class EmbeddedDataStreamBatchTests(DataStreamBatchTests, PyFlinkBatchTestCase):
def setUp(self):
- super(EmbeddedDataStreamTests, self).setUp()
+ super(EmbeddedDataStreamBatchTests, self).setUp()
config = get_j_env_configuration(self.env._j_stream_execution_environment)
config.setString("python.execution-mode", "thread")
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
index 50be1362075..a109d2fba94 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
@@ -18,41 +18,46 @@
from pyflink.fn_execution import pickle
from pyflink.fn_execution.datastream import operations
-from pyflink.fn_execution.datastream.embedded.process_function import InternalProcessFunctionContext
+from pyflink.fn_execution.datastream.embedded.process_function import (
+ InternalProcessFunctionContext, InternalKeyedProcessFunctionContext,
+ InternalKeyedProcessFunctionOnTimerContext)
from pyflink.fn_execution.datastream.embedded.runtime_context import StreamingRuntimeContext
-from pyflink.fn_execution.datastream.operations import DATA_STREAM_STATELESS_FUNCTION_URN
class OneInputOperation(operations.OneInputOperation):
- def __init__(self,
- function_urn,
- serialized_fn,
- runtime_context,
- function_context,
- job_parameters):
- (self.open_func,
- self.close_func,
- self.process_element_func
- ) = extract_one_input_process_function(
- function_urn=function_urn,
- user_defined_function_proto=serialized_fn,
- runtime_context=StreamingRuntimeContext.of(runtime_context, job_parameters),
- function_context=function_context)
+ def __init__(self, open_func, close_func, process_element_func, on_timer_func=None):
+ self._open_func = open_func
+ self._close_func = close_func
+ self._process_element_func = process_element_func
+ self._on_timer_func = on_timer_func
def open(self) -> None:
- self.open_func()
+ self._open_func()
def close(self) -> None:
- self.close_func()
+ self._close_func()
def process_element(self, value):
- return self.process_element_func(value)
+ return self._process_element_func(value)
+ def on_timer(self, timestamp):
+ if self._on_timer_func:
+ return self._on_timer_func(timestamp)
+
+
+def extract_process_function(
+ user_defined_function_proto, runtime_context, function_context, timer_context,
+ job_parameters):
+ from pyflink.fn_execution import flink_fn_execution_pb2
-def extract_one_input_process_function(
- function_urn, user_defined_function_proto, runtime_context, function_context):
user_defined_func = pickle.loads(user_defined_function_proto.payload)
+ func_type = user_defined_function_proto.function_type
+
+ UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
+
+ runtime_context = StreamingRuntimeContext.of(runtime_context, job_parameters)
+
def open_func():
if hasattr(user_defined_func, "open"):
user_defined_func.open(runtime_context)
@@ -61,12 +66,35 @@ def extract_one_input_process_function(
if hasattr(user_defined_func, "close"):
user_defined_func.close()
- process_element = user_defined_func.process_element
+ if func_type == UserDefinedDataStreamFunction.PROCESS:
+ function_context = InternalProcessFunctionContext(function_context)
+
+ process_element = user_defined_func.process_element
+
+ def process_element_func(value):
+ yield from process_element(value, function_context)
+
+ return OneInputOperation(open_func, close_func, process_element_func)
+
+ elif func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:
+
+ function_context = InternalKeyedProcessFunctionContext(
+ function_context, user_defined_function_proto.key_type_info)
+
+ timer_context = InternalKeyedProcessFunctionOnTimerContext(
+ timer_context, user_defined_function_proto.key_type_info)
+
+ on_timer = user_defined_func.on_timer
+
+ def process_element(value, context):
+ return user_defined_func.process_element(value[1], context)
- if function_urn == DATA_STREAM_STATELESS_FUNCTION_URN:
- context = InternalProcessFunctionContext(function_context)
+ def on_timer_func(timestamp):
+ yield from on_timer(timestamp, timer_context)
- def process_element_func(value):
- yield from process_element(value, context)
+ def process_element_func(value):
+ yield from process_element(value, function_context)
- return open_func, close_func, process_element_func
+ return OneInputOperation(open_func, close_func, process_element_func, on_timer_func)
+ else:
+ raise Exception("Unknown function type {0}.".format(func_type))
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
index 8c7a1608634..eee7b4b01d8 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
@@ -15,7 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from pyflink.datastream import ProcessFunction, TimerService
+from pyflink.datastream import ProcessFunction, TimerService, KeyedProcessFunction, TimeDomain
+from pyflink.fn_execution.datastream.embedded.timerservice_impl import TimerServiceImpl
+from pyflink.fn_execution.embedded.converters import from_type_info
class InternalProcessFunctionContext(ProcessFunction.Context, TimerService):
@@ -45,3 +47,41 @@ class InternalProcessFunctionContext(ProcessFunction.Context, TimerService):
def delete_event_time_timer(self, t: int):
raise Exception("Deleting timers is only supported on a keyed streams.")
+
+
+class InternalKeyedProcessFunctionContext(KeyedProcessFunction.Context):
+
+ def __init__(self, context, key_type_info):
+ self._context = context
+ self._timer_service = TimerServiceImpl(self._context.timerService())
+ self._key_converter = from_type_info(key_type_info)
+
+ def get_current_key(self):
+ return self._key_converter.to_internal(self._context.getCurrentKey())
+
+ def timer_service(self) -> TimerService:
+ return self._timer_service
+
+ def timestamp(self) -> int:
+ return self._context.timestamp()
+
+
+class InternalKeyedProcessFunctionOnTimerContext(KeyedProcessFunction.OnTimerContext,
+ KeyedProcessFunction.Context):
+
+ def __init__(self, context, key_type_info):
+ self._context = context
+ self._timer_service = TimerServiceImpl(self._context.timerService())
+ self._key_converter = from_type_info(key_type_info)
+
+ def timer_service(self) -> TimerService:
+ return self._timer_service
+
+ def timestamp(self) -> int:
+ return self._context.timestamp()
+
+ def time_domain(self) -> TimeDomain:
+ return TimeDomain(self._context.timeDomain())
+
+ def get_current_key(self):
+ return self._key_converter.to_internal(self._context.getCurrentKey())
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py b/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py
index 64abbf99b11..c7d9508446e 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py
@@ -16,9 +16,15 @@
# limitations under the License.
################################################################################
from pyflink.datastream import RuntimeContext
-from pyflink.datastream.state import AggregatingStateDescriptor, AggregatingState, \
- ReducingStateDescriptor, ReducingState, MapStateDescriptor, MapState, ListStateDescriptor, \
- ListState, ValueStateDescriptor, ValueState
+from pyflink.datastream.state import (AggregatingStateDescriptor, AggregatingState,
+ ReducingStateDescriptor, ReducingState, MapStateDescriptor,
+ MapState, ListStateDescriptor, ListState,
+ ValueStateDescriptor, ValueState)
+from pyflink.fn_execution.datastream.embedded.state_impl import (ValueStateImpl, ListStateImpl,
+ MapStateImpl, ReducingStateImpl,
+ AggregatingStateImpl)
+from pyflink.fn_execution.embedded.converters import from_type_info
+from pyflink.fn_execution.embedded.java_utils import to_java_state_descriptor
class StreamingRuntimeContext(RuntimeContext):
@@ -76,20 +82,32 @@ class StreamingRuntimeContext(RuntimeContext):
return self._runtime_context.getMetricGroup()
def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState:
- pass
+ return ValueStateImpl(
+ self._runtime_context.getState(to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
def get_list_state(self, state_descriptor: ListStateDescriptor) -> ListState:
- pass
+ return ListStateImpl(
+ self._runtime_context.getListState(to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
def get_map_state(self, state_descriptor: MapStateDescriptor) -> MapState:
- pass
+ return MapStateImpl(
+ self._runtime_context.getMapState(to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info))
def get_reducing_state(self, state_descriptor: ReducingStateDescriptor) -> ReducingState:
- pass
+ return ReducingStateImpl(
+ self._runtime_context.getState(to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info),
+ state_descriptor.get_reduce_function())
def get_aggregating_state(self,
state_descriptor: AggregatingStateDescriptor) -> AggregatingState:
- pass
+ return AggregatingStateImpl(
+ self._runtime_context.getState(to_java_state_descriptor(state_descriptor)),
+ from_type_info(state_descriptor.type_info),
+ state_descriptor.get_agg_function())
@staticmethod
def of(runtime_context, job_parameters):
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
new file mode 100644
index 00000000000..e88c6abc304
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
@@ -0,0 +1,158 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List, Iterable, Tuple, Dict
+
+from pyflink.datastream import ReduceFunction, AggregateFunction
+from pyflink.datastream.state import (ValueState, T, State, ListState, IN, OUT, ReducingState,
+ AggregatingState, MapState, V, K)
+from pyflink.fn_execution.embedded.converters import (DataConverter, DictDataConverter,
+ ListDataConverter)
+
+
+class StateImpl(State):
+ def __init__(self, state, value_converter: DataConverter):
+ self._state = state
+ self._value_converter = value_converter
+
+ def clear(self):
+ self._state.clear()
+
+
+class ValueStateImpl(StateImpl, ValueState):
+ def __init__(self, value_state, value_converter: DataConverter):
+ super(ValueStateImpl, self).__init__(value_state, value_converter)
+
+ def value(self) -> T:
+ return self._value_converter.to_internal(self._state.value())
+
+ def update(self, value: T) -> None:
+ self._state.update(self._value_converter.to_external(value))
+
+
+class ListStateImpl(StateImpl, ListState):
+
+ def __init__(self, list_state, value_converter: ListDataConverter):
+ super(ListStateImpl, self).__init__(list_state, value_converter)
+ self._element_converter = value_converter._field_converter
+
+ def update(self, values: List[T]) -> None:
+ self._state.update(self._value_converter.to_external(values))
+
+ def add_all(self, values: List[T]) -> None:
+ self._state.addAll(self._value_converter.to_external(values))
+
+ def get(self) -> OUT:
+ return self._value_converter.to_internal(self._state.get())
+
+ def add(self, value: IN) -> None:
+ self._state.add(self._element_converter.to_external(value))
+
+
+class ReducingStateImpl(StateImpl, ReducingState):
+
+ def __init__(self,
+ value_state,
+ value_converter: DataConverter,
+ reduce_function: ReduceFunction):
+ super(ReducingStateImpl, self).__init__(value_state, value_converter)
+ self._reduce_function = reduce_function
+
+ def get(self) -> OUT:
+ return self._value_converter.to_internal(self._state.value())
+
+ def add(self, value: IN) -> None:
+ if value is None:
+ self.clear()
+ else:
+ current_value = self.get()
+
+ if current_value is None:
+ reduce_value = value
+ else:
+ reduce_value = self._reduce_function.reduce(current_value, value)
+
+ self._state.update(self._value_converter.to_external(reduce_value))
+
+
+class AggregatingStateImpl(StateImpl, AggregatingState):
+ def __init__(self,
+ value_state,
+ value_converter,
+ agg_function: AggregateFunction):
+ super(AggregatingStateImpl, self).__init__(value_state, value_converter)
+ self._agg_function = agg_function
+
+ def get(self) -> OUT:
+ accumulator = self._value_converter.to_internal(self._state.value())
+
+ if accumulator is None:
+ return None
+ else:
+ return self._agg_function.get_result(accumulator)
+
+ def add(self, value: IN) -> None:
+ if value is None:
+ self.clear()
+ else:
+ accumulator = self._value_converter.to_internal(self._state.value())
+
+ if accumulator is None:
+ accumulator = self._agg_function.create_accumulator()
+
+ accumulator = self._agg_function.add(value, accumulator)
+ self._state.update(self._value_converter.to_external(accumulator))
+
+
+class MapStateImpl(StateImpl, MapState):
+ def __init__(self, map_state, map_converter: DictDataConverter):
+ super(MapStateImpl, self).__init__(map_state, map_converter)
+ self._k_converter = map_converter._key_converter
+ self._v_converter = map_converter._value_converter
+
+ def get(self, key: K) -> V:
+ return self._value_converter.to_internal(
+ self._state.get(self._k_converter.to_external(key)))
+
+ def put(self, key: K, value: V) -> None:
+ self._state.put(self._k_converter.to_external(key), self._v_converter.to_external(value))
+
+ def put_all(self, dict_value: Dict[K, V]) -> None:
+ self._state.putAll(self._value_converter.to_external(dict_value))
+
+ def remove(self, key: K) -> None:
+ self._state.remove(self._k_converter.to_external(key))
+
+ def contains(self, key: K) -> bool:
+ return self._state.contains(self._k_converter.to_external(key))
+
+ def items(self) -> Iterable[Tuple[K, V]]:
+ entries = self._state.entries()
+ for entry in entries:
+ yield (self._k_converter.to_internal(entry.getKey()),
+ self._v_converter.to_internal(entry.getValue()))
+
+ def keys(self) -> Iterable[K]:
+ for k in self._state.keys():
+ yield self._k_converter.to_internal(k)
+
+ def values(self) -> Iterable[V]:
+ for v in self._state.values():
+ yield self._v_converter.to_internal(v)
+
+ def is_empty(self) -> bool:
+ return self._state.isEmpty()
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py b/flink-python/pyflink/fn_execution/datastream/embedded/timerservice_impl.py
similarity index 57%
copy from flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
copy to flink-python/pyflink/fn_execution/datastream/embedded/timerservice_impl.py
index 8c7a1608634..fab9440b230 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/timerservice_impl.py
@@ -15,33 +15,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from pyflink.datastream import ProcessFunction, TimerService
+from pyflink.datastream import TimerService
-class InternalProcessFunctionContext(ProcessFunction.Context, TimerService):
- def __init__(self, context):
- self._context = context
-
- def timer_service(self) -> TimerService:
- return self
-
- def timestamp(self) -> int:
- return self._context.timestamp()
+class TimerServiceImpl(TimerService):
+ def __init__(self, timer_service):
+ self._timer_service = timer_service
def current_processing_time(self):
- return self._context.currentProcessingTime()
+ return self._timer_service.currentProcessingTime()
def current_watermark(self):
- return self._context.currentWatermark()
+ return self._timer_service.currentWatermark()
def register_processing_time_timer(self, timestamp: int):
- raise Exception("Register timers is only supported on a keyed stream.")
+ self._timer_service.registerProcessingTimeTimer(timestamp)
def register_event_time_timer(self, timestamp: int):
- raise Exception("Register timers is only supported on a keyed stream.")
+ self._timer_service.registerEventTimeTimer(timestamp)
- def delete_processing_time_timer(self, t: int):
- raise Exception("Deleting timers is only supported on a keyed streams.")
+ def delete_processing_time_timer(self, timestamp: int):
+ self._timer_service.deleteProcessingTimeTimer(timestamp)
- def delete_event_time_timer(self, t: int):
- raise Exception("Deleting timers is only supported on a keyed streams.")
+ def delete_event_time_timer(self, timestamp: int):
+ self._timer_service.deleteEventTimeTimer(timestamp)
diff --git a/flink-python/pyflink/fn_execution/embedded/converters.py b/flink-python/pyflink/fn_execution/embedded/converters.py
index 157a082c93a..ee75e6d5beb 100644
--- a/flink-python/pyflink/fn_execution/embedded/converters.py
+++ b/flink-python/pyflink/fn_execution/embedded/converters.py
@@ -20,7 +20,10 @@ from abc import ABC, abstractmethod
import pickle
from typing import TypeVar, List, Tuple
-from pyflink.common import Row, RowKind
+from pyflink.common import Row, RowKind, TypeInformation
+from pyflink.common.typeinfo import (PickledBytesTypeInfo, PrimitiveArrayTypeInfo,
+ BasicArrayTypeInfo, ObjectArrayTypeInfo, RowTypeInfo,
+ TupleTypeInfo, MapTypeInfo, ListTypeInfo)
IN = TypeVar('IN')
OUT = TypeVar('OUT')
@@ -70,13 +73,15 @@ class FlattenRowDataConverter(DataConverter):
if value is None:
return None
- return [self._field_data_converters[i].to_internal(item) for i, item in enumerate(value)]
+ return tuple([self._field_data_converters[i].to_internal(item)
+ for i, item in enumerate(value)])
def to_external(self, value) -> OUT:
if value is None:
return None
- return [self._field_data_converters[i].to_external(item) for i, item in enumerate(value)]
+ return tuple([self._field_data_converters[i].to_external(item)
+ for i, item in enumerate(value)])
class RowDataConverter(DataConverter):
@@ -84,8 +89,6 @@ class RowDataConverter(DataConverter):
def __init__(self, field_data_converters: List[DataConverter], field_names: List[str]):
self._field_data_converters = field_data_converters
self._reuse_row = Row()
- self._reuse_external_row_data = [None for _ in range(len(field_data_converters))]
- self._reuse_external_row = [None, self._reuse_external_row_data]
self._reuse_row.set_field_names(field_names)
def to_internal(self, value) -> IN:
@@ -102,11 +105,10 @@ class RowDataConverter(DataConverter):
if value is None:
return None
- self._reuse_external_row[0] = value.get_row_kind().value
values = value._values
- for i in range(len(values)):
- self._reuse_external_row_data[i] = self._field_data_converters[i].to_external(values[i])
- return self._reuse_external_row
+ fields = tuple([self._field_data_converters[i].to_external(values[i])
+ for i in range(len(values))])
+ return value.get_row_kind().value, fields
class TupleDataConverter(DataConverter):
@@ -125,8 +127,8 @@ class TupleDataConverter(DataConverter):
if value is None:
return None
- return [self._field_data_converters[i].to_external(item)
- for i, item in enumerate(value)]
+ return tuple([self._field_data_converters[i].to_external(item)
+ for i, item in enumerate(value)])
class ListDataConverter(DataConverter):
@@ -147,6 +149,17 @@ class ListDataConverter(DataConverter):
return [self._field_converter.to_external(item) for item in value]
+class ArrayDataConverter(ListDataConverter):
+ def __init__(self, field_converter: DataConverter):
+ super(ArrayDataConverter, self).__init__(field_converter)
+
+ def to_internal(self, value) -> IN:
+ return tuple(super(ArrayDataConverter, self).to_internal(value))
+
+ def to_external(self, value) -> OUT:
+ return tuple(super(ArrayDataConverter, self).to_external(value))
+
+
class DictDataConverter(DataConverter):
def __init__(self, key_converter: DataConverter, value_converter: DataConverter):
self._key_converter = key_converter
@@ -185,8 +198,9 @@ def from_type_info_proto(type_info):
[from_type_info_proto(field_type)
for field_type in type_info.tuple_type_info.field_types])
elif type_name in (type_info_name.BASIC_ARRAY,
- type_info_name.OBJECT_ARRAY,
- type_info_name.LIST):
+ type_info_name.OBJECT_ARRAY):
+ return ArrayDataConverter(from_type_info_proto(type_info.collection_element_type))
+ elif type_info == type_info_name.LIST:
return ListDataConverter(from_type_info_proto(type_info.collection_element_type))
elif type_name == type_info_name.MAP:
return DictDataConverter(from_type_info_proto(type_info.map_type_info.key_type),
@@ -214,9 +228,23 @@ def from_field_type_proto(field_type):
[from_field_type_proto(f.type) for f in field_type.row_schema.fields],
[f.name for f in field_type.row_schema.fields])
elif type_name == schema_type_name.BASIC_ARRAY:
- return ListDataConverter(from_field_type_proto(field_type.collection_element_type))
+ return ArrayDataConverter(from_field_type_proto(field_type.collection_element_type))
elif type_name == schema_type_name.MAP:
return DictDataConverter(from_field_type_proto(field_type.map_info.key_type),
from_field_type_proto(field_type.map_info.value_type))
return IdentityDataConverter()
+
+
+def from_type_info(type_info: TypeInformation):
+ if isinstance(type_info, (PickledBytesTypeInfo, RowTypeInfo, TupleTypeInfo)):
+ return PickleDataConverter()
+ elif isinstance(type_info, (PrimitiveArrayTypeInfo, BasicArrayTypeInfo, ObjectArrayTypeInfo)):
+ return ArrayDataConverter(from_type_info(type_info._element_type))
+ elif isinstance(type_info, ListTypeInfo):
+ return ListDataConverter(from_type_info(type_info.elem_type))
+ elif isinstance(type_info, MapTypeInfo):
+ return DictDataConverter(from_type_info(type_info._key_type_info),
+ from_type_info(type_info._value_type_info))
+
+ return IdentityDataConverter()
diff --git a/flink-python/pyflink/fn_execution/embedded/java_utils.py b/flink-python/pyflink/fn_execution/embedded/java_utils.py
new file mode 100644
index 00000000000..099cca8f1f3
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/embedded/java_utils.py
@@ -0,0 +1,204 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pemja import findClass
+
+from pyflink.common.typeinfo import (TypeInformation, Types, BasicTypeInfo, BasicType,
+ PrimitiveArrayTypeInfo, BasicArrayTypeInfo,
+ ObjectArrayTypeInfo, MapTypeInfo)
+from pyflink.datastream.state import (StateDescriptor, ValueStateDescriptor,
+ ReducingStateDescriptor,
+ AggregatingStateDescriptor, ListStateDescriptor,
+ MapStateDescriptor, StateTtlConfig)
+
+# Java Types Class
+JTypes = findClass('org.apache.flink.api.common.typeinfo.Types')
+JPrimitiveArrayTypeInfo = findClass('org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo')
+JBasicArrayTypeInfo = findClass('org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo')
+JPickledByteArrayTypeInfo = findClass('org.apache.flink.streaming.api.typeinfo.python.'
+ 'PickledByteArrayTypeInfo')
+JMapTypeInfo = findClass('org.apache.flink.api.java.typeutils.MapTypeInfo')
+
+# Java State Descriptor Class
+JValueStateDescriptor = findClass('org.apache.flink.api.common.state.ValueStateDescriptor')
+JListStateDescriptor = findClass('org.apache.flink.api.common.state.ListStateDescriptor')
+JMapStateDescriptor = findClass('org.apache.flink.api.common.state.MapStateDescriptor')
+
+# Java StateTtlConfig
+JStateTtlConfig = findClass('org.apache.flink.api.common.state.StateTtlConfig')
+JTime = findClass('org.apache.flink.api.common.time.Time')
+JUpdateType = findClass('org.apache.flink.api.common.state.StateTtlConfig$UpdateType')
+JStateVisibility = findClass('org.apache.flink.api.common.state.StateTtlConfig$StateVisibility')
+
+
+def to_java_typeinfo(type_info: TypeInformation):
+ if isinstance(type_info, BasicTypeInfo):
+ basic_type = type_info._basic_type
+
+ if basic_type == BasicType.STRING:
+ j_typeinfo = JTypes.STRING
+ elif basic_type == BasicType.BYTE:
+ j_typeinfo = JTypes.LONG
+ elif basic_type == BasicType.BOOLEAN:
+ j_typeinfo = JTypes.BOOLEAN
+ elif basic_type == BasicType.SHORT:
+ j_typeinfo = JTypes.LONG
+ elif basic_type == BasicType.INT:
+ j_typeinfo = JTypes.LONG
+ elif basic_type == BasicType.LONG:
+ j_typeinfo = JTypes.LONG
+ elif basic_type == BasicType.FLOAT:
+ j_typeinfo = JTypes.DOUBLE
+ elif basic_type == BasicType.DOUBLE:
+ j_typeinfo = JTypes.DOUBLE
+ elif basic_type == BasicType.CHAR:
+ j_typeinfo = JTypes.STRING
+ elif basic_type == BasicType.BIG_INT:
+ j_typeinfo = JTypes.BIG_INT
+ elif basic_type == BasicType.BIG_DEC:
+ j_typeinfo = JTypes.BIG_DEC
+ elif basic_type == BasicType.INSTANT:
+ j_typeinfo = JTypes.INSTANT
+ else:
+ raise TypeError("Invalid BasicType %s." % basic_type)
+
+ elif isinstance(type_info, PrimitiveArrayTypeInfo):
+ element_type = type_info._element_type
+
+ if element_type == Types.BOOLEAN():
+ j_typeinfo = JPrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.BYTE():
+ j_typeinfo = JPrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.SHORT():
+ j_typeinfo = JPrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.INT():
+ j_typeinfo = JPrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.LONG():
+ j_typeinfo = JPrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.FLOAT():
+ j_typeinfo = JPrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.DOUBLE():
+ j_typeinfo = JPrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO
+ elif element_type == Types.CHAR():
+ j_typeinfo = JPrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO
+ else:
+ raise TypeError("Invalid element type for a primitive array.")
+
+ elif isinstance(type_info, BasicArrayTypeInfo):
+ element_type = type_info._element_type
+
+ if element_type == Types.BOOLEAN():
+ j_typeinfo = JBasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO
+ elif element_type == Types.BYTE():
+ j_typeinfo = JBasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO
+ elif element_type == Types.SHORT():
+ j_typeinfo = JBasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO
+ elif element_type == Types.INT():
+ j_typeinfo = JBasicArrayTypeInfo.INT_ARRAY_TYPE_INFO
+ elif element_type == Types.LONG():
+ j_typeinfo = JBasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO
+ elif element_type == Types.FLOAT():
+ j_typeinfo = JBasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO
+ elif element_type == Types.DOUBLE():
+ j_typeinfo = JBasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO
+ elif element_type == Types.CHAR():
+ j_typeinfo = JBasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO
+ elif element_type == Types.STRING():
+ j_typeinfo = JBasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
+ else:
+ raise TypeError("Invalid element type for a basic array.")
+
+ elif isinstance(type_info, ObjectArrayTypeInfo):
+ element_type = type_info._element_type
+
+ j_typeinfo = JTypes.OBJECT_ARRAY(to_java_typeinfo(element_type))
+
+ elif isinstance(type_info, MapTypeInfo):
+ j_key_typeinfo = to_java_typeinfo(type_info._key_type_info)
+ j_value_typeinfo = to_java_typeinfo(type_info._value_type_info)
+
+ j_typeinfo = JMapTypeInfo(j_key_typeinfo, j_value_typeinfo)
+ else:
+ j_typeinfo = JPickledByteArrayTypeInfo.PICKLED_BYTE_ARRAY_TYPE_INFO
+
+ return j_typeinfo
+
+
+def to_java_state_ttl_config(ttl_config: StateTtlConfig):
+ j_ttl_config_builder = JStateTtlConfig.newBuilder(
+ JTime.milliseconds(ttl_config.get_ttl().to_milliseconds()))
+
+ update_type = ttl_config.get_update_type()
+ if update_type == StateTtlConfig.UpdateType.Disabled:
+ j_ttl_config_builder.setUpdateType(JUpdateType.Disabled)
+ elif update_type == StateTtlConfig.UpdateType.OnCreateAndWrite:
+ j_ttl_config_builder.setUpdateType(JUpdateType.OnCreateAndWrite)
+ elif update_type == StateTtlConfig.UpdateType.OnReadAndWrite:
+ j_ttl_config_builder.setUpdateType(JUpdateType.OnReadAndWrite)
+
+ state_visibility = ttl_config.get_state_visibility()
+ if state_visibility == StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp:
+ j_ttl_config_builder.setStateVisibility(JStateVisibility.ReturnExpiredIfNotCleanedUp)
+ elif state_visibility == StateTtlConfig.StateVisibility.NeverReturnExpired:
+ j_ttl_config_builder.setStateVisibility(JStateVisibility.NeverReturnExpired)
+
+ cleanup_strategies = ttl_config.get_cleanup_strategies()
+ if not cleanup_strategies.is_cleanup_in_background():
+ j_ttl_config_builder.disableCleanupInBackground()
+
+ if cleanup_strategies.in_full_snapshot():
+ j_ttl_config_builder.cleanupFullSnapshot()
+
+ incremental_cleanup_strategy = cleanup_strategies.get_incremental_cleanup_strategy()
+ if incremental_cleanup_strategy:
+ j_ttl_config_builder.cleanupIncrementally(
+ incremental_cleanup_strategy.get_cleanup_size(),
+ incremental_cleanup_strategy.run_cleanup_for_every_record())
+
+ rocksdb_compact_filter_cleanup_strategy = \
+ cleanup_strategies.get_rocksdb_compact_filter_cleanup_strategy()
+
+ if rocksdb_compact_filter_cleanup_strategy:
+ j_ttl_config_builder.cleanupInRocksdbCompactFilter(
+ rocksdb_compact_filter_cleanup_strategy.get_query_time_after_num_entries())
+
+ return j_ttl_config_builder.build()
+
+
+def to_java_state_descriptor(state_descriptor: StateDescriptor):
+ if isinstance(state_descriptor,
+ (ValueStateDescriptor, ReducingStateDescriptor, AggregatingStateDescriptor)):
+ value_type_info = to_java_typeinfo(state_descriptor.type_info)
+ j_state_descriptor = JValueStateDescriptor(state_descriptor.name, value_type_info)
+
+ elif isinstance(state_descriptor, ListStateDescriptor):
+ element_type_info = to_java_typeinfo(state_descriptor.type_info.elem_type)
+ j_state_descriptor = JListStateDescriptor(state_descriptor.name, element_type_info)
+
+ elif isinstance(state_descriptor, MapStateDescriptor):
+ key_type_info = to_java_typeinfo(state_descriptor.type_info._key_type_info)
+ value_type_info = to_java_typeinfo(state_descriptor.type_info._value_type_info)
+ j_state_descriptor = JMapStateDescriptor(
+ state_descriptor.name, key_type_info, value_type_info)
+ else:
+ raise Exception("Unknown supported state_descriptor {0}".format(state_descriptor))
+
+ if state_descriptor._ttl_config:
+ j_state_ttl_config = to_java_state_ttl_config(state_descriptor._ttl_config)
+ j_state_descriptor.enableTimeToLive(j_state_ttl_config)
+
+ return j_state_descriptor
diff --git a/flink-python/pyflink/fn_execution/embedded/operation_utils.py b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
index 0d666693fbc..c6a96d2ad67 100644
--- a/flink-python/pyflink/fn_execution/embedded/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
@@ -101,8 +101,8 @@ def create_table_operation_from_proto(proto, input_coder_info, output_coder_into
def create_one_input_user_defined_data_stream_function_from_protos(
- function_urn, function_infos, input_coder_info, output_coder_info, runtime_context,
- function_context, job_parameters):
+ function_infos, input_coder_info, output_coder_info, runtime_context,
+ function_context, timer_context, job_parameters):
serialized_fns = [pare_user_defined_data_stream_function_proto(proto)
for proto in function_infos]
input_data_converter = (
@@ -111,12 +111,12 @@ def create_one_input_user_defined_data_stream_function_from_protos(
from_type_info_proto(parse_coder_proto(output_coder_info).raw_type.type_info))
function_operation = OneInputFunctionOperation(
- function_urn,
serialized_fns,
input_data_converter,
output_data_converter,
runtime_context,
function_context,
+ timer_context,
job_parameters)
return function_operation
diff --git a/flink-python/pyflink/fn_execution/embedded/operations.py b/flink-python/pyflink/fn_execution/embedded/operations.py
index 1ece4cf6e11..919195c16b6 100644
--- a/flink-python/pyflink/fn_execution/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/embedded/operations.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from pyflink.fn_execution.datastream.embedded.operations import OneInputOperation
+from pyflink.fn_execution.datastream.embedded.operations import extract_process_function
from pyflink.fn_execution.embedded.converters import DataConverter
@@ -36,22 +36,27 @@ class FunctionOperation(object):
for operation in self._operations:
operation.close()
+ def on_timer(self, timestamp):
+ for operation in self._operations:
+ for item in operation.on_timer(timestamp):
+ yield self._output_data_converter.to_external(item)
+
class OneInputFunctionOperation(FunctionOperation):
def __init__(self,
- function_urn,
serialized_fns,
input_data_converter: DataConverter,
output_data_converter: DataConverter,
runtime_context,
function_context,
+ timer_context,
job_parameters):
operations = (
- [OneInputOperation(
- function_urn,
+ [extract_process_function(
serialized_fn,
runtime_context,
function_context,
+ timer_context,
job_parameters)
for serialized_fn in serialized_fns])
super(OneInputFunctionOperation, self).__init__(
diff --git a/flink-python/setup.py b/flink-python/setup.py
index 81c8b463412..acc377bb6d3 100644
--- a/flink-python/setup.py
+++ b/flink-python/setup.py
@@ -309,7 +309,7 @@ try:
'cloudpickle==2.1.0', 'avro-python3>=1.8.1,!=1.9.2,<1.10.0',
'pytz>=2018.3', 'fastavro>=1.1.0,<1.4.8', 'requests>=2.26.0',
'protobuf<3.18',
- 'pemja==0.2.0;'
+ 'pemja==0.2.2;'
'python_full_version >= "3.7" and platform_system != "Windows"',
'httplib2>=0.19.0,<=0.20.4', apache_flink_libraries_dependency]
diff --git a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
index 272e929c693..613cc6459c5 100644
--- a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
+++ b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
@@ -33,6 +33,7 @@ import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.embedded.AbstractEmbeddedDataStreamPythonFunctionOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedProcessOperator;
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.AbstractExternalDataStreamPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonCoProcessOperator;
@@ -423,7 +424,8 @@ public class PythonOperatorChainingOptimizer {
|| upOperator instanceof ExternalPythonProcessOperator
|| upOperator instanceof ExternalPythonCoProcessOperator))
|| (downOperator instanceof EmbeddedPythonProcessOperator
- && upOperator instanceof EmbeddedPythonProcessOperator);
+ && (upOperator instanceof EmbeddedPythonProcessOperator
+ || upOperator instanceof EmbeddedPythonKeyedProcessOperator));
}
private static boolean arePythonOperatorsInSameExecutionEnvironment(
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java
index 22df29dec4d..c5205987375 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.api.operators.python.embedded;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
import org.apache.flink.util.Preconditions;
@@ -28,8 +29,7 @@ import org.apache.flink.util.Preconditions;
import java.util.HashMap;
import java.util.Map;
-import static org.apache.flink.python.Constants.STATEFUL_FUNCTION_URN;
-import static org.apache.flink.python.Constants.STATELESS_FUNCTION_URN;
+import static org.apache.flink.streaming.api.utils.PythonOperatorUtils.inBatchExecutionMode;
/** Base class for all Python DataStream operators executed in embedded Python environment. */
@Internal
@@ -47,23 +47,14 @@ public abstract class AbstractEmbeddedDataStreamPythonFunctionOperator<OUT>
/** The TypeInformation of output data. */
protected final TypeInformation<OUT> outputTypeInfo;
- /** The function urn. */
- final String functionUrn;
-
/** The number of partitions for the partition custom function. */
private Integer numPartitions;
public AbstractEmbeddedDataStreamPythonFunctionOperator(
- String functionUrn,
Configuration config,
DataStreamPythonFunctionInfo pythonFunctionInfo,
TypeInformation<OUT> outputTypeInfo) {
super(config);
- Preconditions.checkArgument(
- STATELESS_FUNCTION_URN.equals(functionUrn)
- || STATEFUL_FUNCTION_URN.equals(functionUrn),
- "The function urn should be `STATELESS_FUNCTION_URN` or `STATEFUL_FUNCTION_URN`.");
- this.functionUrn = functionUrn;
this.pythonFunctionInfo = Preconditions.checkNotNull(pythonFunctionInfo);
this.outputTypeInfo = Preconditions.checkNotNull(outputTypeInfo);
}
@@ -88,6 +79,13 @@ public abstract class AbstractEmbeddedDataStreamPythonFunctionOperator<OUT>
if (numPartitions != null) {
jobParameters.put(NUM_PARTITIONS, String.valueOf(numPartitions));
}
+
+ KeyedStateBackend<Object> keyedStateBackend = getKeyedStateBackend();
+ if (keyedStateBackend != null) {
+ jobParameters.put(
+ "inBatchExecutionMode",
+ String.valueOf(inBatchExecutionMode(keyedStateBackend)));
+ }
return jobParameters;
}
}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
index ec886ef29b5..42b096827f3 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
@@ -53,19 +53,18 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
private transient PythonTypeUtils.DataConverter<IN, Object> inputDataConverter;
- private transient PythonTypeUtils.DataConverter<OUT, Object> outputDataConverter;
+ transient PythonTypeUtils.DataConverter<OUT, Object> outputDataConverter;
- private transient TimestampedCollector<OUT> collector;
+ protected transient TimestampedCollector<OUT> collector;
protected transient long timestamp;
public AbstractOneInputEmbeddedPythonFunctionOperator(
- String functionUrn,
Configuration config,
DataStreamPythonFunctionInfo pythonFunctionInfo,
TypeInformation<IN> inputTypeInfo,
TypeInformation<OUT> outputTypeInfo) {
- super(functionUrn, config, pythonFunctionInfo, outputTypeInfo);
+ super(config, pythonFunctionInfo, outputTypeInfo);
this.inputTypeInfo = Preconditions.checkNotNull(inputTypeInfo);
}
@@ -84,7 +83,6 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
@Override
public void openPythonInterpreter() {
- // function_urn = ...
// function_protos = ...
// input_coder_info = ...
// output_coder_info = ...
@@ -95,13 +93,11 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
// from pyflink.fn_execution.embedded.operation_utils import
// create_one_input_user_defined_data_stream_function_from_protos
//
- // operation = create_one_input_user_defined_data_stream_function_from_protos(function_urn,
+ // operation = create_one_input_user_defined_data_stream_function_from_protos(
// function_protos, input_coder_info, output_coder_info, runtime_context,
// function_context, job_parameters)
// operation.open()
- interpreter.set("function_urn", functionUrn);
-
interpreter.set(
"function_protos",
createUserDefinedFunctionsProto().stream()
@@ -127,6 +123,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
interpreter.set("runtime_context", getRuntimeContext());
interpreter.set("function_context", getFunctionContext());
+ interpreter.set("timer_context", getTimerContext());
interpreter.set("job_parameters", getJobParameters());
interpreter.exec(
@@ -134,12 +131,12 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
interpreter.exec(
"operation = create_one_input_user_defined_data_stream_function_from_protos("
- + "function_urn,"
+ "function_protos,"
+ "input_coder_info,"
+ "output_coder_info,"
+ "runtime_context,"
+ "function_context,"
+ + "timer_context,"
+ "job_parameters)");
interpreter.invokeMethod("operation", "open");
@@ -153,7 +150,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
}
@Override
- public void processElement(StreamRecord<IN> element) {
+ public void processElement(StreamRecord<IN> element) throws Exception {
collector.setTimestamp(element);
timestamp = element.getTimestamp();
@@ -169,6 +166,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
OUT result = outputDataConverter.toInternal(results.next());
collector.collect(result);
}
+ results.close();
}
TypeInformation<IN> getInputTypeInfo() {
@@ -181,4 +179,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
/** Gets The function context. */
public abstract Object getFunctionContext();
+
+ /** Gets The Timer context. */
+ public abstract Object getTimerContext();
}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java
new file mode 100644
index 00000000000..76fdf58058a
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python.embedded;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.python.util.ProtoUtils;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.SimpleTimerService;
+import org.apache.flink.streaming.api.TimeDomain;
+import org.apache.flink.streaming.api.TimerService;
+import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.InternalTimer;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.api.utils.PythonTypeUtils;
+import org.apache.flink.types.Row;
+
+import pemja.core.object.PyIterator;
+
+import java.util.List;
+
+import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
+import static org.apache.flink.python.PythonOptions.MAP_STATE_WRITE_CACHE_SIZE;
+import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
+import static org.apache.flink.python.PythonOptions.PYTHON_PROFILE_ENABLED;
+import static org.apache.flink.python.PythonOptions.STATE_CACHE_SIZE;
+import static org.apache.flink.streaming.api.utils.PythonOperatorUtils.inBatchExecutionMode;
+
+/**
+ * {@link EmbeddedPythonKeyedProcessOperator} is responsible for executing user defined python
+ * KeyedProcessFunction in embedded Python environment. It is also able to handle the timer and
+ * state request from the python stateful user defined function.
+ */
+@Internal
+public class EmbeddedPythonKeyedProcessOperator<K, IN, OUT>
+ extends AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
+ implements Triggerable<K, VoidNamespace> {
+
+ private static final long serialVersionUID = 1L;
+
+ /** The TypeInformation of the key. */
+ private transient TypeInformation<K> keyTypeInfo;
+
+ private transient ContextImpl context;
+
+ private transient OnTimerContextImpl onTimerContext;
+
+ private transient PythonTypeUtils.DataConverter<K, Object> keyConverter;
+
+ public EmbeddedPythonKeyedProcessOperator(
+ Configuration config,
+ DataStreamPythonFunctionInfo pythonFunctionInfo,
+ TypeInformation<IN> inputTypeInfo,
+ TypeInformation<OUT> outputTypeInfo) {
+ super(config, pythonFunctionInfo, inputTypeInfo, outputTypeInfo);
+ }
+
+ @Override
+ public void open() throws Exception {
+ keyTypeInfo = ((RowTypeInfo) this.getInputTypeInfo()).getTypeAt(0);
+
+ keyConverter = PythonTypeUtils.TypeInfoToDataConverter.typeInfoDataConverter(keyTypeInfo);
+
+ InternalTimerService<VoidNamespace> internalTimerService =
+ getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this);
+
+ TimerService timerService = new SimpleTimerService(internalTimerService);
+
+ context = new ContextImpl(timerService);
+
+ onTimerContext = new OnTimerContextImpl(timerService);
+
+ super.open();
+ }
+
+ @Override
+ public List<FlinkFnApi.UserDefinedDataStreamFunction> createUserDefinedFunctionsProto() {
+ return ProtoUtils.createUserDefinedDataStreamStatefulFunctionProtos(
+ getPythonFunctionInfo(),
+ getRuntimeContext(),
+ getJobParameters(),
+ keyTypeInfo,
+ inBatchExecutionMode(getKeyedStateBackend()),
+ config.get(PYTHON_METRIC_ENABLED),
+ config.get(PYTHON_PROFILE_ENABLED),
+ false,
+ config.get(STATE_CACHE_SIZE),
+ config.get(MAP_STATE_READ_CACHE_SIZE),
+ config.get(MAP_STATE_WRITE_CACHE_SIZE));
+ }
+
+ @Override
+ public void onEventTime(InternalTimer<K, VoidNamespace> timer) throws Exception {
+ collector.setAbsoluteTimestamp(timer.getTimestamp());
+ invokeUserFunction(TimeDomain.EVENT_TIME, timer);
+ }
+
+ @Override
+ public void onProcessingTime(InternalTimer<K, VoidNamespace> timer) throws Exception {
+ collector.eraseTimestamp();
+ invokeUserFunction(TimeDomain.PROCESSING_TIME, timer);
+ }
+
+ @Override
+ public Object getFunctionContext() {
+ return context;
+ }
+
+ @Override
+ public Object getTimerContext() {
+ return onTimerContext;
+ }
+
+ @Override
+ public <T> AbstractEmbeddedDataStreamPythonFunctionOperator<T> copy(
+ DataStreamPythonFunctionInfo pythonFunctionInfo, TypeInformation<T> outputTypeInfo) {
+ return new EmbeddedPythonKeyedProcessOperator<>(
+ config, pythonFunctionInfo, getInputTypeInfo(), outputTypeInfo);
+ }
+
+ private void invokeUserFunction(TimeDomain timeDomain, InternalTimer<K, VoidNamespace> timer)
+ throws Exception {
+ onTimerContext.timeDomain = timeDomain;
+ onTimerContext.timer = timer;
+ PyIterator results =
+ (PyIterator)
+ interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp());
+
+ while (results.hasNext()) {
+ OUT result = outputDataConverter.toInternal(results.next());
+ collector.collect(result);
+ }
+ results.close();
+
+ onTimerContext.timeDomain = null;
+ onTimerContext.timer = null;
+ }
+
+ private class ContextImpl {
+
+ private final TimerService timerService;
+
+ ContextImpl(TimerService timerService) {
+ this.timerService = timerService;
+ }
+
+ public long timestamp() {
+ return timestamp;
+ }
+
+ public TimerService timerService() {
+ return timerService;
+ }
+
+ @SuppressWarnings("unchecked")
+ public Object getCurrentKey() {
+ return keyConverter.toExternal(
+ (K)
+ ((Row) EmbeddedPythonKeyedProcessOperator.this.getCurrentKey())
+ .getField(0));
+ }
+ }
+
+ private class OnTimerContextImpl {
+
+ private final TimerService timerService;
+
+ private TimeDomain timeDomain;
+
+ private InternalTimer<K, VoidNamespace> timer;
+
+ OnTimerContextImpl(TimerService timerService) {
+ this.timerService = timerService;
+ }
+
+ public long timestamp() {
+ return timer.getTimestamp();
+ }
+
+ public TimerService timerService() {
+ return timerService;
+ }
+
+ public int timeDomain() {
+ return timeDomain.ordinal();
+ }
+
+ @SuppressWarnings("unchecked")
+ public Object getCurrentKey() {
+ return keyConverter.toExternal((K) ((Row) timer.getKey()).getField(0));
+ }
+ }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java
index ad0ceebb98a..50e4feb921c 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java
@@ -31,7 +31,6 @@ import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import java.util.HashMap;
import java.util.List;
-import static org.apache.flink.python.Constants.STATELESS_FUNCTION_URN;
import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
import static org.apache.flink.python.PythonOptions.MAP_STATE_WRITE_CACHE_SIZE;
import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
@@ -57,7 +56,7 @@ public class EmbeddedPythonProcessOperator<IN, OUT>
DataStreamPythonFunctionInfo pythonFunctionInfo,
TypeInformation<IN> inputTypeInfo,
TypeInformation<OUT> outputTypeInfo) {
- super(STATELESS_FUNCTION_URN, config, pythonFunctionInfo, inputTypeInfo, outputTypeInfo);
+ super(config, pythonFunctionInfo, inputTypeInfo, outputTypeInfo);
}
@Override
@@ -87,6 +86,11 @@ public class EmbeddedPythonProcessOperator<IN, OUT>
return context;
}
+ @Override
+ public Object getTimerContext() {
+ return null;
+ }
+
@Override
public <T> AbstractEmbeddedDataStreamPythonFunctionOperator<T> copy(
DataStreamPythonFunctionInfo pythonFunctionInfo, TypeInformation<T> outputTypeInfo) {
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java
index 77d5049cfc6..756b1cb4645 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java
@@ -138,7 +138,7 @@ public class EmbeddedPythonTableFunctionOperator extends AbstractEmbeddedStatele
@SuppressWarnings("unchecked")
@Override
- public void processElement(StreamRecord<RowData> element) {
+ public void processElement(StreamRecord<RowData> element) throws Exception {
RowData value = element.getValue();
for (int i = 0; i < userDefinedFunctionInputArgs.length; i++) {
@@ -165,5 +165,6 @@ public class EmbeddedPythonTableFunctionOperator extends AbstractEmbeddedStatele
} else if (joinType == FlinkJoinType.LEFT) {
rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseNullResultRowData));
}
+ udtfResults.close();
}
}
diff --git a/flink-python/src/main/resources/META-INF/NOTICE b/flink-python/src/main/resources/META-INF/NOTICE
index 5a0397c922b..2911e2124b2 100644
--- a/flink-python/src/main/resources/META-INF/NOTICE
+++ b/flink-python/src/main/resources/META-INF/NOTICE
@@ -28,7 +28,7 @@ This project bundles the following dependencies under the Apache Software Licens
- org.apache.beam:beam-vendor-sdks-java-extensions-protobuf:2.38.0
- org.apache.beam:beam-vendor-guava-26_0-jre:0.1
- org.apache.beam:beam-vendor-grpc-1_43_2:0.1
-- com.alibaba:pemja:0.2.0
+- com.alibaba:pemja:0.2.2
This project bundles the following dependencies under the BSD license.
See bundled license files for details