You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/04/28 00:27:51 UTC

[beam] 01/01: Revert "[BEAM-9639][BEAM-9608] Improvements for FnApiRunner"

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch revert-11270-fn-ref-more
in repository https://gitbox.apache.org/repos/asf/beam.git

commit e5e52694bcaf9a4cdc9fd4a130f8cca4dcc6fe6a
Author: Pablo <pa...@users.noreply.github.com>
AuthorDate: Mon Apr 27 17:27:34 2020 -0700

    Revert "[BEAM-9639][BEAM-9608] Improvements for FnApiRunner"
---
 .../runners/portability/fn_api_runner/execution.py | 179 +--------
 .../runners/portability/fn_api_runner/fn_runner.py | 406 +++++++++++++--------
 .../portability/fn_api_runner/fn_runner_test.py    |  24 --
 .../portability/fn_api_runner/translations.py      |  11 -
 4 files changed, 255 insertions(+), 365 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 2f29515..e62d8a8 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -22,29 +22,17 @@ from __future__ import absolute_import
 import collections
 import itertools
 from typing import TYPE_CHECKING
-from typing import Any
-from typing import DefaultDict
-from typing import Dict
-from typing import Iterator
-from typing import List
 from typing import MutableMapping
-from typing import Optional
-from typing import Tuple
 
 from typing_extensions import Protocol
 
 from apache_beam import coders
-from apache_beam.coders import BytesCoder
 from apache_beam.coders.coder_impl import create_InputStream
 from apache_beam.coders.coder_impl import create_OutputStream
-from apache_beam.coders.coders import GlobalWindowCoder
-from apache_beam.coders.coders import WindowedValueCoder
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners import pipeline_context
-from apache_beam.runners.portability.fn_api_runner import translations
-from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
 from apache_beam.runners.portability.fn_api_runner.translations import only_element
 from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
 from apache_beam.runners.portability.fn_api_runner.translations import unique_name
@@ -57,13 +45,8 @@ from apache_beam.utils import windowed_value
 
 if TYPE_CHECKING:
   from apache_beam.coders.coder_impl import CoderImpl
+  from apache_beam.runners.portability.fn_api_runner import translations
   from apache_beam.runners.portability.fn_api_runner import worker_handlers
