You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2019/01/31 13:42:14 UTC

[beam] branch master updated: [BEAM-6243] Add an experiment to use Python's optimizer on Flink.

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 06456f8  [BEAM-6243] Add an experiment to use Python's optimizer on Flink.
     new ee85ea1  Merge pull request #7297  [BEAM-6243] Add an experiment to use Python's optimizer on Flink.
06456f8 is described below

commit 06456f8b137d87c07c8b7d26b5d4be448e17affb
Author: Robert Bradshaw <ro...@google.com>
AuthorDate: Thu Dec 13 17:09:08 2018 +0100

    [BEAM-6243] Add an experiment to use Python's optimizer on Flink.
---
 .../FlinkBatchPortablePipelineTranslator.java      |   3 +-
 .../beam/runners/flink/FlinkJobInvocation.java     |   8 +-
 .../jobsubmission/InMemoryJobService.java          |   2 +
 .../python/apache_beam/options/pipeline_options.py |  10 +
 .../runners/portability/flink_runner_test.py       |   7 +-
 .../runners/portability/fn_api_runner.py           |  52 ++--
 .../portability/fn_api_runner_transforms.py        | 296 +++++++++++++++++----
 .../runners/portability/portable_runner.py         |  41 ++-
 sdks/python/build.gradle                           |  11 +-
 9 files changed, 341 insertions(+), 89 deletions(-)

diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index 8f9fcb9..496d737 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -412,7 +412,8 @@ public class FlinkBatchPortablePipelineTranslator
 
     for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
       String collectionId =
-          components
+          stagePayload
+              .getComponents()
               .getTransformsOrThrow(sideInputId.getTransformId())
               .getInputsOrThrow(sideInputId.getLocalName());
       // Register under the global PCollection name. Only ExecutableStageFunction needs to know the
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkJobInvocation.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkJobInvocation.java
index efaa40a..d3e95dd 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkJobInvocation.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkJobInvocation.java
@@ -34,6 +34,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.fnexecution.jobsubmission.JobInvocation;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
@@ -127,7 +128,12 @@ public class FlinkJobInvocation implements JobInvocation {
                 ImmutableSet.of(PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN)));
 
     // Fused pipeline proto.
-    RunnerApi.Pipeline fusedPipeline = GreedyPipelineFuser.fuse(trimmedPipeline).toPipeline();
+    // TODO: Consider supporting partially-fused graphs.
+    RunnerApi.Pipeline fusedPipeline =
+        trimmedPipeline.getComponents().getTransformsMap().values().stream()
+                .anyMatch(proto -> ExecutableStage.URN.equals(proto.getSpec().getUrn()))
+            ? pipeline
+            : GreedyPipelineFuser.fuse(pipeline).toPipeline();
     JobInfo jobInfo =
         JobInfo.create(
             id,
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java
index 545641a..69ce491 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java
@@ -169,7 +169,9 @@ public class InMemoryJobService extends JobServiceGrpc.JobServiceImplBase implem
       try {
         PipelineValidator.validate(preparation.pipeline());
       } catch (Exception e) {
+        LOG.warn("Encountered Unexpected Exception during validation", e);
         responseObserver.onError(new StatusRuntimeException(Status.INVALID_ARGUMENT.withCause(e)));
+        return;
       }
 
       // create new invocation
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index ce55d9b..a7c74b4 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -582,6 +582,16 @@ class DebugOptions(PipelineOptions):
          'enabled with this flag. Please sync with the owners of the runner '
          'before enabling any experiments.'))
 
+  def lookup_experiment(self, key, default=None):
+    if not self.experiments:
+      return default
+    elif key in self.experiments:
+      return True
+    for experiment in self.experiments:
+      if experiment.startswith(key + '='):
+        return experiment.split('=', 1)[1]
+    return default
+
 
 class ProfilingOptions(PipelineOptions):
 
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 297dfb2..2ae23d4 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -44,6 +44,7 @@ if __name__ == '__main__':
   #     --flink_job_server_jar=/path/to/job_server.jar \
   #     --type=Batch \
   #     --environment_type=docker \
