You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2018/08/17 23:16:58 UTC

[beam] branch master updated: Pipeline Graph from Interactive Beam -- made faster

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e14965  Pipeline Graph from Interactive Beam -- made faster
0e14965 is described below

commit 0e14965707b5d48a3de7fa69f09d88ef0aa48c09
Author: Sindy Li <qi...@umich.edu>
AuthorDate: Tue Aug 14 16:31:05 2018 -0700

    Pipeline Graph from Interactive Beam -- made faster
    
    * Optimization
    Changed filtering top level PTransform by string manipulation to
    searching for them directly through looking into subtransforms of root
    Ptransforms. Makes PipelineGraph faster.
    
    * Generalization
        * Moved display_graph() method to PipelineGraph
        * PipelineGraph now takes pipeline obj or proto
---
 .../interactive/interactive_pipeline_graph.py      | 34 +++--------
 .../runners/interactive/pipeline_graph.py          | 68 +++++++++++++++++-----
 2 files changed, 60 insertions(+), 42 deletions(-)

diff --git a/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py
index 229848c..2ad7c1b 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py
@@ -53,33 +53,27 @@ class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
   """Creates the DOT representation of an interactive pipeline. Thread-safe."""
 
   def __init__(self,
-               pipeline_proto,
+               pipeline,
                required_transforms=None,
                referenced_pcollections=None,
                cached_pcollections=None):
     """Constructor of PipelineGraph.
 
-    Examples:
-      pipeline_graph = PipelineGraph(pipeline_proto)
-      print(pipeline_graph.get_dot())
-      pipeline_graph.display_graph()
-
     Args:
-      pipeline_proto: (Pipeline proto) Pipeline to be rendered.
+      pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
       required_transforms: (dict from str to PTransform proto) Mapping from
           transform ID to transforms that leads to visible results.
       referenced_pcollections: (dict from str to PCollection proto) PCollection
           ID mapped to PCollection referenced during pipeline execution.
-      cached_pcollections: (set of str) A set of PCollection IDs of those whose
+      cached_pcollections: (set of str) a set of PCollection IDs of those whose
           cached results are used in the execution.
     """
-    self._pipeline_proto = pipeline_proto
     self._required_transforms = required_transforms or {}
     self._referenced_pcollections = referenced_pcollections or {}
     self._cached_pcollections = cached_pcollections or set()
 
     super(InteractivePipelineGraph, self).__init__(
-        pipeline_proto=pipeline_proto,
+        pipeline=pipeline,
         default_vertex_attrs={'color': 'gray', 'fontcolor': 'gray'},
         default_edge_attrs={'color': 'gray'}
     )
@@ -87,14 +81,6 @@ class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
     transform_updates, pcollection_updates = self._generate_graph_update_dicts()
     self._update_graph(transform_updates, pcollection_updates)
 
-  def display_graph(self):
-    """Displays graph via IPython or prints DOT if not possible."""
-    try:
-      from IPython.core import display  # pylint: disable=import-error
-      display.display(display.HTML(self._get_graph().create_svg()))  # pylint: disable=protected-access
-    except ImportError:
-      print(str(self._get_graph()))
-
   def update_pcollection_stats(self, pcollection_stats):
     """Updates PCollection stats.
 
@@ -123,21 +109,15 @@ class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
       vertex_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes
       edge_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes
     """
-    transforms = self._pipeline_proto.components.transforms
-
     transform_dict = {}  # maps PTransform IDs to properties
     pcoll_dict = {}  # maps PCollection IDs to properties
 