-  from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
-  from apache_beam.transforms.window import BoundedWindow
-
-ENCODED_IMPULSE_VALUE = WindowedValueCoder(
-    BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested(
-        GlobalWindows.windowed_value(b''))
 
 
 class Buffer(Protocol):
@@ -221,7 +204,7 @@ class WindowGroupingBuffer(object):
   def __init__(
       self,
       access_pattern,
-      coder  # type: WindowedValueCoder
+      coder  # type: coders.WindowedValueCoder
   ):
     # type: (...) -> None
     # Here's where we would use a different type of partitioning
@@ -268,12 +251,11 @@ class WindowGroupingBuffer(object):
 
 class FnApiRunnerExecutionContext(object):
   """
- :var pcoll_buffers: (dict): Mapping of
+ :var pcoll_buffers: (collections.defaultdict of str: list): Mapping of
        PCollection IDs to list that functions as buffer for the
        ``beam.PCollection``.
  """
   def __init__(self,
-      stages,  # type: List[translations.Stage]
       worker_handler_manager,  # type: worker_handlers.WorkerHandlerManager
       pipeline_components,  # type: beam_runner_api_pb2.Components
       safe_coders,
@@ -286,9 +268,6 @@ class FnApiRunnerExecutionContext(object):
     :param safe_coders:
     :param data_channel_coders:
     """
-    self.stages = stages
-    self.side_input_descriptors_by_stage = (
-        self._build_data_side_inputs_map(stages))
     self.pcoll_buffers = {}  # type: MutableMapping[bytes, PartitionableBuffer]
     self.timer_buffers = {}  # type: MutableMapping[bytes, ListBuffer]
     self.worker_handler_manager = worker_handler_manager
@@ -301,63 +280,6 @@ class FnApiRunnerExecutionContext(object):
         iterable_state_write=self._iterable_state_write)
     self._last_uid = -1
 
-  @staticmethod
-  def _build_data_side_inputs_map(stages):
-    # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]
-
-    """Builds an index mapping stages to side input descriptors.
-
-    A side input descriptor is a map of side input IDs to side input access
-    patterns for all of the outputs of a stage that will be consumed as a
-    side input.
-    """
-    transform_consumers = collections.defaultdict(
-        list)  # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]]
-    stage_consumers = collections.defaultdict(
-        list)  # type: DefaultDict[str, List[translations.Stage]]
-
-    def get_all_side_inputs():
-      # type: () -> Set[str]
-      all_side_inputs = set()  # type: Set[str]
-      for stage in stages:
-        for transform in stage.transforms:
-          for input in transform.inputs.values():
-            transform_consumers[input].append(transform)
-            stage_consumers[input].append(stage)
-        for si in stage.side_inputs():
-          all_side_inputs.add(si)
-      return all_side_inputs
-
-    all_side_inputs = frozenset(get_all_side_inputs())
-    data_side_inputs_by_producing_stage = {}
-
-    producing_stages_by_pcoll = {}
-
-    for s in stages:
-      data_side_inputs_by_producing_stage[s.name] = {}
-      for transform in s.transforms:
-        for o in transform.outputs.values():
-          if o in s.side_inputs():
-            continue
-          producing_stages_by_pcoll[o] = s
-
-    for side_pc in all_side_inputs:
-      for consuming_transform in transform_consumers[side_pc]:
-        if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
-          continue
-        producing_stage = producing_stages_by_pcoll[side_pc]
-        payload = proto_utils.parse_Bytes(
-            consuming_transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
-        for si_tag in payload.side_inputs:
-          if consuming_transform.inputs[si_tag] == side_pc:
-            side_input_id = (consuming_transform.unique_name, si_tag)
-            data_side_inputs_by_producing_stage[
-                producing_stage.name][side_input_id] = (
-                    translations.create_buffer_id(side_pc),
-                    payload.side_inputs[si_tag].access_pattern)
-
-    return data_side_inputs_by_producing_stage
-
   @property
   def state_servicer(self):
     # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
@@ -379,43 +301,6 @@ class FnApiRunnerExecutionContext(object):
         out.get())
     return token
 
-  def commit_side_inputs_to_state(
-      self,
-      data_side_input,  # type: DataSideInput
-  ):
-    # type: (...) -> None
-    for (consuming_transform_id, tag), (buffer_id,
-                                        func_spec) in data_side_input.items():
-      _, pcoll_id = split_buffer_id(buffer_id)
-      value_coder = self.pipeline_context.coders[self.safe_coders[
-          self.data_channel_coders[pcoll_id]]]
-      elements_by_window = WindowGroupingBuffer(func_spec, value_coder)
-      if buffer_id not in self.pcoll_buffers:
-        self.pcoll_buffers[buffer_id] = ListBuffer(
-            coder_impl=value_coder.get_impl())
-      for element_data in self.pcoll_buffers[buffer_id]:
-        elements_by_window.append(element_data)
-
-      if func_spec.urn == common_urns.side_inputs.ITERABLE.urn:
-        for _, window, elements_data in elements_by_window.encoded_items():
-          state_key = beam_fn_api_pb2.StateKey(
-              iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
-                  transform_id=consuming_transform_id,
-                  side_input_id=tag,
-                  window=window))
-          self.state_servicer.append_raw(state_key, elements_data)
-      elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn:
-        for key, window, elements_data in elements_by_window.encoded_items():
-          state_key = beam_fn_api_pb2.StateKey(
-              multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
-                  transform_id=consuming_transform_id,
-                  side_input_id=tag,
-                  window=window,
-                  key=key))
-          self.state_servicer.append_raw(state_key, elements_data)
-      else:
-        raise ValueError("Unknown access pattern: '%s'" % func_spec.urn)
-
 
 class BundleContextManager(object):
 
@@ -482,64 +367,6 @@ class BundleContextManager(object):
         state_api_service_descriptor=self.state_api_service_descriptor(),
         timer_api_service_descriptor=self.data_api_service_descriptor())
 
-  def extract_bundle_inputs_and_outputs(self):
-    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]
-
-    """Returns maps of transform names to PCollection identifiers.
-
-    Also mutates IO stages to point to the data ApiServiceDescriptor.
-
-    Returns:
-      A tuple of (data_input, data_output, expected_timer_output) dictionaries.
-        `data_input` is a dictionary mapping (transform_name, output_name) to a
-        PCollection buffer; `data_output` is a dictionary mapping
-        (transform_name, output_name) to a PCollection ID.
-        `expected_timer_output` is a dictionary mapping transform_id and
-        timer family ID to a buffer id for timers.
-    """
-    data_input = {}  # type: Dict[str, PartitionableBuffer]
-    data_output = {}  # type: DataOutput
-    # A mapping of {(transform_id, timer_family_id) : buffer_id}
-    expected_timer_output = {}  # type: Dict[Tuple[str, str], str]
-    for transform in self.stage.transforms:
-      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
-                                bundle_processor.DATA_OUTPUT_URN):
-        pcoll_id = transform.spec.payload
-        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
-          coder_id = self.execution_context.data_channel_coders[only_element(
-              transform.outputs.values())]
-          coder = self.execution_context.pipeline_context.coders[
-              self.execution_context.safe_coders.get(coder_id, coder_id)]
-          if pcoll_id == translations.IMPULSE_BUFFER:
-            data_input[transform.unique_name] = ListBuffer(
-                coder_impl=coder.get_impl())
-            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
-          else:
-            if pcoll_id not in self.execution_context.pcoll_buffers:
-              self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
-                  coder_impl=coder.get_impl())
-            data_input[transform.unique_name] = (
-                self.execution_context.pcoll_buffers[pcoll_id])
-        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
-          data_output[transform.unique_name] = pcoll_id
-          coder_id = self.execution_context.data_channel_coders[only_element(
-              transform.inputs.values())]
-        else:
-          raise NotImplementedError
-        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
-        data_api_service_descriptor = self.data_api_service_descriptor()
-        if data_api_service_descriptor:
-          data_spec.api_service_descriptor.url = (
-              data_api_service_descriptor.url)
-        transform.spec.payload = data_spec.SerializeToString()
-      elif transform.spec.urn in translations.PAR_DO_URNS:
-        payload = proto_utils.parse_Bytes(
-            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
-        for timer_family_id in payload.timer_family_specs.keys():
-          expected_timer_output[(transform.unique_name, timer_family_id)] = (
-              create_buffer_id(timer_family_id, 'timers'))
-    return data_input, data_output, expected_timer_output
-
   def get_input_coder_impl(self, transform_id):
     # type: (str) -> CoderImpl
     coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 1f162ac..8e7ba2d 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -40,9 +40,11 @@ from typing import List
 from typing import Mapping
 from typing import MutableMapping
 from typing import Optional
+from typing import Sequence
 from typing import Tuple
 from typing import TypeVar
 
+import apache_beam as beam  # pylint: disable=ungrouped-imports
 from apache_beam.coders.coder_impl import create_OutputStream
 from apache_beam.metrics import metric
 from apache_beam.metrics import monitoring_infos
@@ -58,9 +60,12 @@ from apache_beam.runners.portability import portable_metrics
 from apache_beam.runners.portability.fn_api_runner import execution
 from apache_beam.runners.portability.fn_api_runner import translations
 from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
+from apache_beam.runners.portability.fn_api_runner.execution import WindowGroupingBuffer
 from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
 from apache_beam.runners.portability.fn_api_runner.translations import only_element
+from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
 from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager
+from apache_beam.runners.worker import bundle_processor
 from apache_beam.transforms import environments
 from apache_beam.utils import profiler
 from apache_beam.utils import proto_utils
@@ -68,6 +73,7 @@ from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
 
 if TYPE_CHECKING:
   from apache_beam.pipeline import Pipeline
+  from apache_beam.coders.coder_impl import CoderImpl
   from apache_beam.portability.api import metrics_pb2
 
 _LOGGER = logging.getLogger(__name__)
@@ -82,6 +88,11 @@ BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse,
 
 # This module is experimental. No backwards-compatibility guarantees.
 
+ENCODED_IMPULSE_VALUE = beam.coders.WindowedValueCoder(
+    beam.coders.BytesCoder(),
+    beam.coders.coders.GlobalWindowCoder()).get_impl().encode_nested(
+        beam.transforms.window.GlobalWindows.windowed_value(b''))
+
 
 class FnApiRunner(runner.PipelineRunner):
 
@@ -314,7 +325,6 @@ class FnApiRunner(runner.PipelineRunner):
     monitoring_infos_by_stage = {}
 
     runner_execution_context = execution.FnApiRunnerExecutionContext(
-        stages,
         worker_handler_manager,
         stage_context.components,
         stage_context.safe_coders,
@@ -325,7 +335,6 @@ class FnApiRunner(runner.PipelineRunner):
         for stage in stages:
           bundle_context_manager = execution.BundleContextManager(
               runner_execution_context, stage, self._num_workers)
-
           stage_results = self._run_stage(
               runner_execution_context,
               bundle_context_manager,
@@ -336,14 +345,54 @@ class FnApiRunner(runner.PipelineRunner):
       worker_handler_manager.close_all()
     return RunnerResult(runner.PipelineState.DONE, monitoring_infos_by_stage)
 
+  def _store_side_inputs_in_state(self,
+                                  runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
+                                  data_side_input,  # type: DataSideInput
+                                 ):
+    # type: (...) -> None
+    for (transform_id, tag), (buffer_id, si) in data_side_input.items():
+      _, pcoll_id = split_buffer_id(buffer_id)
+      value_coder = runner_execution_context.pipeline_context.coders[
+          runner_execution_context.safe_coders[
+              runner_execution_context.data_channel_coders[pcoll_id]]]
+      elements_by_window = WindowGroupingBuffer(si, value_coder)
+      if buffer_id not in runner_execution_context.pcoll_buffers:
+        runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer(
+            coder_impl=value_coder.get_impl())
+      for element_data in runner_execution_context.pcoll_buffers[buffer_id]:
+        elements_by_window.append(element_data)
+
+      if si.urn == common_urns.side_inputs.ITERABLE.urn:
+        for _, window, elements_data in elements_by_window.encoded_items():
+          state_key = beam_fn_api_pb2.StateKey(
+              iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
+                  transform_id=transform_id, side_input_id=tag, window=window))
+          (
+              runner_execution_context.worker_handler_manager.state_servicer.
+              append_raw(state_key, elements_data))
+      elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
+        for key, window, elements_data in elements_by_window.encoded_items():
+          state_key = beam_fn_api_pb2.StateKey(
+              multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
+                  transform_id=transform_id,
+                  side_input_id=tag,
+                  window=window,
+                  key=key))
+          (
+              runner_execution_context.worker_handler_manager.state_servicer.
+              append_raw(state_key, elements_data))
+      else:
+        raise ValueError("Unknown access pattern: '%s'" % si.urn)
+
   def _run_bundle_multiple_times_for_testing(
       self,
       runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
-      bundle_manager,  # type: BundleManager
+      bundle_context_manager,  # type: execution.BundleContextManager
       data_input,
       data_output,  # type: DataOutput
       fired_timers,
       expected_output_timers,
+      cache_token_generator
   ):
     # type: (...) -> None
 
@@ -354,12 +403,18 @@ class FnApiRunner(runner.PipelineRunner):
     for _ in range(self._bundle_repeat):
       try:
         runner_execution_context.state_servicer.checkpoint()
-        bundle_manager.process_bundle(
-            data_input,
-            data_output,
-            fired_timers,
-            expected_output_timers,
-            dry_run=True)
+        testing_bundle_manager = ParallelBundleManager(
+            bundle_context_manager.worker_handlers,
+            lambda pcoll_id,
+            transform_id: ListBuffer(
+                coder_impl=bundle_context_manager.get_input_coder_impl),
+            bundle_context_manager.get_input_coder_impl,
+            bundle_context_manager.process_bundle_descriptor,
+            self._progress_frequency,
+            num_workers=self._num_workers,
+            cache_token_generator=cache_token_generator)
+        testing_bundle_manager.process_bundle(
+            data_input, data_output, fired_timers, expected_output_timers)
       finally:
         runner_execution_context.state_servicer.restore()
 
@@ -389,17 +444,6 @@ class FnApiRunner(runner.PipelineRunner):
         fired_timers[(transform_id, timer_family_id)].append(out.get())
         written_timers.clear()
 
-  def _add_sdk_delayed_applications_to_deferred_inputs(
-      self, bundle_context_manager, bundle_result, deferred_inputs):
-    for delayed_application in bundle_result.process_bundle.residual_roots:
-      name = bundle_context_manager.input_for(
-          delayed_application.application.transform_id,
-          delayed_application.application.input_id)
-      if name not in deferred_inputs:
-        deferred_inputs[name] = ListBuffer(
-            coder_impl=bundle_context_manager.get_input_coder_impl(name))
-      deferred_inputs[name].append(delayed_application.application.element)
-
   def _add_residuals_and_channel_splits_to_deferred_inputs(
       self,
       splits,  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
@@ -463,115 +507,164 @@ class FnApiRunner(runner.PipelineRunner):
     Args:
       runner_execution_context (execution.FnApiRunnerExecutionContext): An
         object containing execution information for the pipeline.
-      bundle_context_manager (execution.BundleContextManager): A description of
-        the stage to execute, and its context.
+      stage (translations.Stage): A description of the stage to execute.
     """
-    data_input, data_output, expected_timer_output = (
-        bundle_context_manager.extract_bundle_inputs_and_outputs())
-    input_timers = {}
-
+    worker_handler_list = bundle_context_manager.worker_handlers
     worker_handler_manager = runner_execution_context.worker_handler_manager
     _LOGGER.info('Running %s', bundle_context_manager.stage.name)
+    (data_input, data_side_input, data_output,
+     expected_timer_output) = self._extract_endpoints(
+         bundle_context_manager, runner_execution_context)
     worker_handler_manager.register_process_bundle_descriptor(
         bundle_context_manager.process_bundle_descriptor)
 
-    # We create the bundle manager here, as it can be reused for bundles of the
-    # same stage, but it may have to be created by-bundle later on.
+    # Store the required side inputs into state so it is accessible for the
+    # worker when it runs this bundle.
+    self._store_side_inputs_in_state(runner_execution_context, data_side_input)
+
+    # Change cache token across bundle repeats
     cache_token_generator = FnApiRunner.get_cache_token_generator(static=False)
-    bundle_manager = ParallelBundleManager(
+
+    self._run_bundle_multiple_times_for_testing(
+        runner_execution_context,
         bundle_context_manager,
+        data_input,
+        data_output, {},
+        expected_timer_output,
+        cache_token_generator=cache_token_generator)
+
+    bundle_manager = ParallelBundleManager(
+        worker_handler_list,
+        bundle_context_manager.get_buffer,
+        bundle_context_manager.get_input_coder_impl,
+        bundle_context_manager.process_bundle_descriptor,
         self._progress_frequency,
+        num_workers=self._num_workers,
         cache_token_generator=cache_token_generator)
 
-    final_result = None
+    # For the first time of processing, we don't have fired timers as inputs.
+    result, splits = bundle_manager.process_bundle(data_input,
+                                                   data_output,
+                                                   {},
+                                                   expected_timer_output)
 
-    def merge_results(last_result):
-      """ Merge the latest result with other accumulated results. """
-      return (
-          last_result
-          if final_result is None else beam_fn_api_pb2.InstructionResponse(
-              process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
-                  monitoring_infos=monitoring_infos.consolidate(
-                      itertools.chain(
-                          final_result.process_bundle.monitoring_infos,
-                          last_result.process_bundle.monitoring_infos))),
-              error=final_result.error or last_result.error))
+    last_result = result
+    last_sent = data_input
 
+    # We cannot split deferred_input until we include residual_roots to
+    # merged results. Without residual_roots, pipeline stops earlier and we
+    # may miss some data.
+    # We also don't partition fired timer inputs for the same reason.
+    bundle_manager._num_workers = 1
     while True:
-      last_result, deferred_inputs, fired_timers = self._run_bundle(
-              runner_execution_context,
-              bundle_context_manager,
-              data_input,
-              data_output,
-              input_timers,
-              expected_timer_output,
-              bundle_manager)
-
-      final_result = merge_results(last_result)
-      if not deferred_inputs and not fired_timers:
-        break
+      deferred_inputs = {}  # type: Dict[str, PartitionableBuffer]
+      fired_timers = {}
+
+      self._collect_written_timers_and_add_to_fired_timers(
+          bundle_context_manager, fired_timers)
+      # Queue any process-initiated delayed bundle applications.
+      for delayed_application in last_result.process_bundle.residual_roots:
+        name = bundle_context_manager.input_for(
+            delayed_application.application.transform_id,
+            delayed_application.application.input_id)
+        if name not in deferred_inputs:
+          deferred_inputs[name] = ListBuffer(
+              coder_impl=bundle_context_manager.get_input_coder_impl(name))
+        deferred_inputs[name].append(delayed_application.application.element)
+      # Queue any runner-initiated delayed bundle applications.
+      self._add_residuals_and_channel_splits_to_deferred_inputs(
+          splits, bundle_context_manager, last_sent, deferred_inputs)
+
+      if deferred_inputs or fired_timers:
+        # The worker will be waiting on these inputs as well.
+        for other_input in data_input:
+          if other_input not in deferred_inputs:
+            deferred_inputs[other_input] = ListBuffer(
+                coder_impl=bundle_context_manager.get_input_coder_impl(
+                    other_input))
+        # TODO(robertwb): merge results
+        last_result, splits = bundle_manager.process_bundle(
+            deferred_inputs, data_output, fired_timers, expected_timer_output)
+        last_sent = deferred_inputs
+        result = beam_fn_api_pb2.InstructionResponse(
+            process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
+                monitoring_infos=monitoring_infos.consolidate(
+                    itertools.chain(
+                        result.process_bundle.monitoring_infos,
+                        last_result.process_bundle.monitoring_infos))),
+            error=result.error or last_result.error)
       else:
-        data_input = deferred_inputs
-        input_timers = fired_timers
-        bundle_manager._registered = True
+        break
 
-    # Store the required downstream side inputs into state so it is accessible
-    # for the worker when it runs bundles that consume this stage's output.
-    data_side_input = (
-        runner_execution_context.side_input_descriptors_by_stage.get(
-            bundle_context_manager.stage.name, {}))
-    runner_execution_context.commit_side_inputs_to_state(data_side_input)
+    return result
 
-    return final_result
+  @staticmethod
+  def _extract_endpoints(bundle_context_manager,  # type: execution.BundleContextManager
+                         runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
+                         ):
+    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput]
 
