You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/07/29 11:27:25 UTC

[flink] branch master updated: [FLINK-28559][python] Support DataStream PythonKeyedProcessOperator in Thread Mode

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 4cc33f97d13 [FLINK-28559][python] Support DataStream PythonKeyedProcessOperator in Thread Mode
4cc33f97d13 is described below

commit 4cc33f97d1351be4836e973a528042684475a069
Author: huangxingbo <hx...@apache.org>
AuthorDate: Wed Jul 27 16:21:46 2022 +0800

    [FLINK-28559][python] Support DataStream PythonKeyedProcessOperator in Thread Mode
    
    This closes #20375.
---
 flink-python/dev/dev-requirements.txt              |   2 +-
 flink-python/pom.xml                               |   2 +-
 flink-python/pyflink/datastream/data_stream.py     |  60 +--
 flink-python/pyflink/datastream/state.py           |   4 +-
 .../pyflink/datastream/tests/test_data_stream.py   | 420 +++++++++++----------
 .../fn_execution/datastream/embedded/operations.py |  82 ++--
 .../datastream/embedded/process_function.py        |  42 ++-
 .../datastream/embedded/runtime_context.py         |  34 +-
 .../fn_execution/datastream/embedded/state_impl.py | 158 ++++++++
 .../{process_function.py => timerservice_impl.py}  |  30 +-
 .../pyflink/fn_execution/embedded/converters.py    |  56 ++-
 .../pyflink/fn_execution/embedded/java_utils.py    | 204 ++++++++++
 .../fn_execution/embedded/operation_utils.py       |   6 +-
 .../pyflink/fn_execution/embedded/operations.py    |  13 +-
 flink-python/setup.py                              |   2 +-
 .../chain/PythonOperatorChainingOptimizer.java     |   4 +-
 ...ctEmbeddedDataStreamPythonFunctionOperator.java |  20 +-
 ...ractOneInputEmbeddedPythonFunctionOperator.java |  21 +-
 .../EmbeddedPythonKeyedProcessOperator.java        | 214 +++++++++++
 .../embedded/EmbeddedPythonProcessOperator.java    |   8 +-
 .../table/EmbeddedPythonTableFunctionOperator.java |   3 +-
 flink-python/src/main/resources/META-INF/NOTICE    |   2 +-
 22 files changed, 1057 insertions(+), 330 deletions(-)

diff --git a/flink-python/dev/dev-requirements.txt b/flink-python/dev/dev-requirements.txt
index da91f005bb3..63a5852c946 100755
--- a/flink-python/dev/dev-requirements.txt
+++ b/flink-python/dev/dev-requirements.txt
@@ -31,6 +31,6 @@ numpy>=1.14.3,<1.20; python_version < '3.7'
 fastavro>=1.1.0,<1.4.8
 grpcio>=1.29.0,<1.47
 grpcio-tools>=1.3.5,<=1.14.2
-pemja==0.2.0; python_version >= '3.7' and platform_system != 'Windows'
+pemja==0.2.2; python_version >= '3.7' and platform_system != 'Windows'
 httplib2>=0.19.0,<=0.20.4
 protobuf<3.18
\ No newline at end of file
diff --git a/flink-python/pom.xml b/flink-python/pom.xml
index fa5fe677285..78a08113aa7 100644
--- a/flink-python/pom.xml
+++ b/flink-python/pom.xml
@@ -110,7 +110,7 @@ under the License.
 		<dependency>
 			<groupId>com.alibaba</groupId>
 			<artifactId>pemja</artifactId>
-			<version>0.2.0</version>
+			<version>0.2.2</version>
 		</dependency>
 
 		<!-- Protobuf dependencies -->
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 15418960590..b893308abc3 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -49,8 +49,8 @@ from pyflink.datastream.functions import (_get_python_env, FlatMapFunction, MapF
                                           InternalSingleValueProcessAllWindowFunction)
 from pyflink.datastream.output_tag import OutputTag
 from pyflink.datastream.slot_sharing_group import SlotSharingGroup