+  #     --extra_experiments=beam_experiments \
   #     [FlinkRunnerTest.test_method, ...]
 
   parser = argparse.ArgumentParser(add_help=True)
@@ -54,6 +55,8 @@ if __name__ == '__main__':
   parser.add_argument('--environment_type', default='docker',
                       help='Environment type. docker or process')
   parser.add_argument('--environment_config', help='Environment config.')
+  parser.add_argument('--extra_experiments', default=[], action='append',
+                      help='Beam experiments config.')
   known_args, args = parser.parse_known_args(sys.argv)
   sys.argv = args
 
@@ -62,6 +65,7 @@ if __name__ == '__main__':
   environment_type = known_args.environment_type.lower()
   environment_config = (
       known_args.environment_config if known_args.environment_config else None)
+  extra_experiments = known_args.extra_experiments
 
   # This is defined here to only be run when we invoke this file explicitly.
   class FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
@@ -127,7 +131,8 @@ if __name__ == '__main__':
 
     def create_options(self):
       options = super(FlinkRunnerTest, self).create_options()
-      options.view_as(DebugOptions).experiments = ['beam_fn_api']
+      options.view_as(DebugOptions).experiments = [
+          'beam_fn_api'] + extra_experiments
       options._all_options['parallelism'] = 1
       options._all_options['shutdown_sources_on_final_watermark'] = True
       options.view_as(PortableOptions).environment_type = (
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 1253857..be8799b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -317,39 +317,27 @@ class FnApiRunner(runner.PipelineRunner):
       yield
 
   def create_stages(self, pipeline_proto):
-
-    pipeline_context = fn_api_runner_transforms.TransformContext(
-        copy.deepcopy(pipeline_proto.components),
+    return fn_api_runner_transforms.create_and_optimize_stages(
+        copy.deepcopy(pipeline_proto),
+        phases=[fn_api_runner_transforms.annotate_downstream_side_inputs,
+                fn_api_runner_transforms.fix_side_input_pcoll_coders,
+                fn_api_runner_transforms.lift_combiners,
+                fn_api_runner_transforms.expand_gbk,
+                fn_api_runner_transforms.sink_flattens,
+                fn_api_runner_transforms.greedily_fuse,
+                fn_api_runner_transforms.read_to_impulse,
+                fn_api_runner_transforms.impulse_to_input,
+                fn_api_runner_transforms.inject_timer_pcollections,
+                fn_api_runner_transforms.sort_stages,
+                fn_api_runner_transforms.window_pcollection_coders],
+        known_runner_urns=frozenset([
+            common_urns.primitives.FLATTEN.urn,
+            common_urns.primitives.GROUP_BY_KEY.urn]),
         use_state_iterables=self._use_state_iterables)
 
-    # Initial set of stages are singleton leaf transforms.
-    stages = list(fn_api_runner_transforms.leaf_transform_stages(
-        pipeline_proto.root_transform_ids,
-        pipeline_proto.components))
-
-    # Apply each phase in order.
-    for phase in [
-        fn_api_runner_transforms.annotate_downstream_side_inputs,
-        fn_api_runner_transforms.fix_side_input_pcoll_coders,
-        fn_api_runner_transforms.lift_combiners,
-        fn_api_runner_transforms.expand_gbk,
-        fn_api_runner_transforms.sink_flattens,
-        fn_api_runner_transforms.greedily_fuse,
-        fn_api_runner_transforms.read_to_impulse,
-        fn_api_runner_transforms.impulse_to_input,
-        fn_api_runner_transforms.inject_timer_pcollections,
-        fn_api_runner_transforms.sort_stages,
-        fn_api_runner_transforms.window_pcollection_coders]:
-      logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
-      stages = list(phase(stages, pipeline_context))
-      logging.debug('Stages: %s', [str(s) for s in stages])
-
-    # Return the (possibly mutated) context and ordered set of stages.
-    return pipeline_context.components, stages, pipeline_context.safe_coders
-
-  def run_stages(self, pipeline_components, stages, safe_coders):
+  def run_stages(self, stage_context, stages):
     worker_handler_manager = WorkerHandlerManager(
-        pipeline_components.environments, self._provision_info)
+        stage_context.components.environments, self._provision_info)
     metrics_by_stage = {}
     monitoring_infos_by_stage = {}
 
@@ -359,10 +347,10 @@ class FnApiRunner(runner.PipelineRunner):
         for stage in stages:
           stage_results = self.run_stage(
               worker_handler_manager.get_worker_handler,
-              pipeline_components,
+              stage_context.components,
               stage,
               pcoll_buffers,
-              safe_coders)
+              stage_context.safe_coders)
           metrics_by_stage[stage.name] = stage_results.process_bundle.metrics
           monitoring_infos_by_stage[stage.name] = (
               stage_results.process_bundle.monitoring_infos)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 8152c0c..21f8fa2 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -46,7 +46,7 @@ class Stage(object):
   """A set of Transforms that can be sent to the worker for processing."""
   def __init__(self, name, transforms,
                downstream_side_inputs=None, must_follow=frozenset(),
-               parent=None, environment=None):
+               parent=None, environment=None, forced_root=False):
     self.name = name
     self.transforms = transforms
     self.downstream_side_inputs = downstream_side_inputs
