You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2020/12/15 00:25:05 UTC

[beam] branch master updated: [BEAM-11196] Fix `None` parent when fusing >2 stages (#13549)

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 52a2270  [BEAM-11196] Fix `None` parent when fusing >2 stages (#13549)
52a2270 is described below

commit 52a227020d289f7e091b3c42fdc7cf1485fbf547
Author: Yifan Mai <yi...@google.com>
AuthorDate: Mon Dec 14 16:24:12 2020 -0800

    [BEAM-11196] Fix `None` parent when fusing >2 stages (#13549)
---
 .../portability/fn_api_runner/translations.py      | 34 +++++++++++++---------
 .../portability/fn_api_runner/translations_test.py |  8 ++++-
 2 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index 3126822..6e8f95e 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -174,7 +174,7 @@ class Stage(object):
         union(self.must_follow, other.must_follow),
         environment=self._merge_environments(
             self.environment, other.environment),
-        parent=_parent_for_fused_stages([self.name, other.name], context),
+        parent=_parent_for_fused_stages([self, other], context),
         forced_root=self.forced_root or other.forced_root)
 
   def is_runner_urn(self, context):
@@ -799,7 +799,7 @@ def eliminate_common_key_with_none(stages, context):
         only_element(stage.transforms[0].outputs.values())
         for stage in sibling_stages
     ]
-    parent = _parent_for_fused_stages([s.name for s in sibling_stages], context)
+    parent = _parent_for_fused_stages(sibling_stages, context)
     for to_delete_pcoll_id in output_pcoll_ids[1:]:
       pcoll_id_remap[to_delete_pcoll_id] = output_pcoll_ids[0]
       del context.components.pcollections[to_delete_pcoll_id]
@@ -1224,18 +1224,16 @@ def lift_combiners(stages, context):
       yield stage
 
 
-def _lowest_common_ancestor(a, b, context):
-  # type: (str, str, TransformContext) -> Optional[str]
+def _lowest_common_ancestor(a, b, parents):
+  # type: (str, str, Dict[str, str]) -> Optional[str]
 
   '''Returns the name of the lowest common ancestor of the two named stages.
 
-  The provided context is used to compute ancestors of stages. Note that stages
-  are considered to be ancestors of themselves.
+  The map of stage names to their parents' stage names should be provided
+  in parents. Note that stages are considered to be ancestors of themselves.
   '''
   assert a != b
 
-  parents = context.parents_map()
-
   def get_ancestors(name):
     ancestor = name
     while ancestor is not None:
@@ -1250,7 +1248,7 @@ def _lowest_common_ancestor(a, b, context):
 
 
 def _parent_for_fused_stages(stages, context):
-  # type: (Iterable[Optional[str]], TransformContext) -> Optional[str]
+  # type: (Iterable[Stage], TransformContext) -> Optional[str]
 
   '''Returns the name of the new parent for the fused stages.
 
@@ -1258,15 +1256,25 @@ def _parent_for_fused_stages(stages, context):
   contained in the set of stages to be fused. The provided context is used to
   compute ancestors of stages.
   '''
+
+  parents = context.parents_map()
+  # If any of the input stages were produced by fusion or an optimizer phase,
+  # or had its parent modified by an optimizer phase, its parent will not be
+  # be reflected in the PipelineContext yet, so we need to add it to the
+  # parents map.
+  for stage in stages:
+    parents[stage.name] = stage.parent
+
   def reduce_fn(a, b):
     # type: (Optional[str], Optional[str]) -> Optional[str]
     if a is None or b is None:
       return None
-    return _lowest_common_ancestor(a, b, context)
+    return _lowest_common_ancestor(a, b, parents)
 
-  result = functools.reduce(reduce_fn, stages)
-  if result in stages:
-    result = context.parents_map().get(result)
+  stage_names = [stage.name for stage in stages]  # type: List[Optional[str]]
+  result = functools.reduce(reduce_fn, stage_names)
+  if result in stage_names:
+    result = parents.get(result)
   return result
 
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py
index 4c7643b..97bbfc2 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations_test.py
@@ -38,6 +38,7 @@ class TranslationsTest(unittest.TestCase):
       def expand(self, pcoll):
         _ = pcoll | 'key-with-none-a' >> beam.ParDo(core._KeyWithNone())
         _ = pcoll | 'key-with-none-b' >> beam.ParDo(core._KeyWithNone())
+        _ = pcoll | 'key-with-none-c' >> beam.ParDo(core._KeyWithNone())
 
     pipeline = beam.Pipeline()
     _ = pipeline | beam.Create(
@@ -57,6 +58,7 @@ class TranslationsTest(unittest.TestCase):
       def expand(self, pcoll):
         _ = pcoll | 'mean-perkey' >> combiners.Mean.PerKey()
         _ = pcoll | 'count-perkey' >> combiners.Count.PerKey()
+        _ = pcoll | 'largest-perkey' >> core.CombinePerKey(combiners.Largest(1))
 
     pipeline = beam.Pipeline()
     vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
@@ -84,6 +86,7 @@ class TranslationsTest(unittest.TestCase):
       def expand(self, pcoll):
         _ = pcoll | 'mean-perkey' >> combiners.Mean.PerKey()
         _ = pcoll | 'count-perkey' >> combiners.Count.PerKey()
+        _ = pcoll | 'largest-perkey' >> core.CombinePerKey(combiners.Largest(1))
 
     pipeline = beam.Pipeline()
     vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
@@ -100,7 +103,7 @@ class TranslationsTest(unittest.TestCase):
           combine_per_key_stages.append(stage)
     # Combiner packing should be skipped because the environment is missing
     # the beam:combinefn:packed_python:v1 capability.
-    self.assertEqual(len(combine_per_key_stages), 2)
+    self.assertEqual(len(combine_per_key_stages), 3)
     for combine_per_key_stage in combine_per_key_stages:
       self.assertNotIn('Packed', combine_per_key_stage.name)
       self.assertNotIn(
@@ -111,6 +114,8 @@ class TranslationsTest(unittest.TestCase):
       def expand(self, pcoll):
         _ = pcoll | 'mean-globally' >> combiners.Mean.Globally()
         _ = pcoll | 'count-globally' >> combiners.Count.Globally()
+        _ = pcoll | 'largest-globally' >> core.CombineGlobally(
+            combiners.Largest(1))
 
     pipeline = beam.Pipeline()
     vals = [6, 3, 1, 1, 9, 1, 5, 2, 0, 6]
@@ -176,6 +181,7 @@ class TranslationsTest(unittest.TestCase):
     pcoll = pipeline | Create(vals)
     _ = pcoll | 'mean-globally' >> combiners.Mean.Globally()
     _ = pcoll | 'count-globally' >> combiners.Count.Globally()
+    _ = pcoll | 'largest-globally' >> core.CombineGlobally(combiners.Largest(1))
     pipeline_proto = pipeline.to_runner_api()
     optimized_pipeline_proto = translations.optimize_pipeline(
         pipeline_proto,