-from pyflink.datastream.state import ValueStateDescriptor, ValueState, ListStateDescriptor, \
-    StateDescriptor, ReducingStateDescriptor, AggregatingStateDescriptor, MapStateDescriptor
+from pyflink.datastream.state import (ListStateDescriptor, StateDescriptor, ReducingStateDescriptor,
+                                      AggregatingStateDescriptor, MapStateDescriptor, ReducingState)
 from pyflink.datastream.utils import convert_to_python_obj
 from pyflink.datastream.window import (CountTumblingWindowAssigner, CountSlidingWindowAssigner,
                                        CountWindowSerializer, TimeWindowSerializer, Trigger,
@@ -1202,6 +1202,12 @@ class KeyedStream(DataStream):
 
         output_type = _from_java_type(self._original_data_type_info.get_java_type_info())
 
+        gateway = get_gateway()
+        j_conf = get_j_env_configuration(self._j_data_stream.getExecutionEnvironment())
+        python_execution_mode = (
+            j_conf.getString(
+                gateway.jvm.org.apache.flink.python.PythonOptions.PYTHON_EXECUTION_MODE))
+
         class ReduceProcessKeyedProcessFunctionAdapter(KeyedProcessFunction):
 
             def __init__(self, reduce_function):
@@ -1213,40 +1219,47 @@ class KeyedStream(DataStream):
                     self._open_func = None
                     self._close_func = None
                     self._reduce_function = reduce_function
-                self._reduce_value_state = None  # type: ValueState
+                self._reduce_state = None  # type: ReducingState
+                self._in_batch_execution_mode = True
 
             def open(self, runtime_context: RuntimeContext):
                 if self._open_func:
                     self._open_func(runtime_context)
 
-                self._reduce_value_state = runtime_context.get_state(
-                    ValueStateDescriptor("_reduce_state" + str(uuid.uuid4()), output_type))
-                from pyflink.fn_execution.datastream.process.runtime_context import (
-                    StreamingRuntimeContext)
-                self._in_batch_execution_mode = \
-                    cast(StreamingRuntimeContext, runtime_context)._in_batch_execution_mode
+                self._reduce_state = runtime_context.get_reducing_state(
+                    ReducingStateDescriptor(
+                        "_reduce_state" + str(uuid.uuid4()),
+                        self._reduce_function,
+                        output_type))
+
+                if python_execution_mode == "process":
+                    from pyflink.fn_execution.datastream.process.runtime_context import (
+                        StreamingRuntimeContext)
+                    self._in_batch_execution_mode = (
+                        cast(StreamingRuntimeContext, runtime_context)._in_batch_execution_mode)
+                else:
+                    self._in_batch_execution_mode = runtime_context.get_job_parameter(
+                        "inBatchExecutionMode", "false") == "true"
 
             def close(self):
                 if self._close_func:
                     self._close_func()
 
             def process_element(self, value, ctx: 'KeyedProcessFunction.Context'):
-                reduce_value = self._reduce_value_state.value()
-                if reduce_value is not None:
-                    reduce_value = self._reduce_function(reduce_value, value)
-                else:
-                    # register a timer for emitting the result at the end when this is the
-                    # first input for this key
-                    if self._in_batch_execution_mode:
+                if self._in_batch_execution_mode:
+                    reduce_value = self._reduce_state.get()
+                    if reduce_value is None:
+                        # register a timer for emitting the result at the end when this is the
+                        # first input for this key
                         ctx.timer_service().register_event_time_timer(0x7fffffffffffffff)
-                    reduce_value = value
-                self._reduce_value_state.update(reduce_value)
-                if not self._in_batch_execution_mode:
+                    self._reduce_state.add(value)
+                else:
+                    self._reduce_state.add(value)
                     # only emitting the result when all the data for a key is received
-                    yield reduce_value
+                    yield self._reduce_state.get()
 
             def on_timer(self, timestamp: int, ctx: 'KeyedProcessFunction.OnTimerContext'):
-                current_value = self._reduce_value_state.value()
+                current_value = self._reduce_state.get()
                 if current_value is not None:
                     yield current_value
 
@@ -2708,7 +2721,10 @@ def _get_one_input_stream_operator(data_stream: DataStream,
         else:
             JDataStreamPythonFunctionOperator = gateway.jvm.ExternalPythonProcessOperator
     elif func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:  # type: ignore
-        JDataStreamPythonFunctionOperator = gateway.jvm.ExternalPythonKeyedProcessOperator
+        if python_execution_mode == 'thread':
+            JDataStreamPythonFunctionOperator = gateway.jvm.EmbeddedPythonKeyedProcessOperator
+        else:
+            JDataStreamPythonFunctionOperator = gateway.jvm.ExternalPythonKeyedProcessOperator
     elif func_type == UserDefinedDataStreamFunction.WINDOW:  # type: ignore
         window_serializer = typing.cast(WindowOperationDescriptor, func).window_serializer
         if isinstance(window_serializer, TimeWindowSerializer):
diff --git a/flink-python/pyflink/datastream/state.py b/flink-python/pyflink/datastream/state.py
index b60bcdeb6ee..f38b9963bd6 100644
--- a/flink-python/pyflink/datastream/state.py
+++ b/flink-python/pyflink/datastream/state.py
@@ -899,14 +899,14 @@ class StateTtlConfig(object):
             Configuration of cleanup strategy while taking the full snapshot.
             """
 
-            def __init__(self, cleanup_size: int, run_cleanup_for_every_record: int):
+            def __init__(self, cleanup_size: int, run_cleanup_for_every_record: bool):
                 self._cleanup_size = cleanup_size
                 self._run_cleanup_for_every_record = run_cleanup_for_every_record
 
             def get_cleanup_size(self) -> int:
                 return self._cleanup_size
 
-            def run_cleanup_for_every_record(self) -> int:
+            def run_cleanup_for_every_record(self) -> bool:
                 return self._run_cleanup_for_every_record
 
         class RocksdbCompactFilterCleanupStrategy(CleanupStrategy):
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index d302e2bf18d..beba02e1720 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -164,6 +164,210 @@ class DataStreamTests(object):
 
         self.env.execute('test_partition_custom')
 
+    def test_keyed_process_function_with_state(self):
+        self.env.get_config().set_auto_watermark_interval(2000)
+        self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime)
+        data_stream = self.env.from_collection([(1, 'hi', '1603708211000'),
+                                                (2, 'hello', '1603708224000'),
+                                                (3, 'hi', '1603708226000'),
+                                                (4, 'hello', '1603708289000'),
+                                                (5, 'hi', '1603708291000'),
+                                                (6, 'hello', '1603708293000')],
+                                               type_info=Types.ROW([Types.INT(), Types.STRING(),
+                                                                    Types.STRING()]))
+
+        class MyTimestampAssigner(TimestampAssigner):
+
+            def extract_timestamp(self, value, record_timestamp) -> int:
+                return int(value[2])
+
+        class MyProcessFunction(KeyedProcessFunction):
+
+            def __init__(self):
+                self.value_state = None
+                self.list_state = None
+                self.map_state = None
+
+            def open(self, runtime_context: RuntimeContext):
+                value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
+                self.value_state = runtime_context.get_state(value_state_descriptor)
+                list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
+                self.list_state = runtime_context.get_list_state(list_state_descriptor)
+                map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
+                state_ttl_config = StateTtlConfig \
+                    .new_builder(Time.seconds(1)) \
+                    .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
+                    .set_state_visibility(
+                        StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp) \
+                    .disable_cleanup_in_background() \
+                    .build()
+                map_state_descriptor.enable_time_to_live(state_ttl_config)
+                self.map_state = runtime_context.get_map_state(map_state_descriptor)
+
+            def process_element(self, value, ctx):
+                current_value = self.value_state.value()
+                self.value_state.update(value[0])
+                current_list = [_ for _ in self.list_state.get()]
+                self.list_state.add(value[0])
+                map_entries = {k: v for k, v in self.map_state.items()}
+                keys = sorted(map_entries.keys())
+                map_entries_string = [str(k) + ': ' + str(map_entries[k]) for k in keys]
+                map_entries_string = '{' + ', '.join(map_entries_string) + '}'
+                self.map_state.put(value[0], value[1])
+                current_key = ctx.get_current_key()
+                yield "current key: {}, current value state: {}, current list state: {}, " \
+                      "current map state: {}, current value: {}".format(str(current_key),
+                                                                        str(current_value),
+                                                                        str(current_list),
+                                                                        map_entries_string,
+                                                                        str(value))
+
+            def on_timer(self, timestamp, ctx):
+                pass
+
+        watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
+            .with_timestamp_assigner(MyTimestampAssigner())
+        data_stream.assign_timestamps_and_watermarks(watermark_strategy) \
+            .key_by(lambda x: x[1], key_type=Types.STRING()) \
+            .process(MyProcessFunction(), output_type=Types.STRING()) \
+            .add_sink(self.test_sink)
+        self.env.execute('test time stamp assigner with keyed process function')
+        results = self.test_sink.get_results()
+        expected = ["current key: hi, current value state: None, current list state: [], "
+                    "current map state: {}, current value: Row(f0=1, f1='hi', "
+                    "f2='1603708211000')",
+                    "current key: hello, current value state: None, "
+                    "current list state: [], current map state: {}, current value: Row(f0=2,"
+                    " f1='hello', f2='1603708224000')",
+                    "current key: hi, current value state: 1, current list state: [1], "
+                    "current map state: {1: hi}, current value: Row(f0=3, f1='hi', "
+                    "f2='1603708226000')",
+                    "current key: hello, current value state: 2, current list state: [2], "
+                    "current map state: {2: hello}, current value: Row(f0=4, f1='hello', "
+                    "f2='1603708289000')",
+                    "current key: hi, current value state: 3, current list state: [1, 3], "
+                    "current map state: {1: hi, 3: hi}, current value: Row(f0=5, f1='hi', "
+                    "f2='1603708291000')",
+                    "current key: hello, current value state: 4, current list state: [2, 4],"
+                    " current map state: {2: hello, 4: hello}, current value: Row(f0=6, "
+                    "f1='hello', f2='1603708293000')"]
+        self.assert_equals_sorted(expected, results)
+
+    def test_reducing_state(self):
+        self.env.set_parallelism(2)
+        data_stream = self.env.from_collection([
+            (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
+            type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
+
+        class MyProcessFunction(KeyedProcessFunction):
+
+            def __init__(self):
+                self.reducing_state = None  # type: ReducingState
+
+            def open(self, runtime_context: RuntimeContext):
+                self.reducing_state = runtime_context.get_reducing_state(
+                    ReducingStateDescriptor(
+                        'reducing_state', lambda i, i2: i + i2, Types.INT()))
+
+            def process_element(self, value, ctx):
+                self.reducing_state.add(value[0])
+                yield self.reducing_state.get(), value[1]
+
+        data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
+            .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
+            .add_sink(self.test_sink)
+        self.env.execute('test_reducing_state')
+        result = self.test_sink.get_results()
+        expected_result = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
+        result.sort()
+        expected_result.sort()
+        self.assertEqual(expected_result, result)
+
+    def test_aggregating_state(self):
+        self.env.set_parallelism(2)
+        data_stream = self.env.from_collection([
+            (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
+            type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
+
+        class MyAggregateFunction(AggregateFunction):
+
+            def create_accumulator(self):
+                return 0
+
+            def add(self, value, accumulator):
+                return value + accumulator
+
+            def get_result(self, accumulator):
+                return accumulator
+
+            def merge(self, acc_a, acc_b):
+                return acc_a + acc_b
+
+        class MyProcessFunction(KeyedProcessFunction):
+
+            def __init__(self):
+                self.aggregating_state = None  # type: AggregatingState
+
+            def open(self, runtime_context: RuntimeContext):
+                descriptor = AggregatingStateDescriptor(
+                    'aggregating_state', MyAggregateFunction(), Types.INT())
+                state_ttl_config = StateTtlConfig \
+                    .new_builder(Time.seconds(1)) \
+                    .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
+                    .disable_cleanup_in_background() \
+                    .build()
+                descriptor.enable_time_to_live(state_ttl_config)
+                self.aggregating_state = runtime_context.get_aggregating_state(descriptor)
+
+            def process_element(self, value, ctx):
+                self.aggregating_state.add(value[0])
+                yield self.aggregating_state.get(), value[1]
+
+        config = Configuration(
+            j_configuration=get_j_env_configuration(self.env._j_stream_execution_environment))
+        config.set_integer("python.fn-execution.bundle.size", 1)
+        data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
+            .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
+            .add_sink(self.test_sink)
+        self.env.execute('test_aggregating_state')
+        results = self.test_sink.get_results()
+        expected = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
+        self.assert_equals_sorted(expected, results)
+
+
+class DataStreamStreamingTests(DataStreamTests):
+
+    def test_reduce_with_state(self):
+        ds = self.env.from_collection([('a', 0), ('c', 1), ('d', 1), ('b', 0), ('e', 1)],
+                                      type_info=Types.ROW([Types.STRING(), Types.INT()]))
+        keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
+
+        with self.assertRaises(Exception):
+            keyed_stream.name("keyed stream")
+
+        keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
+        self.env.execute('key_by_test')
+        results = self.test_sink.get_results(False)
+        expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde, 1]']
+        self.assert_equals_sorted(expected, results)
+
+
+class DataStreamBatchTests(DataStreamTests):
+
+    def test_reduce_with_state(self):
+        ds = self.env.from_collection([('a', 0), ('c', 1), ('d', 1), ('b', 0), ('e', 1)],
+                                      type_info=Types.ROW([Types.STRING(), Types.INT()]))
+        keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
+
+        with self.assertRaises(Exception):
+            keyed_stream.name("keyed stream")
+
+        keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
+        self.env.execute('key_by_test')
+        results = self.test_sink.get_results(False)
+        expected = ['+I[ab, 0]', '+I[cde, 1]']
+        self.assert_equals_sorted(expected, results)
+
 
 class ProcessDataStreamTests(DataStreamTests):
     """
@@ -640,176 +844,6 @@ class ProcessDataStreamTests(DataStreamTests):
                     "-9223372036854775808, current_value: Row(f0=4, f1='1603708289000')"]
         self.assert_equals_sorted(expected, results)
 
-    def test_keyed_process_function_with_state(self):
-        self.env.get_config().set_auto_watermark_interval(2000)
-        self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime)
-        data_stream = self.env.from_collection([(1, 'hi', '1603708211000'),
-                                                (2, 'hello', '1603708224000'),
-                                                (3, 'hi', '1603708226000'),
-                                                (4, 'hello', '1603708289000'),
-                                                (5, 'hi', '1603708291000'),
-                                                (6, 'hello', '1603708293000')],
-                                               type_info=Types.ROW([Types.INT(), Types.STRING(),
-                                                                    Types.STRING()]))
-
-        class MyTimestampAssigner(TimestampAssigner):
-
-            def extract_timestamp(self, value, record_timestamp) -> int:
-                return int(value[2])
-
-        class MyProcessFunction(KeyedProcessFunction):
-
-            def __init__(self):
-                self.value_state = None
-                self.list_state = None
-                self.map_state = None
-
-            def open(self, runtime_context: RuntimeContext):
-                value_state_descriptor = ValueStateDescriptor('value_state', Types.INT())
-                self.value_state = runtime_context.get_state(value_state_descriptor)
-                list_state_descriptor = ListStateDescriptor('list_state', Types.INT())
-                self.list_state = runtime_context.get_list_state(list_state_descriptor)
-                map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING())
-                state_ttl_config = StateTtlConfig \
-                    .new_builder(Time.seconds(1)) \
-                    .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
-                    .set_state_visibility(
-                        StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp) \
-                    .disable_cleanup_in_background() \
-                    .build()
-                map_state_descriptor.enable_time_to_live(state_ttl_config)
-                self.map_state = runtime_context.get_map_state(map_state_descriptor)
-
-            def process_element(self, value, ctx):
-                current_value = self.value_state.value()
-                self.value_state.update(value[0])
-                current_list = [_ for _ in self.list_state.get()]
-                self.list_state.add(value[0])
-                map_entries = {k: v for k, v in self.map_state.items()}
-                keys = sorted(map_entries.keys())
-                map_entries_string = [str(k) + ': ' + str(map_entries[k]) for k in keys]
-                map_entries_string = '{' + ', '.join(map_entries_string) + '}'
-                self.map_state.put(value[0], value[1])
-                current_key = ctx.get_current_key()
-                yield "current key: {}, current value state: {}, current list state: {}, " \
-                      "current map state: {}, current value: {}".format(str(current_key),
-                                                                        str(current_value),
-                                                                        str(current_list),
-                                                                        map_entries_string,
-                                                                        str(value))
-
-            def on_timer(self, timestamp, ctx):
-                pass
-
-        watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \
-            .with_timestamp_assigner(MyTimestampAssigner())
-        data_stream.assign_timestamps_and_watermarks(watermark_strategy) \
-            .key_by(lambda x: x[1], key_type=Types.STRING()) \
-            .process(MyProcessFunction(), output_type=Types.STRING()) \
-            .add_sink(self.test_sink)
-        self.env.execute('test time stamp assigner with keyed process function')
-        results = self.test_sink.get_results()
-        expected = ["current key: hi, current value state: None, current list state: [], "
-                    "current map state: {}, current value: Row(f0=1, f1='hi', "
-                    "f2='1603708211000')",
-                    "current key: hello, current value state: None, "
-                    "current list state: [], current map state: {}, current value: Row(f0=2,"
-                    " f1='hello', f2='1603708224000')",
-                    "current key: hi, current value state: 1, current list state: [1], "
-                    "current map state: {1: hi}, current value: Row(f0=3, f1='hi', "
-                    "f2='1603708226000')",
-                    "current key: hello, current value state: 2, current list state: [2], "
-                    "current map state: {2: hello}, current value: Row(f0=4, f1='hello', "
-                    "f2='1603708289000')",
-                    "current key: hi, current value state: 3, current list state: [1, 3], "
-                    "current map state: {1: hi, 3: hi}, current value: Row(f0=5, f1='hi', "
-                    "f2='1603708291000')",
-                    "current key: hello, current value state: 4, current list state: [2, 4],"
-                    " current map state: {2: hello, 4: hello}, current value: Row(f0=6, "
-                    "f1='hello', f2='1603708293000')"]
-        self.assert_equals_sorted(expected, results)
-
-    def test_reducing_state(self):
-        self.env.set_parallelism(2)
-        data_stream = self.env.from_collection([
-            (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
-            type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
-
-        class MyProcessFunction(KeyedProcessFunction):
-
-            def __init__(self):
-                self.reducing_state = None  # type: ReducingState
-
-            def open(self, runtime_context: RuntimeContext):
-                self.reducing_state = runtime_context.get_reducing_state(
-                    ReducingStateDescriptor(
-                        'reducing_state', lambda i, i2: i + i2, Types.INT()))
-
-            def process_element(self, value, ctx):
-                self.reducing_state.add(value[0])
-                yield self.reducing_state.get(), value[1]
-
-        data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
-            .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
-            .add_sink(self.test_sink)
-        self.env.execute('test_reducing_state')
-        result = self.test_sink.get_results()
-        expected_result = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
-        result.sort()
-        expected_result.sort()
-        self.assertEqual(expected_result, result)
-
-    def test_aggregating_state(self):
-        self.env.set_parallelism(2)
-        data_stream = self.env.from_collection([
-            (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 'hello')],
-            type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
-
-        class MyAggregateFunction(AggregateFunction):
-
-            def create_accumulator(self):
-                return 0
-
-            def add(self, value, accumulator):
-                return value + accumulator
-
-            def get_result(self, accumulator):
-                return accumulator
-
-            def merge(self, acc_a, acc_b):
-                return acc_a + acc_b
-
-        class MyProcessFunction(KeyedProcessFunction):
-
-            def __init__(self):
-                self.aggregating_state = None  # type: AggregatingState
-
-            def open(self, runtime_context: RuntimeContext):
-                descriptor = AggregatingStateDescriptor(
-                    'aggregating_state', MyAggregateFunction(), Types.INT())
-                state_ttl_config = StateTtlConfig \
-                    .new_builder(Time.seconds(1)) \
-                    .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \
-                    .disable_cleanup_in_background() \
-                    .build()
-                descriptor.enable_time_to_live(state_ttl_config)
-                self.aggregating_state = runtime_context.get_aggregating_state(descriptor)
-
-            def process_element(self, value, ctx):
-                self.aggregating_state.add(value[0])
-                yield self.aggregating_state.get(), value[1]
-
-        config = Configuration(
-            j_configuration=get_j_env_configuration(self.env._j_stream_execution_environment))
-        config.set_integer("python.fn-execution.bundle.size", 1)
-        data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \
-            .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
-            .add_sink(self.test_sink)
-        self.env.execute('test_aggregating_state')
-        results = self.test_sink.get_results()
-        expected = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)']
-        self.assert_equals_sorted(expected, results)
-
     def test_process_side_output(self):
         tag = OutputTag("side", Types.INT())
 
@@ -1122,7 +1156,8 @@ class ProcessDataStreamTests(DataStreamTests):
         self.assert_equals_sorted(expected, self.test_sink.get_results())
 
 
-class StreamingModeDataStreamTests(ProcessDataStreamTests, PyFlinkStreamingTestCase):
+class ProcessDataStreamStreamingTests(DataStreamStreamingTests, ProcessDataStreamTests,
+                                      PyFlinkStreamingTestCase):
     def test_data_stream_name(self):
         ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')])
         test_name = 'test_name'
@@ -1401,20 +1436,6 @@ class StreamingModeDataStreamTests(ProcessDataStreamTests, PyFlinkStreamingTestC
         expected = ["+I[1, a]", "+I[3, a]", "+I[6, a]", "+I[4, b]"]
         self.assert_equals_sorted(expected, results)
 
-    def test_reduce_with_state(self):
-        ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 1)],
-                                      type_info=Types.ROW([Types.STRING(), Types.INT()]))
-        keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
-
-        with self.assertRaises(Exception):
-            keyed_stream.name("keyed stream")
-
-        keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
-        self.env.execute('key_by_test')
-        results = self.test_sink.get_results(False)
-        expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde, 1]']
-        self.assert_equals_sorted(expected, results)
-
     def test_keyed_sum(self):
         self.env.set_parallelism(1)
         ds = self.env.from_collection(
@@ -1529,7 +1550,8 @@ class StreamingModeDataStreamTests(ProcessDataStreamTests, PyFlinkStreamingTestC
         self.assert_equals_sorted(expected, results)
 
 
-class BatchModeDataStreamTests(ProcessDataStreamTests, PyFlinkBatchTestCase):
+class ProcessDataStreamBatchTests(DataStreamBatchTests, ProcessDataStreamTests,
+                                  PyFlinkBatchTestCase):
 
     def test_timestamp_assigner_and_watermark_strategy(self):
         self.env.set_parallelism(1)
@@ -1587,20 +1609,6 @@ class BatchModeDataStreamTests(ProcessDataStreamTests, PyFlinkBatchTestCase):
         expected = ["+I[6, a]", "+I[7, b]"]
         self.assert_equals_sorted(expected, results)
 
-    def test_reduce_with_state(self):
-        ds = self.env.from_collection([('a', 0), ('c', 1), ('d', 1), ('b', 0), ('e', 1)],
-                                      type_info=Types.ROW([Types.STRING(), Types.INT()]))
-        keyed_stream = ds.key_by(MyKeySelector(), key_type=Types.INT())
-
-        with self.assertRaises(Exception):
-            keyed_stream.name("keyed stream")
-
-        keyed_stream.reduce(MyReduceFunction()).add_sink(self.test_sink)
-        self.env.execute('key_by_test')
-        results = self.test_sink.get_results(False)
-        expected = ['+I[ab, 0]', '+I[cde, 1]']
-        self.assert_equals_sorted(expected, results)
-
     def test_keyed_sum(self):
         self.env.set_parallelism(1)
         ds = self.env.from_collection(
@@ -1704,9 +1712,17 @@ class BatchModeDataStreamTests(ProcessDataStreamTests, PyFlinkBatchTestCase):
 
 
 @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7")
-class EmbeddedDataStreamTests(DataStreamTests, PyFlinkStreamingTestCase):
+class EmbeddedDataStreamStreamTests(DataStreamStreamingTests, PyFlinkStreamingTestCase):
+    def setUp(self):
+        super(EmbeddedDataStreamStreamTests, self).setUp()
+        config = get_j_env_configuration(self.env._j_stream_execution_environment)
+        config.setString("python.execution-mode", "thread")
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires python3.7")
+class EmbeddedDataStreamBatchTests(DataStreamBatchTests, PyFlinkBatchTestCase):
     def setUp(self):
-        super(EmbeddedDataStreamTests, self).setUp()
+        super(EmbeddedDataStreamBatchTests, self).setUp()
         config = get_j_env_configuration(self.env._j_stream_execution_environment)
         config.setString("python.execution-mode", "thread")
 
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
index 50be1362075..a109d2fba94 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/operations.py
@@ -18,41 +18,46 @@
 
 from pyflink.fn_execution import pickle
 from pyflink.fn_execution.datastream import operations
-from pyflink.fn_execution.datastream.embedded.process_function import InternalProcessFunctionContext
+from pyflink.fn_execution.datastream.embedded.process_function import (
+    InternalProcessFunctionContext, InternalKeyedProcessFunctionContext,
+    InternalKeyedProcessFunctionOnTimerContext)
 from pyflink.fn_execution.datastream.embedded.runtime_context import StreamingRuntimeContext
-from pyflink.fn_execution.datastream.operations import DATA_STREAM_STATELESS_FUNCTION_URN
 
 
 class OneInputOperation(operations.OneInputOperation):
-    def __init__(self,
-                 function_urn,
-                 serialized_fn,
-                 runtime_context,
-                 function_context,
-                 job_parameters):
-        (self.open_func,
-         self.close_func,
-         self.process_element_func
-         ) = extract_one_input_process_function(
-            function_urn=function_urn,
-            user_defined_function_proto=serialized_fn,
-            runtime_context=StreamingRuntimeContext.of(runtime_context, job_parameters),
-            function_context=function_context)
+    def __init__(self, open_func, close_func, process_element_func, on_timer_func=None):
+        self._open_func = open_func
+        self._close_func = close_func
+        self._process_element_func = process_element_func
+        self._on_timer_func = on_timer_func
 
     def open(self) -> None:
-        self.open_func()
+        self._open_func()
 
     def close(self) -> None:
-        self.close_func()
+        self._close_func()
 
     def process_element(self, value):
-        return self.process_element_func(value)
+        return self._process_element_func(value)
 
+    def on_timer(self, timestamp):
+        if self._on_timer_func:
+            return self._on_timer_func(timestamp)
+
+
+def extract_process_function(
+        user_defined_function_proto, runtime_context, function_context, timer_context,
+        job_parameters):
+    from pyflink.fn_execution import flink_fn_execution_pb2
 
-def extract_one_input_process_function(
-        function_urn, user_defined_function_proto, runtime_context, function_context):
     user_defined_func = pickle.loads(user_defined_function_proto.payload)
 
+    func_type = user_defined_function_proto.function_type
+
+    UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
+
+    runtime_context = StreamingRuntimeContext.of(runtime_context, job_parameters)
+
     def open_func():
         if hasattr(user_defined_func, "open"):
             user_defined_func.open(runtime_context)
@@ -61,12 +66,35 @@ def extract_one_input_process_function(
         if hasattr(user_defined_func, "close"):
             user_defined_func.close()
 
-    process_element = user_defined_func.process_element
+    if func_type == UserDefinedDataStreamFunction.PROCESS:
+        function_context = InternalProcessFunctionContext(function_context)
+
+        process_element = user_defined_func.process_element
+
+        def process_element_func(value):
+            yield from process_element(value, function_context)
+
+        return OneInputOperation(open_func, close_func, process_element_func)
+
+    elif func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:
+
+        function_context = InternalKeyedProcessFunctionContext(
+            function_context, user_defined_function_proto.key_type_info)
+
+        timer_context = InternalKeyedProcessFunctionOnTimerContext(
+            timer_context, user_defined_function_proto.key_type_info)
+
+        on_timer = user_defined_func.on_timer
+
+        def process_element(value, context):
+            return user_defined_func.process_element(value[1], context)
 
-    if function_urn == DATA_STREAM_STATELESS_FUNCTION_URN:
-        context = InternalProcessFunctionContext(function_context)
+        def on_timer_func(timestamp):
+            yield from on_timer(timestamp, timer_context)
 
-    def process_element_func(value):
-        yield from process_element(value, context)
+        def process_element_func(value):
+            yield from process_element(value, function_context)
 
-    return open_func, close_func, process_element_func
+        return OneInputOperation(open_func, close_func, process_element_func, on_timer_func)
+    else:
+        raise Exception("Unknown function type {0}.".format(func_type))
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
index 8c7a1608634..eee7b4b01d8 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
@@ -15,7 +15,9 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
-from pyflink.datastream import ProcessFunction, TimerService
+from pyflink.datastream import ProcessFunction, TimerService, KeyedProcessFunction, TimeDomain
+from pyflink.fn_execution.datastream.embedded.timerservice_impl import TimerServiceImpl
+from pyflink.fn_execution.embedded.converters import from_type_info
 
 
 class InternalProcessFunctionContext(ProcessFunction.Context, TimerService):
@@ -45,3 +47,41 @@ class InternalProcessFunctionContext(ProcessFunction.Context, TimerService):
 
     def delete_event_time_timer(self, t: int):
         raise Exception("Deleting timers is only supported on a keyed streams.")
+
+
+class InternalKeyedProcessFunctionContext(KeyedProcessFunction.Context):
+
+    def __init__(self, context, key_type_info):
+        self._context = context
+        self._timer_service = TimerServiceImpl(self._context.timerService())
+        self._key_converter = from_type_info(key_type_info)
+
+    def get_current_key(self):
+        return self._key_converter.to_internal(self._context.getCurrentKey())
+
+    def timer_service(self) -> TimerService:
+        return self._timer_service
+
+    def timestamp(self) -> int:
+        return self._context.timestamp()
+
+
+class InternalKeyedProcessFunctionOnTimerContext(KeyedProcessFunction.OnTimerContext,
+                                                 KeyedProcessFunction.Context):
+
+    def __init__(self, context, key_type_info):
+        self._context = context
+        self._timer_service = TimerServiceImpl(self._context.timerService())
+        self._key_converter = from_type_info(key_type_info)
+
+    def timer_service(self) -> TimerService:
+        return self._timer_service
+
+    def timestamp(self) -> int:
+        return self._context.timestamp()
+
+    def time_domain(self) -> TimeDomain:
+        return TimeDomain(self._context.timeDomain())
+
+    def get_current_key(self):
+        return self._key_converter.to_internal(self._context.getCurrentKey())
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py b/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py
index 64abbf99b11..c7d9508446e 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/runtime_context.py
@@ -16,9 +16,15 @@
 # limitations under the License.
 ################################################################################
 from pyflink.datastream import RuntimeContext
-from pyflink.datastream.state import AggregatingStateDescriptor, AggregatingState, \
-    ReducingStateDescriptor, ReducingState, MapStateDescriptor, MapState, ListStateDescriptor, \
-    ListState, ValueStateDescriptor, ValueState
+from pyflink.datastream.state import (AggregatingStateDescriptor, AggregatingState,
+                                      ReducingStateDescriptor, ReducingState, MapStateDescriptor,
+                                      MapState, ListStateDescriptor, ListState,
+                                      ValueStateDescriptor, ValueState)
+from pyflink.fn_execution.datastream.embedded.state_impl import (ValueStateImpl, ListStateImpl,
+                                                                 MapStateImpl, ReducingStateImpl,
+                                                                 AggregatingStateImpl)
+from pyflink.fn_execution.embedded.converters import from_type_info
+from pyflink.fn_execution.embedded.java_utils import to_java_state_descriptor
 
 
 class StreamingRuntimeContext(RuntimeContext):
@@ -76,20 +82,32 @@ class StreamingRuntimeContext(RuntimeContext):
         return self._runtime_context.getMetricGroup()
 
     def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState:
-        pass
+        return ValueStateImpl(
+            self._runtime_context.getState(to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
 
     def get_list_state(self, state_descriptor: ListStateDescriptor) -> ListState:
-        pass
+        return ListStateImpl(
+            self._runtime_context.getListState(to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
 
     def get_map_state(self, state_descriptor: MapStateDescriptor) -> MapState:
-        pass
+        return MapStateImpl(
+            self._runtime_context.getMapState(to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info))
 
     def get_reducing_state(self, state_descriptor: ReducingStateDescriptor) -> ReducingState:
-        pass
+        return ReducingStateImpl(
+            self._runtime_context.getState(to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info),
+            state_descriptor.get_reduce_function())
 
     def get_aggregating_state(self,
                               state_descriptor: AggregatingStateDescriptor) -> AggregatingState:
-        pass
+        return AggregatingStateImpl(
+            self._runtime_context.getState(to_java_state_descriptor(state_descriptor)),
+            from_type_info(state_descriptor.type_info),
+            state_descriptor.get_agg_function())
 
     @staticmethod
     def of(runtime_context, job_parameters):
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
new file mode 100644
index 00000000000..e88c6abc304
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/state_impl.py
@@ -0,0 +1,158 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List, Iterable, Tuple, Dict
+
+from pyflink.datastream import ReduceFunction, AggregateFunction
+from pyflink.datastream.state import (ValueState, T, State, ListState, IN, OUT, ReducingState,
+                                      AggregatingState, MapState, V, K)
+from pyflink.fn_execution.embedded.converters import (DataConverter, DictDataConverter,
+                                                      ListDataConverter)
+
+
+class StateImpl(State):
+    def __init__(self, state, value_converter: DataConverter):
+        self._state = state
+        self._value_converter = value_converter
+
+    def clear(self):
+        self._state.clear()
+
+
+class ValueStateImpl(StateImpl, ValueState):
+    def __init__(self, value_state, value_converter: DataConverter):
+        super(ValueStateImpl, self).__init__(value_state, value_converter)
+
+    def value(self) -> T:
+        return self._value_converter.to_internal(self._state.value())
+
+    def update(self, value: T) -> None:
+        self._state.update(self._value_converter.to_external(value))
+
+
+class ListStateImpl(StateImpl, ListState):
+
+    def __init__(self, list_state, value_converter: ListDataConverter):
+        super(ListStateImpl, self).__init__(list_state, value_converter)
+        self._element_converter = value_converter._field_converter
+
+    def update(self, values: List[T]) -> None:
+        self._state.update(self._value_converter.to_external(values))
+
+    def add_all(self, values: List[T]) -> None:
+        self._state.addAll(self._value_converter.to_external(values))
+
+    def get(self) -> OUT:
+        return self._value_converter.to_internal(self._state.get())
+
+    def add(self, value: IN) -> None:
+        self._state.add(self._element_converter.to_external(value))
+
+
+class ReducingStateImpl(StateImpl, ReducingState):
+
+    def __init__(self,
+                 value_state,
+                 value_converter: DataConverter,
+                 reduce_function: ReduceFunction):
+        super(ReducingStateImpl, self).__init__(value_state, value_converter)
+        self._reduce_function = reduce_function
+
+    def get(self) -> OUT:
+        return self._value_converter.to_internal(self._state.value())
+
+    def add(self, value: IN) -> None:
+        if value is None:
+            self.clear()
+        else:
+            current_value = self.get()
+
+            if current_value is None:
+                reduce_value = value
+            else:
+                reduce_value = self._reduce_function.reduce(current_value, value)
+
+            self._state.update(self._value_converter.to_external(reduce_value))
+
+
+class AggregatingStateImpl(StateImpl, AggregatingState):
+    def __init__(self,
+                 value_state,
+                 value_converter,
+                 agg_function: AggregateFunction):
+        super(AggregatingStateImpl, self).__init__(value_state, value_converter)
+        self._agg_function = agg_function
+
+    def get(self) -> OUT:
+        accumulator = self._value_converter.to_internal(self._state.value())
+
+        if accumulator is None:
+            return None
+        else:
+            return self._agg_function.get_result(accumulator)
+
+    def add(self, value: IN) -> None:
+        if value is None:
+            self.clear()
+        else:
+            accumulator = self._value_converter.to_internal(self._state.value())
+
+            if accumulator is None:
+                accumulator = self._agg_function.create_accumulator()
+
+            accumulator = self._agg_function.add(value, accumulator)
+            self._state.update(self._value_converter.to_external(accumulator))
+
+
+class MapStateImpl(StateImpl, MapState):
+    def __init__(self, map_state, map_converter: DictDataConverter):
+        super(MapStateImpl, self).__init__(map_state, map_converter)
+        self._k_converter = map_converter._key_converter
+        self._v_converter = map_converter._value_converter
+
+    def get(self, key: K) -> V:
+        return self._value_converter.to_internal(
+            self._state.get(self._k_converter.to_external(key)))
+
+    def put(self, key: K, value: V) -> None:
+        self._state.put(self._k_converter.to_external(key), self._v_converter.to_external(value))
+
+    def put_all(self, dict_value: Dict[K, V]) -> None:
+        self._state.putAll(self._value_converter.to_external(dict_value))
+
+    def remove(self, key: K) -> None:
+        self._state.remove(self._k_converter.to_external(key))
+
+    def contains(self, key: K) -> bool:
+        return self._state.contains(self._k_converter.to_external(key))
+
+    def items(self) -> Iterable[Tuple[K, V]]:
+        entries = self._state.entries()
+        for entry in entries:
+            yield (self._k_converter.to_internal(entry.getKey()),
+                   self._v_converter.to_internal(entry.getValue()))
+
+    def keys(self) -> Iterable[K]:
+        for k in self._state.keys():
+            yield self._k_converter.to_internal(k)
+
+    def values(self) -> Iterable[V]:
+        for v in self._state.values():
+            yield self._v_converter.to_internal(v)
+
+    def is_empty(self) -> bool:
+        return self._state.isEmpty()
diff --git a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py b/flink-python/pyflink/fn_execution/datastream/embedded/timerservice_impl.py
similarity index 57%
copy from flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
copy to flink-python/pyflink/fn_execution/datastream/embedded/timerservice_impl.py
index 8c7a1608634..fab9440b230 100644
--- a/flink-python/pyflink/fn_execution/datastream/embedded/process_function.py
+++ b/flink-python/pyflink/fn_execution/datastream/embedded/timerservice_impl.py
@@ -15,33 +15,27 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
-from pyflink.datastream import ProcessFunction, TimerService
+from pyflink.datastream import TimerService
 
 
-class InternalProcessFunctionContext(ProcessFunction.Context, TimerService):
-    def __init__(self, context):
-        self._context = context
-
-    def timer_service(self) -> TimerService:
-        return self
-
-    def timestamp(self) -> int:
-        return self._context.timestamp()
+class TimerServiceImpl(TimerService):
+    def __init__(self, timer_service):
+        self._timer_service = timer_service
 
     def current_processing_time(self):
-        return self._context.currentProcessingTime()
+        return self._timer_service.currentProcessingTime()
 
     def current_watermark(self):
-        return self._context.currentWatermark()
+        return self._timer_service.currentWatermark()
 
     def register_processing_time_timer(self, timestamp: int):
-        raise Exception("Register timers is only supported on a keyed stream.")
+        self._timer_service.registerProcessingTimeTimer(timestamp)
 
     def register_event_time_timer(self, timestamp: int):
-        raise Exception("Register timers is only supported on a keyed stream.")
+        self._timer_service.registerEventTimeTimer(timestamp)
 
-    def delete_processing_time_timer(self, t: int):
-        raise Exception("Deleting timers is only supported on a keyed streams.")
+    def delete_processing_time_timer(self, timestamp: int):
+        self._timer_service.deleteProcessingTimeTimer(timestamp)
 
-    def delete_event_time_timer(self, t: int):
-        raise Exception("Deleting timers is only supported on a keyed streams.")
+    def delete_event_time_timer(self, timestamp: int):
+        self._timer_service.deleteEventTimeTimer(timestamp)
diff --git a/flink-python/pyflink/fn_execution/embedded/converters.py b/flink-python/pyflink/fn_execution/embedded/converters.py
index 157a082c93a..ee75e6d5beb 100644
--- a/flink-python/pyflink/fn_execution/embedded/converters.py
+++ b/flink-python/pyflink/fn_execution/embedded/converters.py
@@ -20,7 +20,10 @@ from abc import ABC, abstractmethod
 import pickle
 from typing import TypeVar, List, Tuple
 
-from pyflink.common import Row, RowKind
+from pyflink.common import Row, RowKind, TypeInformation
+from pyflink.common.typeinfo import (PickledBytesTypeInfo, PrimitiveArrayTypeInfo,
+                                     BasicArrayTypeInfo, ObjectArrayTypeInfo, RowTypeInfo,
+                                     TupleTypeInfo, MapTypeInfo, ListTypeInfo)
 
 IN = TypeVar('IN')
 OUT = TypeVar('OUT')
@@ -70,13 +73,15 @@ class FlattenRowDataConverter(DataConverter):
         if value is None:
             return None
 
-        return [self._field_data_converters[i].to_internal(item) for i, item in enumerate(value)]
+        return tuple([self._field_data_converters[i].to_internal(item)
+                      for i, item in enumerate(value)])
 
     def to_external(self, value) -> OUT:
         if value is None:
             return None
 
-        return [self._field_data_converters[i].to_external(item) for i, item in enumerate(value)]
+        return tuple([self._field_data_converters[i].to_external(item)
+                      for i, item in enumerate(value)])
 
 
 class RowDataConverter(DataConverter):
@@ -84,8 +89,6 @@ class RowDataConverter(DataConverter):
     def __init__(self, field_data_converters: List[DataConverter], field_names: List[str]):
         self._field_data_converters = field_data_converters
         self._reuse_row = Row()
-        self._reuse_external_row_data = [None for _ in range(len(field_data_converters))]
-        self._reuse_external_row = [None, self._reuse_external_row_data]
         self._reuse_row.set_field_names(field_names)
 
     def to_internal(self, value) -> IN:
@@ -102,11 +105,10 @@ class RowDataConverter(DataConverter):
         if value is None:
             return None
 
-        self._reuse_external_row[0] = value.get_row_kind().value
         values = value._values
-        for i in range(len(values)):
-            self._reuse_external_row_data[i] = self._field_data_converters[i].to_external(values[i])
-        return self._reuse_external_row
+        fields = tuple([self._field_data_converters[i].to_external(values[i])
+                        for i in range(len(values))])
+        return value.get_row_kind().value, fields
 
 
 class TupleDataConverter(DataConverter):
@@ -125,8 +127,8 @@ class TupleDataConverter(DataConverter):
         if value is None:
             return None
 
-        return [self._field_data_converters[i].to_external(item)
-                for i, item in enumerate(value)]
+        return tuple([self._field_data_converters[i].to_external(item)
+                      for i, item in enumerate(value)])
 
 
 class ListDataConverter(DataConverter):
@@ -147,6 +149,17 @@ class ListDataConverter(DataConverter):
         return [self._field_converter.to_external(item) for item in value]
 
 
+class ArrayDataConverter(ListDataConverter):
+    def __init__(self, field_converter: DataConverter):
+        super(ArrayDataConverter, self).__init__(field_converter)
+
+    def to_internal(self, value) -> IN:
+        return tuple(super(ArrayDataConverter, self).to_internal(value))
+
+    def to_external(self, value) -> OUT:
+        return tuple(super(ArrayDataConverter, self).to_external(value))
+
+
 class DictDataConverter(DataConverter):
     def __init__(self, key_converter: DataConverter, value_converter: DataConverter):
         self._key_converter = key_converter
@@ -185,8 +198,9 @@ def from_type_info_proto(type_info):
             [from_type_info_proto(field_type)
              for field_type in type_info.tuple_type_info.field_types])
     elif type_name in (type_info_name.BASIC_ARRAY,
-                       type_info_name.OBJECT_ARRAY,
-                       type_info_name.LIST):
+                       type_info_name.OBJECT_ARRAY):
+        return ArrayDataConverter(from_type_info_proto(type_info.collection_element_type))
+    elif type_info == type_info_name.LIST:
         return ListDataConverter(from_type_info_proto(type_info.collection_element_type))
     elif type_name == type_info_name.MAP:
         return DictDataConverter(from_type_info_proto(type_info.map_type_info.key_type),
@@ -214,9 +228,23 @@ def from_field_type_proto(field_type):
             [from_field_type_proto(f.type) for f in field_type.row_schema.fields],
             [f.name for f in field_type.row_schema.fields])
     elif type_name == schema_type_name.BASIC_ARRAY:
-        return ListDataConverter(from_field_type_proto(field_type.collection_element_type))
+        return ArrayDataConverter(from_field_type_proto(field_type.collection_element_type))
     elif type_name == schema_type_name.MAP:
         return DictDataConverter(from_field_type_proto(field_type.map_info.key_type),
                                  from_field_type_proto(field_type.map_info.value_type))
 
     return IdentityDataConverter()
+
+
+def from_type_info(type_info: TypeInformation):
+    if isinstance(type_info, (PickledBytesTypeInfo, RowTypeInfo, TupleTypeInfo)):
+        return PickleDataConverter()
+    elif isinstance(type_info, (PrimitiveArrayTypeInfo, BasicArrayTypeInfo, ObjectArrayTypeInfo)):
+        return ArrayDataConverter(from_type_info(type_info._element_type))
+    elif isinstance(type_info, ListTypeInfo):
+        return ListDataConverter(from_type_info(type_info.elem_type))
+    elif isinstance(type_info, MapTypeInfo):
+        return DictDataConverter(from_type_info(type_info._key_type_info),
+                                 from_type_info(type_info._value_type_info))
+
+    return IdentityDataConverter()
diff --git a/flink-python/pyflink/fn_execution/embedded/java_utils.py b/flink-python/pyflink/fn_execution/embedded/java_utils.py
new file mode 100644
index 00000000000..099cca8f1f3
--- /dev/null
+++ b/flink-python/pyflink/fn_execution/embedded/java_utils.py
@@ -0,0 +1,204 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pemja import findClass
+
+from pyflink.common.typeinfo import (TypeInformation, Types, BasicTypeInfo, BasicType,
+                                     PrimitiveArrayTypeInfo, BasicArrayTypeInfo,
+                                     ObjectArrayTypeInfo, MapTypeInfo)
+from pyflink.datastream.state import (StateDescriptor, ValueStateDescriptor,
+                                      ReducingStateDescriptor,
+                                      AggregatingStateDescriptor, ListStateDescriptor,
+                                      MapStateDescriptor, StateTtlConfig)
+
+# Java Types Class
+JTypes = findClass('org.apache.flink.api.common.typeinfo.Types')
+JPrimitiveArrayTypeInfo = findClass('org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo')
+JBasicArrayTypeInfo = findClass('org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo')
+JPickledByteArrayTypeInfo = findClass('org.apache.flink.streaming.api.typeinfo.python.'
+                                      'PickledByteArrayTypeInfo')
+JMapTypeInfo = findClass('org.apache.flink.api.java.typeutils.MapTypeInfo')
+
+# Java State Descriptor Class
+JValueStateDescriptor = findClass('org.apache.flink.api.common.state.ValueStateDescriptor')
+JListStateDescriptor = findClass('org.apache.flink.api.common.state.ListStateDescriptor')
+JMapStateDescriptor = findClass('org.apache.flink.api.common.state.MapStateDescriptor')
+
+# Java StateTtlConfig
+JStateTtlConfig = findClass('org.apache.flink.api.common.state.StateTtlConfig')
+JTime = findClass('org.apache.flink.api.common.time.Time')
+JUpdateType = findClass('org.apache.flink.api.common.state.StateTtlConfig$UpdateType')
+JStateVisibility = findClass('org.apache.flink.api.common.state.StateTtlConfig$StateVisibility')
+
+
+def to_java_typeinfo(type_info: TypeInformation):
+    if isinstance(type_info, BasicTypeInfo):
+        basic_type = type_info._basic_type
+
+        if basic_type == BasicType.STRING:
+            j_typeinfo = JTypes.STRING
+        elif basic_type == BasicType.BYTE:
+            j_typeinfo = JTypes.LONG
+        elif basic_type == BasicType.BOOLEAN:
+            j_typeinfo = JTypes.BOOLEAN
+        elif basic_type == BasicType.SHORT:
+            j_typeinfo = JTypes.LONG
+        elif basic_type == BasicType.INT:
+            j_typeinfo = JTypes.LONG
+        elif basic_type == BasicType.LONG:
+            j_typeinfo = JTypes.LONG
+        elif basic_type == BasicType.FLOAT:
+            j_typeinfo = JTypes.DOUBLE
+        elif basic_type == BasicType.DOUBLE:
+            j_typeinfo = JTypes.DOUBLE
+        elif basic_type == BasicType.CHAR:
+            j_typeinfo = JTypes.STRING
+        elif basic_type == BasicType.BIG_INT:
+            j_typeinfo = JTypes.BIG_INT
+        elif basic_type == BasicType.BIG_DEC:
+            j_typeinfo = JTypes.BIG_DEC
+        elif basic_type == BasicType.INSTANT:
+            j_typeinfo = JTypes.INSTANT
+        else:
+            raise TypeError("Invalid BasicType %s." % basic_type)
+
+    elif isinstance(type_info, PrimitiveArrayTypeInfo):
+        element_type = type_info._element_type
+
+        if element_type == Types.BOOLEAN():
+            j_typeinfo = JPrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.BYTE():
+            j_typeinfo = JPrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.SHORT():
+            j_typeinfo = JPrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.INT():
+            j_typeinfo = JPrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.LONG():
+            j_typeinfo = JPrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.FLOAT():
+            j_typeinfo = JPrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.DOUBLE():
+            j_typeinfo = JPrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO
+        elif element_type == Types.CHAR():
+            j_typeinfo = JPrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO
+        else:
+            raise TypeError("Invalid element type for a primitive array.")
+
+    elif isinstance(type_info, BasicArrayTypeInfo):
+        element_type = type_info._element_type
+
+        if element_type == Types.BOOLEAN():
+            j_typeinfo = JBasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO
+        elif element_type == Types.BYTE():
+            j_typeinfo = JBasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO
+        elif element_type == Types.SHORT():
+            j_typeinfo = JBasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO
+        elif element_type == Types.INT():
+            j_typeinfo = JBasicArrayTypeInfo.INT_ARRAY_TYPE_INFO
+        elif element_type == Types.LONG():
+            j_typeinfo = JBasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO
+        elif element_type == Types.FLOAT():
+            j_typeinfo = JBasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO
+        elif element_type == Types.DOUBLE():
+            j_typeinfo = JBasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO
+        elif element_type == Types.CHAR():
+            j_typeinfo = JBasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO
+        elif element_type == Types.STRING():
+            j_typeinfo = JBasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
+        else:
+            raise TypeError("Invalid element type for a basic array.")
+
+    elif isinstance(type_info, ObjectArrayTypeInfo):
+        element_type = type_info._element_type
+
+        j_typeinfo = JTypes.OBJECT_ARRAY(to_java_typeinfo(element_type))
+
+    elif isinstance(type_info, MapTypeInfo):
+        j_key_typeinfo = to_java_typeinfo(type_info._key_type_info)
+        j_value_typeinfo = to_java_typeinfo(type_info._value_type_info)
+
+        j_typeinfo = JMapTypeInfo(j_key_typeinfo, j_value_typeinfo)
+    else:
+        j_typeinfo = JPickledByteArrayTypeInfo.PICKLED_BYTE_ARRAY_TYPE_INFO
+
+    return j_typeinfo
+
+
+def to_java_state_ttl_config(ttl_config: StateTtlConfig):
+    j_ttl_config_builder = JStateTtlConfig.newBuilder(
+        JTime.milliseconds(ttl_config.get_ttl().to_milliseconds()))
+
+    update_type = ttl_config.get_update_type()
+    if update_type == StateTtlConfig.UpdateType.Disabled:
+        j_ttl_config_builder.setUpdateType(JUpdateType.Disabled)
+    elif update_type == StateTtlConfig.UpdateType.OnCreateAndWrite:
+        j_ttl_config_builder.setUpdateType(JUpdateType.OnCreateAndWrite)
+    elif update_type == StateTtlConfig.UpdateType.OnReadAndWrite:
+        j_ttl_config_builder.setUpdateType(JUpdateType.OnReadAndWrite)
+
+    state_visibility = ttl_config.get_state_visibility()
+    if state_visibility == StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp:
+        j_ttl_config_builder.setStateVisibility(JStateVisibility.ReturnExpiredIfNotCleanedUp)
+    elif state_visibility == StateTtlConfig.StateVisibility.NeverReturnExpired:
+        j_ttl_config_builder.setStateVisibility(JStateVisibility.NeverReturnExpired)
+
+    cleanup_strategies = ttl_config.get_cleanup_strategies()
+    if not cleanup_strategies.is_cleanup_in_background():
+        j_ttl_config_builder.disableCleanupInBackground()
+
+    if cleanup_strategies.in_full_snapshot():
+        j_ttl_config_builder.cleanupFullSnapshot()
+
+    incremental_cleanup_strategy = cleanup_strategies.get_incremental_cleanup_strategy()
+    if incremental_cleanup_strategy:
+        j_ttl_config_builder.cleanupIncrementally(
+            incremental_cleanup_strategy.get_cleanup_size(),
+            incremental_cleanup_strategy.run_cleanup_for_every_record())
+
+    rocksdb_compact_filter_cleanup_strategy = \
+        cleanup_strategies.get_rocksdb_compact_filter_cleanup_strategy()
+
+    if rocksdb_compact_filter_cleanup_strategy:
+        j_ttl_config_builder.cleanupInRocksdbCompactFilter(
+            rocksdb_compact_filter_cleanup_strategy.get_query_time_after_num_entries())
+
+    return j_ttl_config_builder.build()
+
+
+def to_java_state_descriptor(state_descriptor: StateDescriptor):
+    if isinstance(state_descriptor,
+                  (ValueStateDescriptor, ReducingStateDescriptor, AggregatingStateDescriptor)):
+        value_type_info = to_java_typeinfo(state_descriptor.type_info)
+        j_state_descriptor = JValueStateDescriptor(state_descriptor.name, value_type_info)
+
+    elif isinstance(state_descriptor, ListStateDescriptor):
+        element_type_info = to_java_typeinfo(state_descriptor.type_info.elem_type)
+        j_state_descriptor = JListStateDescriptor(state_descriptor.name, element_type_info)
+
+    elif isinstance(state_descriptor, MapStateDescriptor):
+        key_type_info = to_java_typeinfo(state_descriptor.type_info._key_type_info)
+        value_type_info = to_java_typeinfo(state_descriptor.type_info._value_type_info)
+        j_state_descriptor = JMapStateDescriptor(
+            state_descriptor.name, key_type_info, value_type_info)
+    else:
+        raise Exception("Unknown supported state_descriptor {0}".format(state_descriptor))
+
+    if state_descriptor._ttl_config:
+        j_state_ttl_config = to_java_state_ttl_config(state_descriptor._ttl_config)
+        j_state_descriptor.enableTimeToLive(j_state_ttl_config)
+
+    return j_state_descriptor
diff --git a/flink-python/pyflink/fn_execution/embedded/operation_utils.py b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
index 0d666693fbc..c6a96d2ad67 100644
--- a/flink-python/pyflink/fn_execution/embedded/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/embedded/operation_utils.py
@@ -101,8 +101,8 @@ def create_table_operation_from_proto(proto, input_coder_info, output_coder_into
 
 
 def create_one_input_user_defined_data_stream_function_from_protos(
-        function_urn, function_infos, input_coder_info, output_coder_info, runtime_context,
-        function_context, job_parameters):
+        function_infos, input_coder_info, output_coder_info, runtime_context,
+        function_context, timer_context, job_parameters):
     serialized_fns = [pare_user_defined_data_stream_function_proto(proto)
                       for proto in function_infos]
     input_data_converter = (
@@ -111,12 +111,12 @@ def create_one_input_user_defined_data_stream_function_from_protos(
         from_type_info_proto(parse_coder_proto(output_coder_info).raw_type.type_info))
 
     function_operation = OneInputFunctionOperation(
-        function_urn,
         serialized_fns,
         input_data_converter,
         output_data_converter,
         runtime_context,
         function_context,
+        timer_context,
         job_parameters)
 
     return function_operation
diff --git a/flink-python/pyflink/fn_execution/embedded/operations.py b/flink-python/pyflink/fn_execution/embedded/operations.py
index 1ece4cf6e11..919195c16b6 100644
--- a/flink-python/pyflink/fn_execution/embedded/operations.py
+++ b/flink-python/pyflink/fn_execution/embedded/operations.py
@@ -15,7 +15,7 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
-from pyflink.fn_execution.datastream.embedded.operations import OneInputOperation
+from pyflink.fn_execution.datastream.embedded.operations import extract_process_function
 from pyflink.fn_execution.embedded.converters import DataConverter
 
 
@@ -36,22 +36,27 @@ class FunctionOperation(object):
         for operation in self._operations:
             operation.close()
 
+    def on_timer(self, timestamp):
+        for operation in self._operations:
+            for item in operation.on_timer(timestamp):
+                yield self._output_data_converter.to_external(item)
+
 
 class OneInputFunctionOperation(FunctionOperation):
     def __init__(self,
-                 function_urn,
                  serialized_fns,
                  input_data_converter: DataConverter,
                  output_data_converter: DataConverter,
                  runtime_context,
                  function_context,
+                 timer_context,
                  job_parameters):
         operations = (
-            [OneInputOperation(
-                function_urn,
+            [extract_process_function(
                 serialized_fn,
                 runtime_context,
                 function_context,
+                timer_context,
                 job_parameters)
                 for serialized_fn in serialized_fns])
         super(OneInputFunctionOperation, self).__init__(
diff --git a/flink-python/setup.py b/flink-python/setup.py
index 81c8b463412..acc377bb6d3 100644
--- a/flink-python/setup.py
+++ b/flink-python/setup.py
@@ -309,7 +309,7 @@ try:
                         'cloudpickle==2.1.0', 'avro-python3>=1.8.1,!=1.9.2,<1.10.0',
                         'pytz>=2018.3', 'fastavro>=1.1.0,<1.4.8', 'requests>=2.26.0',
                         'protobuf<3.18',
-                        'pemja==0.2.0;'
+                        'pemja==0.2.2;'
                         'python_full_version >= "3.7" and platform_system != "Windows"',
                         'httplib2>=0.19.0,<=0.20.4', apache_flink_libraries_dependency]
 
diff --git a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
index 272e929c693..613cc6459c5 100644
--- a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
+++ b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java
@@ -33,6 +33,7 @@ import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
 import org.apache.flink.streaming.api.operators.python.embedded.AbstractEmbeddedDataStreamPythonFunctionOperator;
+import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedProcessOperator;
 import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonProcessOperator;
 import org.apache.flink.streaming.api.operators.python.process.AbstractExternalDataStreamPythonFunctionOperator;
 import org.apache.flink.streaming.api.operators.python.process.ExternalPythonCoProcessOperator;
@@ -423,7 +424,8 @@ public class PythonOperatorChainingOptimizer {
                                 || upOperator instanceof ExternalPythonProcessOperator
                                 || upOperator instanceof ExternalPythonCoProcessOperator))
                 || (downOperator instanceof EmbeddedPythonProcessOperator
-                        && upOperator instanceof EmbeddedPythonProcessOperator);
+                        && (upOperator instanceof EmbeddedPythonProcessOperator
+                                || upOperator instanceof EmbeddedPythonKeyedProcessOperator));
     }
 
     private static boolean arePythonOperatorsInSameExecutionEnvironment(
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java
index 22df29dec4d..c5205987375 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractEmbeddedDataStreamPythonFunctionOperator.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.api.operators.python.embedded;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
 import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
 import org.apache.flink.util.Preconditions;
@@ -28,8 +29,7 @@ import org.apache.flink.util.Preconditions;
 import java.util.HashMap;
 import java.util.Map;
 
-import static org.apache.flink.python.Constants.STATEFUL_FUNCTION_URN;
-import static org.apache.flink.python.Constants.STATELESS_FUNCTION_URN;
+import static org.apache.flink.streaming.api.utils.PythonOperatorUtils.inBatchExecutionMode;
 
 /** Base class for all Python DataStream operators executed in embedded Python environment. */
 @Internal
@@ -47,23 +47,14 @@ public abstract class AbstractEmbeddedDataStreamPythonFunctionOperator<OUT>
     /** The TypeInformation of output data. */
     protected final TypeInformation<OUT> outputTypeInfo;
 
-    /** The function urn. */
-    final String functionUrn;
-
     /** The number of partitions for the partition custom function. */
     private Integer numPartitions;
 
     public AbstractEmbeddedDataStreamPythonFunctionOperator(
-            String functionUrn,
             Configuration config,
             DataStreamPythonFunctionInfo pythonFunctionInfo,
             TypeInformation<OUT> outputTypeInfo) {
         super(config);
-        Preconditions.checkArgument(
-                STATELESS_FUNCTION_URN.equals(functionUrn)
-                        || STATEFUL_FUNCTION_URN.equals(functionUrn),
-                "The function urn should be `STATELESS_FUNCTION_URN` or `STATEFUL_FUNCTION_URN`.");
-        this.functionUrn = functionUrn;
         this.pythonFunctionInfo = Preconditions.checkNotNull(pythonFunctionInfo);
         this.outputTypeInfo = Preconditions.checkNotNull(outputTypeInfo);
     }
@@ -88,6 +79,13 @@ public abstract class AbstractEmbeddedDataStreamPythonFunctionOperator<OUT>
         if (numPartitions != null) {
             jobParameters.put(NUM_PARTITIONS, String.valueOf(numPartitions));
         }
+
+        KeyedStateBackend<Object> keyedStateBackend = getKeyedStateBackend();
+        if (keyedStateBackend != null) {
+            jobParameters.put(
+                    "inBatchExecutionMode",
+                    String.valueOf(inBatchExecutionMode(keyedStateBackend)));
+        }
         return jobParameters;
     }
 }
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
index ec886ef29b5..42b096827f3 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java
@@ -53,19 +53,18 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
 
     private transient PythonTypeUtils.DataConverter<IN, Object> inputDataConverter;
 
-    private transient PythonTypeUtils.DataConverter<OUT, Object> outputDataConverter;
+    transient PythonTypeUtils.DataConverter<OUT, Object> outputDataConverter;
 
-    private transient TimestampedCollector<OUT> collector;
+    protected transient TimestampedCollector<OUT> collector;
 
     protected transient long timestamp;
 
     public AbstractOneInputEmbeddedPythonFunctionOperator(
-            String functionUrn,
             Configuration config,
             DataStreamPythonFunctionInfo pythonFunctionInfo,
             TypeInformation<IN> inputTypeInfo,
             TypeInformation<OUT> outputTypeInfo) {
-        super(functionUrn, config, pythonFunctionInfo, outputTypeInfo);
+        super(config, pythonFunctionInfo, outputTypeInfo);
         this.inputTypeInfo = Preconditions.checkNotNull(inputTypeInfo);
     }
 
@@ -84,7 +83,6 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
 
     @Override
     public void openPythonInterpreter() {
-        // function_urn = ...
         // function_protos = ...
         // input_coder_info = ...
         // output_coder_info = ...
@@ -95,13 +93,11 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
         // from pyflink.fn_execution.embedded.operation_utils import
         // create_one_input_user_defined_data_stream_function_from_protos
         //
-        // operation = create_one_input_user_defined_data_stream_function_from_protos(function_urn,
+        // operation = create_one_input_user_defined_data_stream_function_from_protos(
         //     function_protos, input_coder_info, output_coder_info, runtime_context,
         //     function_context, job_parameters)
         // operation.open()
 
-        interpreter.set("function_urn", functionUrn);
-
         interpreter.set(
                 "function_protos",
                 createUserDefinedFunctionsProto().stream()
@@ -127,6 +123,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
 
         interpreter.set("runtime_context", getRuntimeContext());
         interpreter.set("function_context", getFunctionContext());
+        interpreter.set("timer_context", getTimerContext());
         interpreter.set("job_parameters", getJobParameters());
 
         interpreter.exec(
@@ -134,12 +131,12 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
 
         interpreter.exec(
                 "operation = create_one_input_user_defined_data_stream_function_from_protos("
-                        + "function_urn,"
                         + "function_protos,"
                         + "input_coder_info,"
                         + "output_coder_info,"
                         + "runtime_context,"
                         + "function_context,"
+                        + "timer_context,"
                         + "job_parameters)");
 
         interpreter.invokeMethod("operation", "open");
@@ -153,7 +150,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
     }
 
     @Override
-    public void processElement(StreamRecord<IN> element) {
+    public void processElement(StreamRecord<IN> element) throws Exception {
         collector.setTimestamp(element);
         timestamp = element.getTimestamp();
 
@@ -169,6 +166,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
             OUT result = outputDataConverter.toInternal(results.next());
             collector.collect(result);
         }
+        results.close();
     }
 
     TypeInformation<IN> getInputTypeInfo() {
@@ -181,4 +179,7 @@ public abstract class AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
 
     /** Gets The function context. */
     public abstract Object getFunctionContext();
+
+    /** Gets The Timer context. */
+    public abstract Object getTimerContext();
 }
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java
new file mode 100644
index 00000000000..76fdf58058a
--- /dev/null
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedProcessOperator.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators.python.embedded;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.fnexecution.v1.FlinkFnApi;
+import org.apache.flink.python.util.ProtoUtils;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.SimpleTimerService;
+import org.apache.flink.streaming.api.TimeDomain;
+import org.apache.flink.streaming.api.TimerService;
+import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
+import org.apache.flink.streaming.api.operators.InternalTimer;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.api.utils.PythonTypeUtils;
+import org.apache.flink.types.Row;
+
+import pemja.core.object.PyIterator;
+
+import java.util.List;
+
+import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
+import static org.apache.flink.python.PythonOptions.MAP_STATE_WRITE_CACHE_SIZE;
+import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
+import static org.apache.flink.python.PythonOptions.PYTHON_PROFILE_ENABLED;
+import static org.apache.flink.python.PythonOptions.STATE_CACHE_SIZE;
+import static org.apache.flink.streaming.api.utils.PythonOperatorUtils.inBatchExecutionMode;
+
+/**
+ * {@link EmbeddedPythonKeyedProcessOperator} is responsible for executing user defined python
+ * KeyedProcessFunction in embedded Python environment. It is also able to handle the timer and
+ * state request from the python stateful user defined function.
+ */
+@Internal
+public class EmbeddedPythonKeyedProcessOperator<K, IN, OUT>
+        extends AbstractOneInputEmbeddedPythonFunctionOperator<IN, OUT>
+        implements Triggerable<K, VoidNamespace> {
+
+    private static final long serialVersionUID = 1L;
+
+    /** The TypeInformation of the key. */
+    private transient TypeInformation<K> keyTypeInfo;
+
+    private transient ContextImpl context;
+
+    private transient OnTimerContextImpl onTimerContext;
+
+    private transient PythonTypeUtils.DataConverter<K, Object> keyConverter;
+
+    public EmbeddedPythonKeyedProcessOperator(
+            Configuration config,
+            DataStreamPythonFunctionInfo pythonFunctionInfo,
+            TypeInformation<IN> inputTypeInfo,
+            TypeInformation<OUT> outputTypeInfo) {
+        super(config, pythonFunctionInfo, inputTypeInfo, outputTypeInfo);
+    }
+
+    @Override
+    public void open() throws Exception {
+        keyTypeInfo = ((RowTypeInfo) this.getInputTypeInfo()).getTypeAt(0);
+
+        keyConverter = PythonTypeUtils.TypeInfoToDataConverter.typeInfoDataConverter(keyTypeInfo);
+
+        InternalTimerService<VoidNamespace> internalTimerService =
+                getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this);
+
+        TimerService timerService = new SimpleTimerService(internalTimerService);
+
+        context = new ContextImpl(timerService);
+
+        onTimerContext = new OnTimerContextImpl(timerService);
+
+        super.open();
+    }
+
+    @Override
+    public List<FlinkFnApi.UserDefinedDataStreamFunction> createUserDefinedFunctionsProto() {
+        return ProtoUtils.createUserDefinedDataStreamStatefulFunctionProtos(
+                getPythonFunctionInfo(),
+                getRuntimeContext(),
+                getJobParameters(),
+                keyTypeInfo,
+                inBatchExecutionMode(getKeyedStateBackend()),
+                config.get(PYTHON_METRIC_ENABLED),
+                config.get(PYTHON_PROFILE_ENABLED),
+                false,
+                config.get(STATE_CACHE_SIZE),
+                config.get(MAP_STATE_READ_CACHE_SIZE),
+                config.get(MAP_STATE_WRITE_CACHE_SIZE));
+    }
+
+    @Override
+    public void onEventTime(InternalTimer<K, VoidNamespace> timer) throws Exception {
+        collector.setAbsoluteTimestamp(timer.getTimestamp());
+        invokeUserFunction(TimeDomain.EVENT_TIME, timer);
+    }
+
+    @Override
+    public void onProcessingTime(InternalTimer<K, VoidNamespace> timer) throws Exception {
+        collector.eraseTimestamp();
+        invokeUserFunction(TimeDomain.PROCESSING_TIME, timer);
+    }
+
+    @Override
+    public Object getFunctionContext() {
+        return context;
+    }
+
+    @Override
+    public Object getTimerContext() {
+        return onTimerContext;
+    }
+
+    @Override
+    public <T> AbstractEmbeddedDataStreamPythonFunctionOperator<T> copy(
+            DataStreamPythonFunctionInfo pythonFunctionInfo, TypeInformation<T> outputTypeInfo) {
+        return new EmbeddedPythonKeyedProcessOperator<>(
+                config, pythonFunctionInfo, getInputTypeInfo(), outputTypeInfo);
+    }
+
+    private void invokeUserFunction(TimeDomain timeDomain, InternalTimer<K, VoidNamespace> timer)
+            throws Exception {
+        onTimerContext.timeDomain = timeDomain;
+        onTimerContext.timer = timer;
+        PyIterator results =
+                (PyIterator)
+                        interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp());
+
+        while (results.hasNext()) {
+            OUT result = outputDataConverter.toInternal(results.next());
+            collector.collect(result);
+        }
+        results.close();
+
+        onTimerContext.timeDomain = null;
+        onTimerContext.timer = null;
+    }
+
+    private class ContextImpl {
+
+        private final TimerService timerService;
+
+        ContextImpl(TimerService timerService) {
+            this.timerService = timerService;
+        }
+
+        public long timestamp() {
+            return timestamp;
+        }
+
+        public TimerService timerService() {
+            return timerService;
+        }
+
+        @SuppressWarnings("unchecked")
+        public Object getCurrentKey() {
+            return keyConverter.toExternal(
+                    (K)
+                            ((Row) EmbeddedPythonKeyedProcessOperator.this.getCurrentKey())
+                                    .getField(0));
+        }
+    }
+
+    private class OnTimerContextImpl {
+
+        private final TimerService timerService;
+
+        private TimeDomain timeDomain;
+
+        private InternalTimer<K, VoidNamespace> timer;
+
+        OnTimerContextImpl(TimerService timerService) {
+            this.timerService = timerService;
+        }
+
+        public long timestamp() {
+            return timer.getTimestamp();
+        }
+
+        public TimerService timerService() {
+            return timerService;
+        }
+
+        public int timeDomain() {
+            return timeDomain.ordinal();
+        }
+
+        @SuppressWarnings("unchecked")
+        public Object getCurrentKey() {
+            return keyConverter.toExternal((K) ((Row) timer.getKey()).getField(0));
+        }
+    }
+}
diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java
index ad0ceebb98a..50e4feb921c 100644
--- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonProcessOperator.java
@@ -31,7 +31,6 @@ import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import java.util.HashMap;
 import java.util.List;
 
-import static org.apache.flink.python.Constants.STATELESS_FUNCTION_URN;
 import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE;
 import static org.apache.flink.python.PythonOptions.MAP_STATE_WRITE_CACHE_SIZE;
 import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED;
@@ -57,7 +56,7 @@ public class EmbeddedPythonProcessOperator<IN, OUT>
             DataStreamPythonFunctionInfo pythonFunctionInfo,
             TypeInformation<IN> inputTypeInfo,
             TypeInformation<OUT> outputTypeInfo) {
-        super(STATELESS_FUNCTION_URN, config, pythonFunctionInfo, inputTypeInfo, outputTypeInfo);
+        super(config, pythonFunctionInfo, inputTypeInfo, outputTypeInfo);
     }
 
     @Override
@@ -87,6 +86,11 @@ public class EmbeddedPythonProcessOperator<IN, OUT>
         return context;
     }
 
+    @Override
+    public Object getTimerContext() {
+        return null;
+    }
+
     @Override
     public <T> AbstractEmbeddedDataStreamPythonFunctionOperator<T> copy(
             DataStreamPythonFunctionInfo pythonFunctionInfo, TypeInformation<T> outputTypeInfo) {
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java
index 77d5049cfc6..756b1cb4645 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java
@@ -138,7 +138,7 @@ public class EmbeddedPythonTableFunctionOperator extends AbstractEmbeddedStatele
 
     @SuppressWarnings("unchecked")
     @Override
-    public void processElement(StreamRecord<RowData> element) {
+    public void processElement(StreamRecord<RowData> element) throws Exception {
         RowData value = element.getValue();
 
         for (int i = 0; i < userDefinedFunctionInputArgs.length; i++) {
@@ -165,5 +165,6 @@ public class EmbeddedPythonTableFunctionOperator extends AbstractEmbeddedStatele
         } else if (joinType == FlinkJoinType.LEFT) {
             rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseNullResultRowData));
         }
+        udtfResults.close();
     }
 }
diff --git a/flink-python/src/main/resources/META-INF/NOTICE b/flink-python/src/main/resources/META-INF/NOTICE
index 5a0397c922b..2911e2124b2 100644
--- a/flink-python/src/main/resources/META-INF/NOTICE
+++ b/flink-python/src/main/resources/META-INF/NOTICE
@@ -28,7 +28,7 @@ This project bundles the following dependencies under the Apache Software Licens
 - org.apache.beam:beam-vendor-sdks-java-extensions-protobuf:2.38.0
 - org.apache.beam:beam-vendor-guava-26_0-jre:0.1
 - org.apache.beam:beam-vendor-grpc-1_43_2:0.1
-- com.alibaba:pemja:0.2.0
+- com.alibaba:pemja:0.2.2
 
 This project bundles the following dependencies under the BSD license.
 See bundled license files for details