-    for transform_id, transform in transforms.items():
-      if not super(
-          InteractivePipelineGraph, self)._is_top_level_transform(transform):
-        continue
-
-      transform_dict[transform.unique_name] = {
+    for transform_id, transform_proto in self._top_level_transforms():
+      transform_dict[transform_proto.unique_name] = {
           'required': transform_id in self._required_transforms
       }
 
-      for pcoll_id in transform.outputs.values():
+      for pcoll_id in transform_proto.outputs.values():
         pcoll_dict[pcoll_id] = {
             'cached': pcoll_id in self._cached_pcollections,
             'referenced': pcoll_id in self._referenced_pcollections
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/pipeline_graph.py
index a100a75..737c89d 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_graph.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_graph.py
@@ -29,38 +29,63 @@ import threading
 
 import pydot
 
+import apache_beam as beam
+from apache_beam.portability.api import beam_runner_api_pb2
+
 
 class PipelineGraph(object):
   """Creates a DOT representation of the pipeline. Thread-safe."""
 
   def __init__(self,
-               pipeline_proto,
+               pipeline,
                default_vertex_attrs=None,
                default_edge_attrs=None):
     """Constructor of PipelineGraph.
 
+    Examples:
+      graph = pipeline_graph.PipelineGraph(pipeline_proto)
+      graph.display_graph()
+
+      or
+
+      graph = pipeline_graph.PipelineGraph(pipeline)
+      graph.display_grapy()
+
     Args:
-      pipeline_proto: (Pipeline proto)
+      pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
       default_vertex_attrs: (Dict[str, str]) a dict of default vertex attributes
       default_edge_attrs: (Dict[str, str]) a dict of default edge attributes
     """
     self._lock = threading.Lock()
     self._graph = None
 
+    if isinstance(pipeline, beam_runner_api_pb2.Pipeline):
+      self._pipeline_proto = pipeline
+    elif isinstance(pipeline, beam.Pipeline):
+      self._pipeline_proto = pipeline.to_runner_api()
+    else:
+      raise TypeError('pipeline should either be a %s or %s, while %s is given'
+                      % (beam_runner_api_pb2.Pipeline, beam.Pipeline,
+                         type(pipeline)))
+
     # A dict from PCollection ID to a list of its consuming Transform IDs
     self._consumers = collections.defaultdict(list)
     # A dict from PCollection ID to its producing Transform ID
     self._producers = {}
 
-    transforms = pipeline_proto.components.transforms
-    for transform_id, transform in transforms.items():
-      if not self._is_top_level_transform(transform):
-        continue
-      for pcoll_id in transform.inputs.values():
+    for transform_id, transform_proto in self._top_level_transforms():
+      for pcoll_id in transform_proto.inputs.values():
         self._consumers[pcoll_id].append(transform_id)
-      for pcoll_id in transform.outputs.values():
+      for pcoll_id in transform_proto.outputs.values():
         self._producers[pcoll_id] = transform_id
 
+    # Set the default vertex color to blue.
+    default_vertex_attrs = default_vertex_attrs or {}
+    if 'color' not in default_vertex_attrs:
+      default_vertex_attrs['color'] = 'blue'
+    if 'fontcolor' not in default_vertex_attrs:
+      default_vertex_attrs['fontcolor'] = 'blue'
+
     vertex_dict, edge_dict = self._generate_graph_dicts()
     self._construct_graph(vertex_dict,
                           edge_dict,
@@ -70,9 +95,25 @@ class PipelineGraph(object):
   def get_dot(self):
     return str(self._get_graph())
 
-  def _is_top_level_transform(self, transform):
-    return transform.unique_name and '/' not in transform.unique_name \
-        and not transform.unique_name.startswith('ref_')
+  def display_graph(self):
+    """Displays graph via IPython or prints DOT if not possible."""
+    try:
+      from IPython.core import display  # pylint: disable=import-error
+      display.display(display.HTML(self._get_graph().create_svg()))  # pylint: disable=protected-access
+    except ImportError:
+      print(str(self._get_graph()))
+
+  def _top_level_transforms(self):
+    """Yields all top level PTransforms (subtransforms of the root PTransform).
+
+    Yields: (str, PTransform proto) ID, proto pair of top level PTransforms.
+    """
+    transforms = self._pipeline_proto.components.transforms
+    for root_transform_id in self._pipeline_proto.root_transform_ids:
+      root_transform_proto = transforms[root_transform_id]
+      for top_level_transform_id in root_transform_proto.subtransforms:
+        top_level_transform_proto = transforms[top_level_transform_id]
+        yield top_level_transform_id, top_level_transform_proto
 
   def _generate_graph_dicts(self):
     """From pipeline_proto and other info, generate the graph.
@@ -92,10 +133,7 @@ class PipelineGraph(object):
 
     self._edge_to_vertex_pairs = collections.defaultdict(list)
 
-    for _, transform in transforms.items():
-      if not self._is_top_level_transform(transform):
-        continue
-
+    for _, transform in self._top_level_transforms():
       vertex_dict[transform.unique_name] = {}
 
       for pcoll_id in transform.outputs.values():