-  def _run_bundle(
-      self,
-      runner_execution_context,
-      bundle_context_manager,
-      data_input,
-      data_output,
-      input_timers,
-      expected_timer_output,
-      bundle_manager):
-    """Execute a bundle, and return a result object, and deferred inputs."""
-    self._run_bundle_multiple_times_for_testing(
-        runner_execution_context,
-        bundle_manager,
-        data_input,
-        data_output,
-        input_timers,
-        expected_timer_output)
-
-    result, splits = bundle_manager.process_bundle(
-        data_input, data_output, input_timers, expected_timer_output)
-    # Now we collect all the deferred inputs remaining from bundle execution.
-    # Deferred inputs can be:
-    # - timers
-    # - SDK-initiated deferred applications of root elements
-    # - Runner-initiated deferred applications of root elements
-    deferred_inputs = {}  # type: Dict[str, execution.PartitionableBuffer]
-    fired_timers = {}
-
-    self._collect_written_timers_and_add_to_fired_timers(
-        bundle_context_manager, fired_timers)
-
-    self._add_sdk_delayed_applications_to_deferred_inputs(
-        bundle_context_manager, result, deferred_inputs)
-
-    self._add_residuals_and_channel_splits_to_deferred_inputs(
-        splits, bundle_context_manager, data_input, deferred_inputs)
-
-    # After collecting deferred inputs, we 'pad' the structure with empty
-    # buffers for other expected inputs.
-    if deferred_inputs or fired_timers:
-      # The worker will be waiting on these inputs as well.
-      for other_input in data_input:
-        if other_input not in deferred_inputs:
-          deferred_inputs[other_input] = ListBuffer(
-              coder_impl=bundle_context_manager.get_input_coder_impl(
-                  other_input))
-
-    return result, deferred_inputs, fired_timers
+    """Returns maps of transform names to PCollection identifiers.
+
+    Also mutates IO stages to point to the data ApiServiceDescriptor.
+
+    Args:
+      stage (translations.Stage): The stage to extract endpoints
+        for.
+      data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
+    Returns:
+      A tuple of (data_input, data_side_input, data_output) dictionaries.
+        `data_input` is a dictionary mapping (transform_name, output_name) to a
+        PCollection buffer; `data_output` is a dictionary mapping
+        (transform_name, output_name) to a PCollection ID.
+    """
+    data_input = {}  # type: Dict[str, PartitionableBuffer]
+    data_side_input = {}  # type: DataSideInput
+    data_output = {}  # type: DataOutput
+    # A mapping of {(transform_id, timer_family_id) : buffer_id}
+    expected_timer_output = {}  # type: Dict[Tuple(str, str), str]
+    for transform in bundle_context_manager.stage.transforms:
+      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
+                                bundle_processor.DATA_OUTPUT_URN):
+        pcoll_id = transform.spec.payload
+        if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+          coder_id = runner_execution_context.data_channel_coders[only_element(
+              transform.outputs.values())]
+          coder = runner_execution_context.pipeline_context.coders[
+              runner_execution_context.safe_coders.get(coder_id, coder_id)]
+          if pcoll_id == translations.IMPULSE_BUFFER:
+            data_input[transform.unique_name] = ListBuffer(
+                coder_impl=coder.get_impl())
+            data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
+          else:
+            if pcoll_id not in runner_execution_context.pcoll_buffers:
+              runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
+                  coder_impl=coder.get_impl())
+            data_input[transform.unique_name] = (
+                runner_execution_context.pcoll_buffers[pcoll_id])
+        elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
+          data_output[transform.unique_name] = pcoll_id
+          coder_id = runner_execution_context.data_channel_coders[only_element(
+              transform.inputs.values())]
+        else:
+          raise NotImplementedError
+        data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
+        data_api_service_descriptor = (
+            bundle_context_manager.data_api_service_descriptor())
+        if data_api_service_descriptor:
+          data_spec.api_service_descriptor.url = (
+              data_api_service_descriptor.url)
+        transform.spec.payload = data_spec.SerializeToString()
+      elif transform.spec.urn in translations.PAR_DO_URNS:
+        payload = proto_utils.parse_Bytes(
+            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+        for tag, si in payload.side_inputs.items():
+          data_side_input[transform.unique_name, tag] = (
+              create_buffer_id(transform.inputs[tag]), si.access_pattern)
+        for timer_family_id in payload.timer_family_specs.keys():
+          expected_timer_output[(transform.unique_name, timer_family_id)] = (
+              create_buffer_id(timer_family_id, 'timers'))
+    return data_input, data_side_input, data_output, expected_timer_output
 
   @staticmethod
   def get_cache_token_generator(static=True):
@@ -672,18 +765,28 @@ class BundleManager(object):
   _lock = threading.Lock()
 
   def __init__(self,
-               bundle_context_manager,  # type: execution.BundleContextManager
+               worker_handler_list,  # type: Sequence[WorkerHandler]
+               get_buffer,  # type: Callable[[bytes, str], PartitionableBuffer]
+               get_input_coder_impl,  # type: Callable[[str], CoderImpl]
+               bundle_descriptor,  # type: beam_fn_api_pb2.ProcessBundleDescriptor
                progress_frequency=None,
                cache_token_generator=FnApiRunner.get_cache_token_generator()
               ):
     """Set up a bundle manager.
 
     Args:
+      worker_handler_list
+      get_buffer (Callable[[str], list])
+      get_input_coder_impl (Callable[[str], Coder])
+      bundle_descriptor (beam_fn_api_pb2.ProcessBundleDescriptor)
       progress_frequency
     """
-    self.bundle_context_manager = bundle_context_manager  # type: execution.BundleContextManager
+    self._worker_handler_list = worker_handler_list
+    self._get_buffer = get_buffer
+    self._get_input_coder_impl = get_input_coder_impl
+    self._bundle_descriptor = bundle_descriptor
     self._progress_frequency = progress_frequency
-    self._worker_handler = None  # type: Optional[execution.WorkerHandler]
+    self._worker_handler = None  # type: Optional[WorkerHandler]
     self._cache_token_generator = cache_token_generator
 
   def _send_input_to_worker(self,
@@ -711,8 +814,7 @@ class BundleManager(object):
   def _select_split_manager(self):
     """TODO(pabloem) WHAT DOES THIS DO"""
     unique_names = set(
-        t.unique_name for t in self.bundle_context_manager.
-        process_bundle_descriptor.transforms.values())
+        t.unique_name for t in self._bundle_descriptor.transforms.values())
     for stage_name, candidate in reversed(_split_managers):
       if (stage_name in unique_names or
           (stage_name + '/Process') in unique_names):
@@ -733,8 +835,8 @@ class BundleManager(object):
     byte_stream = b''.join(buffer_data)
     num_elements = len(
         list(
-            self.bundle_context_manager.get_input_coder_impl(
-                read_transform_id).decode_all(byte_stream)))
+            self._get_input_coder_impl(read_transform_id).decode_all(
+                byte_stream)))
 
     # Start the split manager in case it wants to set any breakpoints.
     split_manager_generator = split_manager(num_elements)