@@ -58,6 +58,7 @@ class Stage(object):
           self._merge_environments,
           (self._extract_environment(t) for t in transforms))
     self.environment = environment
+    self.forced_root = forced_root
 
   def __repr__(self):
     must_follow = ', '.join(prev.name for prev in self.must_follow)
@@ -103,7 +104,7 @@ class Stage(object):
             str(env2).replace('\n', ' ')))
       return env1
 
-  def can_fuse(self, consumer):
+  def can_fuse(self, consumer, context):
     try:
       self._merge_environments(self.environment, consumer.environment)
     except ValueError:
@@ -113,8 +114,10 @@ class Stage(object):
       return not a.intersection(b)
 
     return (
-        not self in consumer.must_follow
-        and not self.is_flatten() and not consumer.is_flatten()
+        not consumer.forced_root
+        and not self in consumer.must_follow
+        and not self.is_runner_urn(context)
+        and not consumer.is_runner_urn(context)
         and no_overlap(self.downstream_side_inputs, consumer.side_inputs()))
 
   def fuse(self, other):
@@ -124,10 +127,12 @@ class Stage(object):
         union(self.downstream_side_inputs, other.downstream_side_inputs),
         union(self.must_follow, other.must_follow),
         environment=self._merge_environments(
-            self.environment, other.environment))
+            self.environment, other.environment),
+        parent=self.parent if self.parent == other.parent else None,
+        forced_root=self.forced_root or other.forced_root)
 
