You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2018/08/17 23:16:58 UTC
[beam] branch master updated: Pipeline Graph from Interactive Beam
-- made faster
This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0e14965 Pipeline Graph from Interactive Beam -- made faster
0e14965 is described below
commit 0e14965707b5d48a3de7fa69f09d88ef0aa48c09
Author: Sindy Li <qi...@umich.edu>
AuthorDate: Tue Aug 14 16:31:05 2018 -0700
Pipeline Graph from Interactive Beam -- made faster
* Optimization
Changed filtering top level PTransform by string manipulation to
searching for them directly through looking into subtransforms of root
Ptransforms. Makes PipelineGraph faster.
* Generalization
* Moved display_graph() method to PipelineGraph
* PipelineGraph now takes pipeline obj or proto
---
.../interactive/interactive_pipeline_graph.py | 34 +++--------
.../runners/interactive/pipeline_graph.py | 68 +++++++++++++++++-----
2 files changed, 60 insertions(+), 42 deletions(-)
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py
index 229848c..2ad7c1b 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_pipeline_graph.py
@@ -53,33 +53,27 @@ class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
"""Creates the DOT representation of an interactive pipeline. Thread-safe."""
def __init__(self,
- pipeline_proto,
+ pipeline,
required_transforms=None,
referenced_pcollections=None,
cached_pcollections=None):
"""Constructor of PipelineGraph.
- Examples:
- pipeline_graph = PipelineGraph(pipeline_proto)
- print(pipeline_graph.get_dot())
- pipeline_graph.display_graph()
-
Args:
- pipeline_proto: (Pipeline proto) Pipeline to be rendered.
+ pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
required_transforms: (dict from str to PTransform proto) Mapping from
transform ID to transforms that leads to visible results.
referenced_pcollections: (dict from str to PCollection proto) PCollection
ID mapped to PCollection referenced during pipeline execution.
- cached_pcollections: (set of str) A set of PCollection IDs of those whose
+ cached_pcollections: (set of str) a set of PCollection IDs of those whose
cached results are used in the execution.
"""
- self._pipeline_proto = pipeline_proto
self._required_transforms = required_transforms or {}
self._referenced_pcollections = referenced_pcollections or {}
self._cached_pcollections = cached_pcollections or set()
super(InteractivePipelineGraph, self).__init__(
- pipeline_proto=pipeline_proto,
+ pipeline=pipeline,
default_vertex_attrs={'color': 'gray', 'fontcolor': 'gray'},
default_edge_attrs={'color': 'gray'}
)
@@ -87,14 +81,6 @@ class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
transform_updates, pcollection_updates = self._generate_graph_update_dicts()
self._update_graph(transform_updates, pcollection_updates)
- def display_graph(self):
- """Displays graph via IPython or prints DOT if not possible."""
- try:
- from IPython.core import display # pylint: disable=import-error
- display.display(display.HTML(self._get_graph().create_svg())) # pylint: disable=protected-access
- except ImportError:
- print(str(self._get_graph()))
-
def update_pcollection_stats(self, pcollection_stats):
"""Updates PCollection stats.
@@ -123,21 +109,15 @@ class InteractivePipelineGraph(pipeline_graph.PipelineGraph):
vertex_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes
edge_dict: (Dict[str, Dict[str, str]]) maps vertex name to attributes
"""
- transforms = self._pipeline_proto.components.transforms
-
transform_dict = {} # maps PTransform IDs to properties
pcoll_dict = {} # maps PCollection IDs to properties
- for transform_id, transform in transforms.items():
- if not super(
- InteractivePipelineGraph, self)._is_top_level_transform(transform):
- continue
-
- transform_dict[transform.unique_name] = {
+ for transform_id, transform_proto in self._top_level_transforms():
+ transform_dict[transform_proto.unique_name] = {
'required': transform_id in self._required_transforms
}
- for pcoll_id in transform.outputs.values():
+ for pcoll_id in transform_proto.outputs.values():
pcoll_dict[pcoll_id] = {
'cached': pcoll_id in self._cached_pcollections,
'referenced': pcoll_id in self._referenced_pcollections
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/pipeline_graph.py
index a100a75..737c89d 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_graph.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_graph.py
@@ -29,38 +29,63 @@ import threading
import pydot
+import apache_beam as beam
+from apache_beam.portability.api import beam_runner_api_pb2
+
class PipelineGraph(object):
"""Creates a DOT representation of the pipeline. Thread-safe."""
def __init__(self,
- pipeline_proto,
+ pipeline,
default_vertex_attrs=None,
default_edge_attrs=None):
"""Constructor of PipelineGraph.
+ Examples:
+ graph = pipeline_graph.PipelineGraph(pipeline_proto)
+ graph.display_graph()
+
+ or
+
+ graph = pipeline_graph.PipelineGraph(pipeline)
+ graph.display_grapy()
+
Args:
- pipeline_proto: (Pipeline proto)
+ pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered.
default_vertex_attrs: (Dict[str, str]) a dict of default vertex attributes
default_edge_attrs: (Dict[str, str]) a dict of default edge attributes
"""
self._lock = threading.Lock()
self._graph = None
+ if isinstance(pipeline, beam_runner_api_pb2.Pipeline):
+ self._pipeline_proto = pipeline
+ elif isinstance(pipeline, beam.Pipeline):
+ self._pipeline_proto = pipeline.to_runner_api()
+ else:
+ raise TypeError('pipeline should either be a %s or %s, while %s is given'
+ % (beam_runner_api_pb2.Pipeline, beam.Pipeline,
+ type(pipeline)))
+
# A dict from PCollection ID to a list of its consuming Transform IDs
self._consumers = collections.defaultdict(list)
# A dict from PCollection ID to its producing Transform ID
self._producers = {}
- transforms = pipeline_proto.components.transforms
- for transform_id, transform in transforms.items():
- if not self._is_top_level_transform(transform):
- continue
- for pcoll_id in transform.inputs.values():
+ for transform_id, transform_proto in self._top_level_transforms():
+ for pcoll_id in transform_proto.inputs.values():
self._consumers[pcoll_id].append(transform_id)
- for pcoll_id in transform.outputs.values():
+ for pcoll_id in transform_proto.outputs.values():
self._producers[pcoll_id] = transform_id
+ # Set the default vertex color to blue.
+ default_vertex_attrs = default_vertex_attrs or {}
+ if 'color' not in default_vertex_attrs:
+ default_vertex_attrs['color'] = 'blue'
+ if 'fontcolor' not in default_vertex_attrs:
+ default_vertex_attrs['fontcolor'] = 'blue'
+
vertex_dict, edge_dict = self._generate_graph_dicts()
self._construct_graph(vertex_dict,
edge_dict,
@@ -70,9 +95,25 @@ class PipelineGraph(object):
def get_dot(self):
return str(self._get_graph())
- def _is_top_level_transform(self, transform):
- return transform.unique_name and '/' not in transform.unique_name \
- and not transform.unique_name.startswith('ref_')
+ def display_graph(self):
+ """Displays graph via IPython or prints DOT if not possible."""
+ try:
+ from IPython.core import display # pylint: disable=import-error
+ display.display(display.HTML(self._get_graph().create_svg())) # pylint: disable=protected-access
+ except ImportError:
+ print(str(self._get_graph()))
+
+ def _top_level_transforms(self):
+ """Yields all top level PTransforms (subtransforms of the root PTransform).
+
+ Yields: (str, PTransform proto) ID, proto pair of top level PTransforms.
+ """
+ transforms = self._pipeline_proto.components.transforms
+ for root_transform_id in self._pipeline_proto.root_transform_ids:
+ root_transform_proto = transforms[root_transform_id]
+ for top_level_transform_id in root_transform_proto.subtransforms:
+ top_level_transform_proto = transforms[top_level_transform_id]
+ yield top_level_transform_id, top_level_transform_proto
def _generate_graph_dicts(self):
"""From pipeline_proto and other info, generate the graph.
@@ -92,10 +133,7 @@ class PipelineGraph(object):
self._edge_to_vertex_pairs = collections.defaultdict(list)
- for _, transform in transforms.items():
- if not self._is_top_level_transform(transform):
- continue
-
+ for _, transform in self._top_level_transforms():
vertex_dict[transform.unique_name] = {}
for pcoll_id in transform.outputs.values():