You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 15:06:17 UTC

[airflow] 15/42: Make airflow dags show command display TaskGroup (#14269)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3f36fa9692cbcb13caca5b4f23c78f1967171e50
Author: yuqian90 <yu...@gmail.com>
AuthorDate: Thu Feb 25 23:23:15 2021 +0800

    Make airflow dags show command display TaskGroup (#14269)
    
    closes: #13053
    
    Make `airflow dags show` display TaskGroup.
    
    (cherry picked from commit c71f707d24a9196d33b91a7a2a9e3384698e5193)
---
 airflow/utils/dot_renderer.py    | 120 ++++++++++++++++++++++++++++++---------
 tests/utils/test_dot_renderer.py | 101 +++++++++++++++++++++++++++++++-
 2 files changed, 191 insertions(+), 30 deletions(-)

diff --git a/airflow/utils/dot_renderer.py b/airflow/utils/dot_renderer.py
index 990c7a7..4123f99 100644
--- a/airflow/utils/dot_renderer.py
+++ b/airflow/utils/dot_renderer.py
@@ -17,13 +17,17 @@
 # specific language governing permissions and limitations
 # under the License.
 """Renderer DAG (tasks and dependencies) to the graphviz object."""
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 import graphviz
 
 from airflow.models import TaskInstance
+from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
+from airflow.models.taskmixin import TaskMixin
 from airflow.utils.state import State
+from airflow.utils.task_group import TaskGroup
+from airflow.www.views import dag_edges
 
 
 def _refine_color(color: str):
@@ -42,6 +46,88 @@ def _refine_color(color: str):
     return color
 
 
+def _draw_task(task: BaseOperator, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str]) -> None:
+    """Draw a single task on the given parent_graph"""
+    if states_by_task_id:
+        state = states_by_task_id.get(task.task_id, State.NONE)
+        color = State.color_fg(state)
+        fill_color = State.color(state)
+    else:
+        color = task.ui_fgcolor
+        fill_color = task.ui_color
+
+    parent_graph.node(
+        task.task_id,
+        _attributes={
+            "label": task.label,
+            "shape": "rectangle",
+            "style": "filled,rounded",
+            "color": _refine_color(color),
+            "fillcolor": _refine_color(fill_color),
+        },
+    )
+
+
+def _draw_task_group(
+    task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str]
+) -> None:
+    """Draw the given task_group and its children on the given parent_graph"""
+    # Draw joins
+    if task_group.upstream_group_ids or task_group.upstream_task_ids:
+        parent_graph.node(
+            task_group.upstream_join_id,
+            _attributes={
+                "label": "",
+                "shape": "circle",
+                "style": "filled,rounded",
+                "color": _refine_color(task_group.ui_fgcolor),
+                "fillcolor": _refine_color(task_group.ui_color),
+                "width": "0.2",
+                "height": "0.2",
+            },
+        )
+
+    if task_group.downstream_group_ids or task_group.downstream_task_ids:
+        parent_graph.node(
+            task_group.downstream_join_id,
+            _attributes={
+                "label": "",
+                "shape": "circle",
+                "style": "filled,rounded",
+                "color": _refine_color(task_group.ui_fgcolor),
+                "fillcolor": _refine_color(task_group.ui_color),
+                "width": "0.2",
+                "height": "0.2",
+            },
+        )
+
+    # Draw children
+    for child in sorted(task_group.children.values(), key=lambda t: t.label):
+        _draw_nodes(child, parent_graph, states_by_task_id)
+
+
+def _draw_nodes(node: TaskMixin, parent_graph: graphviz.Digraph, states_by_task_id: Dict[str, str]) -> None:
+    """Draw the node and its children on the given parent_graph recursively."""
+    if isinstance(node, BaseOperator):
+        _draw_task(node, parent_graph, states_by_task_id)
+    else:
+        # Draw TaskGroup
+        if node.is_root:
+            # No need to draw background for root TaskGroup.
+            _draw_task_group(node, parent_graph, states_by_task_id)
+        else:
+            with parent_graph.subgraph(name=f"cluster_{node.group_id}") as sub:
+                sub.attr(
+                    shape="rectangle",
+                    style="filled",
+                    color=_refine_color(node.ui_fgcolor),
+                    # Partially transparent CornflowerBlue
+                    fillcolor="#6495ed7f",
+                    label=node.label,
+                )
+                _draw_task_group(node, sub, states_by_task_id)
+
+
 def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.Digraph:
     """
     Renders the DAG object to the DOT object.
@@ -66,30 +152,10 @@ def render_dag(dag: DAG, tis: Optional[List[TaskInstance]] = None) -> graphviz.D
     states_by_task_id = None
     if tis is not None:
         states_by_task_id = {ti.task_id: ti.state for ti in tis}
-    for task in dag.tasks:
-        node_attrs = {
-            "shape": "rectangle",
-            "style": "filled,rounded",
-        }
-        if states_by_task_id is None:
-            node_attrs.update(
-                {
-                    "color": _refine_color(task.ui_fgcolor),
-                    "fillcolor": _refine_color(task.ui_color),
-                }
-            )
-        else:
-            state = states_by_task_id.get(task.task_id, State.NONE)
-            node_attrs.update(
-                {
-                    "color": State.color_fg(state),
-                    "fillcolor": State.color(state),
-                }
-            )
-        dot.node(
-            task.task_id,
-            _attributes=node_attrs,
-        )
-        for downstream_task_id in task.downstream_task_ids:
-            dot.edge(task.task_id, downstream_task_id)
+
+    _draw_nodes(dag.task_group, dot, states_by_task_id)
+
+    for edge in dag_edges(dag):
+        dot.edge(edge["source_id"], edge["target_id"])
+
     return dot
diff --git a/tests/utils/test_dot_renderer.py b/tests/utils/test_dot_renderer.py
index b030623..ca3ea01 100644
--- a/tests/utils/test_dot_renderer.py
+++ b/tests/utils/test_dot_renderer.py
@@ -23,9 +23,11 @@ from unittest import mock
 from airflow.models import TaskInstance
 from airflow.models.dag import DAG
 from airflow.operators.bash import BashOperator
+from airflow.operators.dummy import DummyOperator
 from airflow.operators.python import PythonOperator
 from airflow.utils import dot_renderer
 from airflow.utils.state import State
+from airflow.utils.task_group import TaskGroup
 
 START_DATE = datetime.datetime.now()
 
@@ -72,9 +74,16 @@ class TestDotRenderer(unittest.TestCase):
         source = dot.source
         # Should render DAG title
         assert "label=DAG_ID" in source
-        assert 'first [color=black fillcolor=tan shape=rectangle style="filled,rounded"]' in source
-        assert 'second [color=white fillcolor=green shape=rectangle style="filled,rounded"]' in source
-        assert 'third [color=black fillcolor=lime shape=rectangle style="filled,rounded"]' in source
+        assert (
+            'first [color=black fillcolor=tan label=first shape=rectangle style="filled,rounded"]' in source
+        )
+        assert (
+            'second [color=white fillcolor=green label=second shape=rectangle style="filled,rounded"]'
+            in source
+        )
+        assert (
+            'third [color=black fillcolor=lime label=third shape=rectangle style="filled,rounded"]' in source
+        )
 
     def test_should_render_dag_orientation(self):
         orientation = "TB"
@@ -105,3 +114,89 @@ class TestDotRenderer(unittest.TestCase):
         # Should render DAG title with orientation
         assert "label=DAG_ID" in source
         assert f'label=DAG_ID labelloc=t rankdir={orientation}' in source
+
+    def test_render_task_group(self):
+        with DAG(dag_id="example_task_group", start_date=START_DATE) as dag:
+            start = DummyOperator(task_id="start")
+
+            with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
+                task_1 = DummyOperator(task_id="task_1")
+                task_2 = BashOperator(task_id="task_2", bash_command='echo 1')
+                task_3 = DummyOperator(task_id="task_3")
+
+                task_1 >> [task_2, task_3]
+
+            with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
+                task_1 = DummyOperator(task_id="task_1")
+
+                with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2"):
+                    task_2 = BashOperator(task_id="task_2", bash_command='echo 1')
+                    task_3 = DummyOperator(task_id="task_3")
+                    task_4 = DummyOperator(task_id="task_4")
+
+                    [task_2, task_3] >> task_4
+
+            end = DummyOperator(task_id='end')
+
+            start >> section_1 >> section_2 >> end
+
+        dot = dot_renderer.render_dag(dag)
+
+        assert dot.source == '\n'.join(
+            [
+                'digraph example_task_group {',
+                '\tgraph [label=example_task_group labelloc=t rankdir=LR]',
+                '\tend [color="#000000" fillcolor="#e8f7e4" label=end shape=rectangle '
+                'style="filled,rounded"]',
+                '\tsubgraph cluster_section_1 {',
+                '\t\tcolor="#000000" fillcolor="#6495ed7f" label=section_1 shape=rectangle style=filled',
+                '\t\t"section_1.upstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 '
+                'label="" shape=circle style="filled,rounded" width=0.2]',
+                '\t\t"section_1.downstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 '
+                'label="" shape=circle style="filled,rounded" width=0.2]',
+                '\t\t"section_1.task_1" [color="#000000" fillcolor="#e8f7e4" label=task_1 shape=rectangle '
+                'style="filled,rounded"]',
+                '\t\t"section_1.task_2" [color="#000000" fillcolor="#f0ede4" label=task_2 shape=rectangle '
+                'style="filled,rounded"]',
+                '\t\t"section_1.task_3" [color="#000000" fillcolor="#e8f7e4" label=task_3 shape=rectangle '
+                'style="filled,rounded"]',
+                '\t}',
+                '\tsubgraph cluster_section_2 {',
+                '\t\tcolor="#000000" fillcolor="#6495ed7f" label=section_2 shape=rectangle style=filled',
+                '\t\t"section_2.upstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 '
+                'label="" shape=circle style="filled,rounded" width=0.2]',
+                '\t\t"section_2.downstream_join_id" [color="#000000" fillcolor=CornflowerBlue height=0.2 '
+                'label="" shape=circle style="filled,rounded" width=0.2]',
+                '\t\tsubgraph "cluster_section_2.inner_section_2" {',
+                '\t\t\tcolor="#000000" fillcolor="#6495ed7f" label=inner_section_2 shape=rectangle '
+                'style=filled',
+                '\t\t\t"section_2.inner_section_2.task_2" [color="#000000" fillcolor="#f0ede4" label=task_2 '
+                'shape=rectangle style="filled,rounded"]',
+                '\t\t\t"section_2.inner_section_2.task_3" [color="#000000" fillcolor="#e8f7e4" label=task_3 '
+                'shape=rectangle style="filled,rounded"]',
+                '\t\t\t"section_2.inner_section_2.task_4" [color="#000000" fillcolor="#e8f7e4" label=task_4 '
+                'shape=rectangle style="filled,rounded"]',
+                '\t\t}',
+                '\t\t"section_2.task_1" [color="#000000" fillcolor="#e8f7e4" label=task_1 shape=rectangle '
+                'style="filled,rounded"]',
+                '\t}',
+                '\tstart [color="#000000" fillcolor="#e8f7e4" label=start shape=rectangle '
+                'style="filled,rounded"]',
+                '\t"section_1.downstream_join_id" -> "section_2.upstream_join_id"',
+                '\t"section_1.task_1" -> "section_1.task_2"',
+                '\t"section_1.task_1" -> "section_1.task_3"',
+                '\t"section_1.task_2" -> "section_1.downstream_join_id"',
+                '\t"section_1.task_3" -> "section_1.downstream_join_id"',
+                '\t"section_1.upstream_join_id" -> "section_1.task_1"',
+                '\t"section_2.downstream_join_id" -> end',
+                '\t"section_2.inner_section_2.task_2" -> "section_2.inner_section_2.task_4"',
+                '\t"section_2.inner_section_2.task_3" -> "section_2.inner_section_2.task_4"',
+                '\t"section_2.inner_section_2.task_4" -> "section_2.downstream_join_id"',
+                '\t"section_2.task_1" -> "section_2.downstream_join_id"',
+                '\t"section_2.upstream_join_id" -> "section_2.inner_section_2.task_2"',
+                '\t"section_2.upstream_join_id" -> "section_2.inner_section_2.task_3"',
+                '\t"section_2.upstream_join_id" -> "section_2.task_1"',
+                '\tstart -> "section_1.upstream_join_id"',
+                '}',
+            ]
+        )