-  def is_flatten(self):
-    return any(transform.spec.urn == common_urns.primitives.FLATTEN.urn
+  def is_runner_urn(self, context):
+    return any(transform.spec.urn in context.known_runner_urns
                for transform in self.transforms)
 
   def side_inputs(self):
@@ -162,6 +167,89 @@ class Stage(object):
       new_transforms.append(transform)
     self.transforms = new_transforms
 
+  def executable_stage_transform(
+      self, known_runner_urns, all_consumers, components):
+    if (len(self.transforms) == 1
+        and self.transforms[0].spec.urn in known_runner_urns):
+      return self.transforms[0]
+
+    else:
+      all_inputs = set(
+          pcoll for t in self.transforms for pcoll in t.inputs.values())
+      all_outputs = set(
+          pcoll for t in self.transforms for pcoll in t.outputs.values())
+      internal_transforms = set(id(t) for t in self.transforms)
+      external_outputs = [pcoll for pcoll in all_outputs
+                          if all_consumers[pcoll] - internal_transforms]
+
+      stage_components = beam_runner_api_pb2.Components()
+      stage_components.CopyFrom(components)
+
+      # Only keep the referenced PCollections.
+      for pcoll_id in stage_components.pcollections.keys():
+        if pcoll_id not in all_inputs and pcoll_id not in all_outputs:
+          del stage_components.pcollections[pcoll_id]
+
+      # Only keep the transforms in this stage.
+      # Also gather up payload data as we iterate over the transforms.
+      stage_components.transforms.clear()
+      main_inputs = set()
+      side_inputs = []
+      user_states = []
+      timers = []
+      for ix, transform in enumerate(self.transforms):
+        transform_id = 'transform_%d' % ix
+        if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+          payload = proto_utils.parse_Bytes(
+              transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+          for tag in payload.side_inputs.keys():
+            side_inputs.append(
+                beam_runner_api_pb2.ExecutableStagePayload.SideInputId(
+                    transform_id=transform_id,
+                    local_name=tag))
+          for tag in payload.state_specs.keys():
+            user_states.append(
+                beam_runner_api_pb2.ExecutableStagePayload.UserStateId(
+                    transform_id=transform_id,
+                    local_name=tag))
+          for tag in payload.timer_specs.keys():
+            timers.append(
+                beam_runner_api_pb2.ExecutableStagePayload.TimerId(
+                    transform_id=transform_id,
+                    local_name=tag))
+          main_inputs.update(
+              pcoll_id
+              for tag, pcoll_id in transform.inputs.items()
+              if tag not in payload.side_inputs)
+        else:
+          main_inputs.update(transform.inputs.values())
+        stage_components.transforms[transform_id].CopyFrom(transform)
+
+      main_input_id = only_element(main_inputs - all_outputs)
+      named_inputs = dict({
+          '%s:%s' % (side.transform_id, side.local_name):
+          stage_components.transforms[side.transform_id].inputs[side.local_name]
+          for side in side_inputs
+      }, main_input=main_input_id)
+      payload = beam_runner_api_pb2.ExecutableStagePayload(
+          environment=components.environments[self.environment],
+          input=main_input_id,
+          outputs=external_outputs,
+          transforms=stage_components.transforms.keys(),
+          components=stage_components,
+          side_inputs=side_inputs,
+          user_states=user_states,
+          timers=timers)
+
+      return beam_runner_api_pb2.PTransform(
+          unique_name=unique_name(None, self.name),
+          spec=beam_runner_api_pb2.FunctionSpec(
+              urn='beam:runner:executable_stage:v1',
+              payload=payload.SerializeToString()),
+          inputs=named_inputs,
+          outputs={'output_%d' % ix: pcoll
+                   for ix, pcoll in enumerate(external_outputs)})
+
 
 def memoize_on_instance(f):
   missing = object()
@@ -185,8 +273,9 @@ class TransformContext(object):
   _KNOWN_CODER_URNS = set(
       value.urn for value in common_urns.coders.__dict__.values())
 
-  def __init__(self, components, use_state_iterables=False):
+  def __init__(self, components, known_runner_urns, use_state_iterables=False):
     self.components = components
+    self.known_runner_urns = known_runner_urns
     self.use_state_iterables = use_state_iterables
     self.safe_coders = {}
     self.bytes_coder_id = self.add_or_get_coder_id(
@@ -296,7 +385,8 @@ def leaf_transform_stages(
         yield stage
 
 
-def with_stages(pipeline_proto, stages):
+def pipeline_from_stages(
+    pipeline_proto, stages, known_runner_urns, partial):
 
   # In case it was a generator that mutates components as it
   # produces outputs (as is the case with most transformations).
@@ -307,6 +397,7 @@ def with_stages(pipeline_proto, stages):
   components = new_proto.components
   components.transforms.clear()
 
+  roots = set()
   parents = {
       child: parent
       for parent, proto in pipeline_proto.components.transforms.items()
@@ -314,24 +405,83 @@ def with_stages(pipeline_proto, stages):
   }
 
   def add_parent(child, parent):
-    if parent not in components.transforms:
-      components.transforms[parent].CopyFrom(
-          pipeline_proto.components.transforms[parent])
-      del components.transforms[parent].subtransforms[:]
-      if parent in parents:
-        add_parent(parent, parents[parent])
-    components.transforms[parent].subtransforms.append(child)
-
+    if parent is None:
+      roots.add(child)
+    else:
+      if parent not in components.transforms:
+        components.transforms[parent].CopyFrom(
+            pipeline_proto.components.transforms[parent])
+        del components.transforms[parent].subtransforms[:]
+        add_parent(parent, parents.get(parent))
+      components.transforms[parent].subtransforms.append(child)
+
+  all_consumers = collections.defaultdict(set)
   for stage in stages:
     for transform in stage.transforms:
-      id = unique_name(components.transforms, stage.name)
-      components.transforms[id].CopyFrom(transform)
-      if stage.parent:
-        add_parent(id, stage.parent)
+      for pcoll in transform.inputs.values():
+        all_consumers[pcoll].add(id(transform))
+
+  for stage in stages:
+    if partial:
+      transform = only_element(stage.transforms)
+    else:
+      transform = stage.executable_stage_transform(
+          known_runner_urns, all_consumers, components)
+    transform_id = unique_name(components.transforms, stage.name)
+    components.transforms[transform_id].CopyFrom(transform)
+    add_parent(transform_id, stage.parent)
+
+  del new_proto.root_transform_ids[:]
+  new_proto.root_transform_ids.extend(roots)
 
   return new_proto
 
 
+def create_and_optimize_stages(
+    pipeline_proto,
+    phases,
+    known_runner_urns,
+    use_state_iterables=False):
+  pipeline_context = TransformContext(
+      pipeline_proto.components,
+      known_runner_urns,
+      use_state_iterables=use_state_iterables)
+
+  # Initial set of stages are singleton leaf transforms.
+  stages = list(leaf_transform_stages(
+      pipeline_proto.root_transform_ids,
+      pipeline_proto.components,
+      union(known_runner_urns, KNOWN_COMPOSITES)))
+
+  # Apply each phase in order.
+  for phase in phases:
+    logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
+    stages = list(phase(stages, pipeline_context))
+    logging.debug('%s %s' % (len(stages), [len(s.transforms) for s in stages]))
+    logging.debug('Stages: %s', [str(s) for s in stages])
+
+  # Return the (possibly mutated) context and ordered set of stages.
+  return pipeline_context, stages
+
+
+def optimize_pipeline(
+    pipeline_proto,
+    phases,
+    known_runner_urns,
+    partial=False,
+    **kwargs):
+  unused_context, stages = create_and_optimize_stages(
+      pipeline_proto,
+      phases,
+      known_runner_urns,
+      **kwargs)
+  return pipeline_from_stages(
+      pipeline_proto, stages, known_runner_urns, partial)
+
+
+# Optimization stages.
+
+
 def annotate_downstream_side_inputs(stages, pipeline_context):
   """Annotate each stage with fusion-prohibiting information.
 
@@ -379,6 +529,17 @@ def annotate_downstream_side_inputs(stages, pipeline_context):
   return stages
 
 
+def annotate_stateful_dofns_as_roots(stages, pipeline_context):
+  for stage in stages:
+    for transform in stage.transforms:
+      if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+        pardo_payload = proto_utils.parse_Bytes(
+            transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+        if pardo_payload.state_specs or pardo_payload.timer_specs:
+          stage.forced_root = True
+    yield stage
+
+
 def fix_side_input_pcoll_coders(stages, pipeline_context):
   """Length prefix side input PCollection coders.
   """
@@ -569,37 +730,30 @@ def expand_gbk(stages, pipeline_context):
       yield stage
 
 
-def sink_flattens(stages, pipeline_context):
-  """Sink flattens and remove them from the graph.
-
-  A flatten that cannot be sunk/fused away becomes multiple writes (to the
-  same logical sink) followed by a read.
+def fix_flatten_coders(stages, pipeline_context):
+  """Ensures that the inputs of Flatten have the same coders as the output.
   """
-  # TODO(robertwb): Actually attempt to sink rather than always materialize.
-  # TODO(robertwb): Possibly fuse this into one of the stages.
   pcollections = pipeline_context.components.pcollections
   for stage in stages:
-    assert len(stage.transforms) == 1
-    transform = stage.transforms[0]
+    transform = only_element(stage.transforms)
     if transform.spec.urn == common_urns.primitives.FLATTEN.urn:
-      # This is used later to correlate the read and writes.
-      buffer_id = create_buffer_id(transform.unique_name)
-      output_pcoll_id, = list(transform.outputs.values())
+      output_pcoll_id = only_element(transform.outputs.values())
       output_coder_id = pcollections[output_pcoll_id].coder_id
-      flatten_writes = []
-      for local_in, pcoll_in in transform.inputs.items():
-
+      for local_in, pcoll_in in list(transform.inputs.items()):
         if pcollections[pcoll_in].coder_id != output_coder_id:
           # Flatten requires that all its inputs be materialized with the
           # same coder as its output.  Add stages to transcode flatten
           # inputs that use different coders.
-          transcoded_pcollection = (
+          transcoded_pcollection = unique_name(
+              pcollections,
               transform.unique_name + '/Transcode/' + local_in + '/out')
+          transcode_name = unique_name(
+              pipeline_context.components.transforms,
+              transform.unique_name + '/Transcode/' + local_in)
           yield Stage(
-              transform.unique_name + '/Transcode/' + local_in,
+              transcode_name,
               [beam_runner_api_pb2.PTransform(
-                  unique_name=
-                  transform.unique_name + '/Transcode/' + local_in,
+                  unique_name=transcode_name,
                   inputs={local_in: pcoll_in},
                   outputs={'out': transcoded_pcollection},
                   spec=beam_runner_api_pb2.FunctionSpec(
@@ -608,15 +762,34 @@ def sink_flattens(stages, pipeline_context):
               must_follow=stage.must_follow)
           pcollections[transcoded_pcollection].CopyFrom(
               pcollections[pcoll_in])
+          pcollections[transcoded_pcollection].unique_name = (
+              transcoded_pcollection)
           pcollections[transcoded_pcollection].coder_id = output_coder_id
-        else:
-          transcoded_pcollection = pcoll_in
+          transform.inputs[local_in] = transcoded_pcollection
+
+    yield stage
+
+
+def sink_flattens(stages, pipeline_context):
+  """Sink flattens and remove them from the graph.
 
+  A flatten that cannot be sunk/fused away becomes multiple writes (to the
+  same logical sink) followed by a read.
+  """
+  # TODO(robertwb): Actually attempt to sink rather than always materialize.
+  # TODO(robertwb): Possibly fuse this into one of the stages.
+  for stage in fix_flatten_coders(stages, pipeline_context):
+    transform = only_element(stage.transforms)
+    if transform.spec.urn == common_urns.primitives.FLATTEN.urn:
+      # This is used later to correlate the read and writes.
+      buffer_id = create_buffer_id(transform.unique_name)
+      flatten_writes = []
+      for local_in, pcoll_in in transform.inputs.items():
         flatten_write = Stage(
             transform.unique_name + '/Write/' + local_in,
             [beam_runner_api_pb2.PTransform(
                 unique_name=transform.unique_name + '/Write/' + local_in,
-                inputs={local_in: transcoded_pcollection},
+                inputs={local_in: pcoll_in},
                 spec=beam_runner_api_pb2.FunctionSpec(
                     urn=bundle_processor.DATA_OUTPUT_URN,
                     payload=buffer_id))],
@@ -684,7 +857,7 @@ def greedily_fuse(stages, pipeline_context):
       # Update consumer.must_follow set, as it's used in can_fuse.
       consumer.must_follow = frozenset(
           replacement(s) for s in consumer.must_follow)
-      if producer.can_fuse(consumer):
+      if producer.can_fuse(consumer, pipeline_context):
         fuse(producer, consumer)
       else:
         # If we can't fuse, do a read + write.
@@ -697,7 +870,8 @@ def greedily_fuse(stages, pipeline_context):
                   inputs={'in': pcoll},
                   spec=beam_runner_api_pb2.FunctionSpec(
                       urn=bundle_processor.DATA_OUTPUT_URN,
-                      payload=buffer_id))])
+                      payload=buffer_id))],
+              downstream_side_inputs=producer.downstream_side_inputs)
           fuse(producer, write_pcoll)
         if consumer.has_as_main_input(pcoll):
           read_pcoll = Stage(
@@ -708,6 +882,7 @@ def greedily_fuse(stages, pipeline_context):
                   spec=beam_runner_api_pb2.FunctionSpec(
                       urn=bundle_processor.DATA_INPUT_URN,
                       payload=buffer_id))],
+              downstream_side_inputs=consumer.downstream_side_inputs,
               must_follow=frozenset([write_pcoll]))
           fuse(read_pcoll, consumer)
         else:
@@ -781,6 +956,34 @@ def impulse_to_input(stages, pipeline_context):
     yield stage
 
 
+def extract_impulse_stages(stages, pipeline_context):
+  """Splits fused Impulse operations into their own stage."""
+  for stage in stages:
+    for transform in list(stage.transforms):
+      if transform.spec.urn == common_urns.primitives.IMPULSE.urn:
+        stage.transforms.remove(transform)
+        yield Stage(
+            transform.unique_name,
+            transforms=[transform],
+            downstream_side_inputs=stage.downstream_side_inputs,
+            must_follow=stage.must_follow,
+            parent=stage.parent)
+
+    if stage.transforms:
+      yield stage
+
+
+def remove_data_plane_ops(stages, pipeline_context):
+  for stage in stages:
+    for transform in list(stage.transforms):
+      if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
+                                bundle_processor.DATA_OUTPUT_URN):
+        stage.transforms.remove(transform)
+
+    if stage.transforms:
+      yield stage
+
+
 def inject_timer_pcollections(stages, pipeline_context):
   """Create PCollections for fired timers and to-be-set timers.
 
@@ -858,12 +1061,15 @@ def inject_timer_pcollections(stages, pipeline_context):
 def sort_stages(stages, pipeline_context):
   """Order stages suitable for sequential execution.
   """
+  all_stages = set(stages)
   seen = set()
   ordered = []
 
   def process(stage):
     if stage not in seen:
       seen.add(stage)
+      if stage not in all_stages:
+        return
       for prev in stage.must_follow:
         process(prev)
       ordered.append(stage)
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 42bca33..87f05d3 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -28,6 +28,7 @@ from concurrent import futures
 import grpc
 
 from apache_beam import metrics
+from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.options.pipeline_options import PortableOptions
 from apache_beam.options.pipeline_options import SetupOptions
 from apache_beam.options.pipeline_options import StandardOptions
@@ -161,6 +162,7 @@ class PortableRunner(runner.PipelineRunner):
       portable_options.environment_config, server = (
           BeamFnExternalWorkerPoolServicer.start(
               sdk_worker_main._get_worker_count(options)))
+      globals()['x'] = server
       cleanup_callbacks = [functools.partial(server.stop, 1)]
     else:
       cleanup_callbacks = []
@@ -180,14 +182,39 @@ class PortableRunner(runner.PipelineRunner):
 
     # Preemptively apply combiner lifting, until all runners support it.
     # This optimization is idempotent.
+    pre_optimize = options.view_as(DebugOptions).lookup_experiment(
+        'pre_optimize', 'combine').lower()
     if not options.view_as(StandardOptions).streaming:
-      stages = list(fn_api_runner_transforms.leaf_transform_stages(
-          proto_pipeline.root_transform_ids, proto_pipeline.components))
-      stages = fn_api_runner_transforms.lift_combiners(
-          stages,
-          fn_api_runner_transforms.TransformContext(proto_pipeline.components))
-      proto_pipeline = fn_api_runner_transforms.with_stages(
-          proto_pipeline, stages)
+      flink_known_urns = frozenset([
+          common_urns.composites.RESHUFFLE.urn,
+          common_urns.primitives.IMPULSE.urn,
+          common_urns.primitives.FLATTEN.urn,
+          common_urns.primitives.GROUP_BY_KEY.urn])
+      if pre_optimize == 'combine':
+        proto_pipeline = fn_api_runner_transforms.optimize_pipeline(
+            proto_pipeline,
+            phases=[fn_api_runner_transforms.lift_combiners],
+            known_runner_urns=flink_known_urns,
+            partial=True)
+      elif pre_optimize == 'all':
+        proto_pipeline = fn_api_runner_transforms.optimize_pipeline(
+            proto_pipeline,
+            phases=[fn_api_runner_transforms.annotate_downstream_side_inputs,
+                    fn_api_runner_transforms.annotate_stateful_dofns_as_roots,
+                    fn_api_runner_transforms.fix_side_input_pcoll_coders,
+                    fn_api_runner_transforms.lift_combiners,
+                    fn_api_runner_transforms.fix_flatten_coders,
+                    # fn_api_runner_transforms.sink_flattens,
+                    fn_api_runner_transforms.greedily_fuse,
+                    fn_api_runner_transforms.read_to_impulse,
+                    fn_api_runner_transforms.extract_impulse_stages,
+                    fn_api_runner_transforms.remove_data_plane_ops,
+                    fn_api_runner_transforms.sort_stages],
+            known_runner_urns=flink_known_urns)
+      elif pre_optimize == 'none':
+        pass
+      else:
+        raise ValueError('Unknown value for pre_optimize: %s' % pre_optimize)
 
     if not job_service:
       channel = grpc.insecure_channel(job_endpoint)
diff --git a/sdks/python/build.gradle b/sdks/python/build.gradle
index aa0b9c7..b74b030 100644
--- a/sdks/python/build.gradle
+++ b/sdks/python/build.gradle
@@ -334,6 +334,9 @@ class CompatibilityMatrixConfig {
   enum SDK_WORKER_TYPE {
     DOCKER, PROCESS, LOOPBACK
   }
+
+  // Whether to pre-optimize the pipeline with the Python optimizer.
+  boolean preOptimize = false
 }
 
 def flinkCompatibilityMatrix = {
@@ -341,7 +344,10 @@ def flinkCompatibilityMatrix = {
   def workerType = config.workerType.name()
   def streaming = config.streaming
   def environment_config = config.workerType == CompatibilityMatrixConfig.SDK_WORKER_TYPE.PROCESS ? "--environment_config='{\"command\": \"${project(":beam-sdks-python:").buildDir.absolutePath}/sdk_worker.sh\"}'" : ""
-  def name = "flinkCompatibilityMatrix${streaming ? 'Streaming' : 'Batch'}${workerType}"
+  def name = "flinkCompatibilityMatrix${streaming ? 'Streaming' : 'Batch'}${config.preOptimize ? 'PreOptimize' : ''}${workerType}"
+  def extra_experiments = []
+  if (config.preOptimize)
+    extra_experiments.add('pre_optimize=all')
   tasks.create(name: name) {
     dependsOn 'setupVirtualenv'
     dependsOn ':beam-runners-flink_2.11-job-server:shadowJar'
@@ -352,7 +358,7 @@ def flinkCompatibilityMatrix = {
     doLast {
       exec {
         executable 'sh'
-        args '-c', ". ${project.ext.envdir}/bin/activate && pip install -e .[test] && python -m apache_beam.runners.portability.flink_runner_test --flink_job_server_jar=${project(":beam-runners-flink_2.11-job-server:").shadowJar.archivePath} --environment_type=${workerType} ${environment_config} ${streaming ? '--streaming' : ''}"
+        args '-c', ". ${project.ext.envdir}/bin/activate && pip install -e .[test] && python -m apache_beam.runners.portability.flink_runner_test --flink_job_server_jar=${project(":beam-runners-flink_2.11-job-server:").shadowJar.archivePath} --environment_type=${workerType} ${environment_config} ${streaming ? '--streaming' : ''} ${extra_experiments ? '--extra_experiments=' + extra_experiments.join(',') : ''}"
       }
     }
   }
@@ -371,6 +377,7 @@ task flinkCompatibilityMatrixProcess() {
 task flinkCompatibilityMatrixLoopback() {
   dependsOn flinkCompatibilityMatrix(streaming: false, workerType: CompatibilityMatrixConfig.SDK_WORKER_TYPE.LOOPBACK)
   dependsOn flinkCompatibilityMatrix(streaming: true, workerType: CompatibilityMatrixConfig.SDK_WORKER_TYPE.LOOPBACK)
+  dependsOn flinkCompatibilityMatrix(streaming: true, workerType: CompatibilityMatrixConfig.SDK_WORKER_TYPE.LOOPBACK, preOptimize: true)
 }
 
 task flinkValidatesRunner() {