@@ -787,20 +889,18 @@ class BundleManager(object):
     return split_results
 
   def process_bundle(self,
-                     inputs,  # type: Mapping[str, execution.PartitionableBuffer]
+                     inputs,  # type: Mapping[str, PartitionableBuffer]
                      expected_outputs,  # type: DataOutput
-                     fired_timers,  # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
-                     expected_output_timers,  # type: Dict[Tuple[str, str], str]
-                     dry_run=False,
+                     fired_timers,  # type: Mapping[Tuple[str, str], PartitionableBuffer]
+                     expected_output_timers  # type: Dict[str, Dict[str, str]]
                     ):
     # type: (...) -> BundleProcessResult
     # Unique id for the instruction processing this bundle.
     with BundleManager._lock:
       BundleManager._uid_counter += 1
       process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
-      self._worker_handler = self.bundle_context_manager.worker_handlers[
-          BundleManager._uid_counter %
-          len(self.bundle_context_manager.worker_handlers)]
+      self._worker_handler = self._worker_handler_list[
+          BundleManager._uid_counter % len(self._worker_handler_list)]
 
     split_manager = self._select_split_manager()
     if not split_manager:
@@ -820,8 +920,7 @@ class BundleManager(object):
     process_bundle_req = beam_fn_api_pb2.InstructionRequest(
         instruction_id=process_bundle_id,
         process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
-            process_bundle_descriptor_id=self.bundle_context_manager.
-            process_bundle_descriptor.id,
+            process_bundle_descriptor_id=self._bundle_descriptor.id,
             cache_tokens=[next(self._cache_token_generator)]))
     result_future = self._worker_handler.control_conn.push(process_bundle_req)
 
@@ -843,15 +942,15 @@ class BundleManager(object):
           expect_reads,
           abort_callback=lambda:
           (result_future.is_done() and result_future.get().error)):
-        if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run:
+        if isinstance(output, beam_fn_api_pb2.Elements.Timers):
           with BundleManager._lock:
-            self.bundle_context_manager.get_buffer(
+            self._get_buffer(
                 expected_output_timers[(
                     output.transform_id, output.timer_family_id)],
                 output.transform_id).append(output.timers)
-        if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run:
+        if isinstance(output, beam_fn_api_pb2.Elements.Data):
           with BundleManager._lock:
-            self.bundle_context_manager.get_buffer(
+            self._get_buffer(
                 expected_outputs[output.transform_id],
                 output.transform_id).append(output.data)
 
@@ -874,32 +973,32 @@ class ParallelBundleManager(BundleManager):
 
   def __init__(
       self,
-      bundle_context_manager,  # type: execution.BundleContextManager
+      worker_handler_list,  # type: Sequence[WorkerHandler]
+      get_buffer,  # type: Callable[[bytes, str], PartitionableBuffer]
+      get_input_coder_impl,  # type: Callable[[str], CoderImpl]
+      bundle_descriptor,  # type: beam_fn_api_pb2.ProcessBundleDescriptor
       progress_frequency=None,
       cache_token_generator=None,
       **kwargs):
     # type: (...) -> None
     super(ParallelBundleManager, self).__init__(
-        bundle_context_manager,
+        worker_handler_list,
+        get_buffer,
+        get_input_coder_impl,
+        bundle_descriptor,
         progress_frequency,
         cache_token_generator=cache_token_generator)
-    self._num_workers = bundle_context_manager.num_workers
+    self._num_workers = kwargs.pop('num_workers', 1)
 
   def process_bundle(self,
-                     inputs,  # type: Mapping[str, execution.PartitionableBuffer]
+                     inputs,  # type: Mapping[str, PartitionableBuffer]
                      expected_outputs,  # type: DataOutput
-                     fired_timers,  # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
-                     expected_output_timers,  # type: Dict[Tuple[str, str], str]
-                     dry_run=False,
-                    ):
+                     fired_timers,  # type: Mapping[Tuple[str, str], PartitionableBuffer]
+                     expected_output_timers  # type: Dict[Tuple[str, str], str]
+                     ):
     # type: (...) -> BundleProcessResult
     part_inputs = [{} for _ in range(self._num_workers)
                    ]  # type: List[Dict[str, List[bytes]]]
-    # Timers are only executed on the first worker
-    # TODO(BEAM-9741): Split timers to multiple workers
-    timer_inputs = [
-        fired_timers if i == 0 else {} for i in range(self._num_workers)
-    ]
     for name, input in inputs.items():
       for ix, part in enumerate(input.partition(self._num_workers)):
         part_inputs[ix][name] = part
@@ -908,23 +1007,21 @@ class ParallelBundleManager(BundleManager):
     split_result_list = [
     ]  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
 
-    def execute(part_map_input_timers):
+    def execute(part_map):
       # type: (...) -> BundleProcessResult
-      part_map, input_timers = part_map_input_timers
       bundle_manager = BundleManager(
-          self.bundle_context_manager,
+          self._worker_handler_list,
+          self._get_buffer,
+          self._get_input_coder_impl,
+          self._bundle_descriptor,
           self._progress_frequency,
           cache_token_generator=self._cache_token_generator)
       return bundle_manager.process_bundle(
-          part_map,
-          expected_outputs,
-          input_timers,
-          expected_output_timers,
-          dry_run)
+          part_map, expected_outputs, fired_timers, expected_output_timers)
 
     with UnboundedThreadPoolExecutor() as executor:
-      for result, split_result in executor.map(execute, zip(part_inputs,  # pylint: disable=zip-builtin-not-iterating
-                                                            timer_inputs)):
+      for result, split_result in executor.map(execute, part_inputs):
+
         split_result_list += split_result
         if merged_result is None:
           merged_result = result
@@ -937,6 +1034,7 @@ class ParallelBundleManager(BundleManager):
                           merged_result.process_bundle.monitoring_infos))),
               error=result.error or merged_result.error)
     assert merged_result is not None
+
     return merged_result, split_result_list
 
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
index ce99d87..0ff0853 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
@@ -240,30 +240,6 @@ class FnApiRunnerTest(unittest.TestCase):
               lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
           equal_to([('a', [1, 3]), ('b', [2])]))
 
-  def test_multimap_multiside_input(self):
-    # A test where two transforms in the same stage consume the same PCollection
-    # twice as side input.
-    with self.create_pipeline() as p:
-      main = p | 'main' >> beam.Create(['a', 'b'])
-      side = (
-          p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)])
-          # TODO(BEAM-4782): Obviate the need for this map.
-          | beam.Map(lambda kv: (kv[0], kv[1])))
-      assert_that(
-          main | 'first map' >> beam.Map(
-              lambda k,
-              d,
-              l: (k, sorted(d[k]), sorted([e[1] for e in l])),
-              beam.pvalue.AsMultiMap(side),
-              beam.pvalue.AsList(side))
-          | 'second map' >> beam.Map(
-              lambda k,
-              d,
-              l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])),
-              beam.pvalue.AsMultiMap(side),
-              beam.pvalue.AsList(side)),
-          equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])]))
-
   def test_multimap_side_input_type_coercion(self):
     with self.create_pipeline() as p:
       main = p | 'main' >> beam.Create(['a', 'b'])
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index 235aec8..5d18d29 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -75,17 +75,6 @@ PAR_DO_URNS = frozenset([
 
 IMPULSE_BUFFER = b'impulse'
 
-# SideInputId is identified by a consumer ParDo + tag.
-SideInputId = Tuple[str, str]
-SideInputAccessPattern = beam_runner_api_pb2.FunctionSpec
-
-DataOutput = Dict[str, bytes]
-
-# DataSideInput maps SideInputIds to a tuple of the encoded bytes of the side
-# input content, and a payload specification regarding the type of side input
-# (MultiMap / Iterable).
-DataSideInput = Dict[SideInputId, Tuple[bytes, SideInputAccessPattern]]
-
 
 class Stage(object):
   """A set of Transforms that can be sent to the worker for processing."""