You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/05 12:59:17 UTC

[airflow] branch main updated: Show tasks in grid view based on topological sort. (#22741)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 34154803ac Show tasks in grid view based on topological sort. (#22741)
34154803ac is described below

commit 34154803ac73d62d3e969e480405df3073032622
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Apr 5 13:59:09 2022 +0100

    Show tasks in grid view based on topological sort. (#22741)
    
    This takes the existing topological sort that existed on a DAG and moves
    it down to TaskGroup.
    
    In order to do this (and not have duplicated sort) the existing sort on
    DAG is re-implemented on top of the new method.
    
    This also surfaced a tiny bug in deserialize_task_group where the
    SerializedTaskGroup did not have `dag` set -- it didn't cause any
    problems until now but was needed to call `upstream_list` on a
    SerializedTaskGroup object.
---
 airflow/models/dag.py                         |  57 ++---------
 airflow/models/taskmixin.py                   |   7 ++
 airflow/serialization/serialized_objects.py   |  15 ++-
 airflow/utils/task_group.py                   |  61 +++++++++++-
 airflow/www/views.py                          |   2 +-
 tests/models/test_dag.py                      |  63 -------------
 tests/serialization/test_dag_serialization.py |   5 +-
 tests/utils/test_task_group.py                | 131 ++++++++++++++++++++++++++
 8 files changed, 218 insertions(+), 123 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8c5b65b45e..2871088392 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -26,7 +26,6 @@ import re
 import sys
 import traceback
 import warnings
-from collections import OrderedDict
 from datetime import datetime, timedelta
 from inspect import signature
 from typing import (
@@ -1718,54 +1717,18 @@ class DAG(LoggingMixin):
         Sorts tasks in topographical order, such that a task comes after any of its
         upstream dependencies.
 
-        Heavily inspired by:
-        http://blog.jupo.org/2012/04/06/topological-sorting-acyclic-directed-graphs/
-
-        :param include_subdag_tasks: whether to include tasks in subdags, default to False
-        :return: list of tasks in topological order
-        """
-        from airflow.operators.subdag import SubDagOperator  # Avoid circular import
-
-        # convert into an OrderedDict to speedup lookup while keeping order the same
-        graph_unsorted = OrderedDict((task.task_id, task) for task in self.tasks)
-
-        graph_sorted: List[Operator] = []
-
-        # special case
-        if len(self.tasks) == 0:
-            return tuple(graph_sorted)
-
-        # Run until the unsorted graph is empty.
-        while graph_unsorted:
-            # Go through each of the node/edges pairs in the unsorted
-            # graph. If a set of edges doesn't contain any nodes that
-            # haven't been resolved, that is, that are still in the
-            # unsorted graph, remove the pair from the unsorted graph,
-            # and append it to the sorted graph. Note here that by using
-            # using the items() method for iterating, a copy of the
-            # unsorted graph is used, allowing us to modify the unsorted
-            # graph as we move through it. We also keep a flag for
-            # checking that graph is acyclic, which is true if any
-            # nodes are resolved during each pass through the graph. If
-            # not, we need to exit as the graph therefore can't be
-            # sorted.
-            acyclic = False
-            for node in list(graph_unsorted.values()):
-                for edge in node.upstream_list:
-                    if edge.node_id in graph_unsorted:
-                        break
-                # no edges in upstream tasks
-                else:
-                    acyclic = True
-                    del graph_unsorted[node.task_id]
-                    graph_sorted.append(node)
-                    if include_subdag_tasks and isinstance(node, SubDagOperator):
-                        graph_sorted.extend(node.subdag.topological_sort(include_subdag_tasks=True))
+        Deprecated in place of ``task_group.topological_sort``
+        """
+        from airflow.utils.task_group import TaskGroup
 
-            if not acyclic:
-                raise AirflowException(f"A cyclic dependency occurred in dag: {self.dag_id}")
+        def nested_topo(group):
+            for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks):
+                if isinstance(node, TaskGroup):
+                    yield from nested_topo(node)
+                else:
+                    yield node
 
-        return tuple(graph_sorted)
+        return tuple(nested_topo(self.task_group))
 
     @provide_session
     def set_dag_runs_state(
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index ad6404869e..118d8c9c23 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -146,6 +146,13 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
     def has_dag(self) -> bool:
         return self.dag is not None
 
+    @property
+    def dag_id(self) -> str:
+        """Returns dag id if it has one or an adhoc/meaningless ID"""
+        if self.dag:
+            return self.dag.dag_id
+        return "_in_memory_dag_"
+
     @property
     def log(self) -> "Logger":
         raise NotImplementedError()
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 02d9cf6c22..d2294e79fc 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1031,10 +1031,9 @@ class SerializedDAG(DAG, BaseSerialization):
             dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)
 
         # Set _task_group
-
         if "_task_group" in encoded_dag:
-            dag._task_group = SerializedTaskGroup.deserialize_task_group(  # type: ignore
-                encoded_dag["_task_group"], None, dag.task_dict
+            dag._task_group = SerializedTaskGroup.deserialize_task_group(
+                encoded_dag["_task_group"], None, dag.task_dict, dag
             )
         else:
             # This must be old data that had no task_group. Create a root TaskGroup and add
@@ -1138,17 +1137,15 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
         encoded_group: Dict[str, Any],
         parent_group: Optional[TaskGroup],
         task_dict: Dict[str, Operator],
-    ) -> Optional[TaskGroup]:
+        dag: SerializedDAG,
+    ) -> TaskGroup:
         """Deserializes a TaskGroup from a JSON object."""
-        if not encoded_group:
-            return None
-
         group_id = cls._deserialize(encoded_group["_group_id"])
         kwargs = {
             key: cls._deserialize(encoded_group[key])
             for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
         }
-        group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs)
+        group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
 
         def set_ref(task: Operator) -> Operator:
             task.task_group = weakref.proxy(group)
@@ -1157,7 +1154,7 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
         group.children = {
             label: set_ref(task_dict[val])  # type: ignore
             if _type == DAT.OP  # type: ignore
-            else SerializedTaskGroup.deserialize_task_group(val, group, task_dict)
+            else SerializedTaskGroup.deserialize_task_group(val, group, task_dict, dag=dag)
             for label, (_type, val) in encoded_group["children"].items()
         }
         group.upstream_group_ids.update(cls._deserialize(encoded_group["upstream_group_ids"]))
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index cba816e1d5..8076f0d252 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -24,7 +24,7 @@ import re
 import weakref
 from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Sequence, Set, Tuple, Union
 
-from airflow.exceptions import AirflowException, DuplicateTaskIdFound
+from airflow.exceptions import AirflowDagCycleException, AirflowException, DuplicateTaskIdFound
 from airflow.models.taskmixin import DAGNode, DependencyMixin
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.utils.helpers import validate_group_key
@@ -401,6 +401,65 @@ class TaskGroup(DAGNode):
             self.task_group._remove(self)
         return MappedTaskGroup(group_id=self._group_id, dag=self.dag, mapped_arg=arg)
 
+    def topological_sort(self, _include_subdag_tasks: bool = False):
+        """
+        Sorts children in topographical order, such that a task comes after any of its
+        upstream dependencies.
+
+        :return: list of tasks in topological order
+        """
+        # This uses a modified version of Kahn's Topological Sort algorithm to
+        # not have to pre-compute the "in-degree" of the nodes.
+        from airflow.operators.subdag import SubDagOperator  # Avoid circular import
+
+        graph_unsorted = copy.copy(self.children)
+
+        graph_sorted: List[DAGNode] = []
+
+        # special case
+        if len(self.children) == 0:
+            return graph_sorted
+
+        # Run until the unsorted graph is empty.
+        while graph_unsorted:
+            # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
+            # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
+            # pair from the unsorted graph, and append it to the sorted graph. Note here that by using using
+            # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
+            # the unsorted graph as we move through it.
+            #
+            # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
+            # during each pass through the graph. If not, we need to exit as the graph therefore can't be
+            # sorted.
+            acyclic = False
+            for node in list(graph_unsorted.values()):
+                for edge in node.upstream_list:
+                    if edge.node_id in graph_unsorted:
+                        break
+                    # Check for task's group is a child (or grand child) of this TG,
+                    tg = edge.task_group
+                    while tg:
+                        if tg.node_id in graph_unsorted:
+                            break
+                        tg = tg.task_group
+
+                    if tg:
+                        # We are already going to visit that TG
+                        break
+                else:
+                    acyclic = True
+                    del graph_unsorted[node.node_id]
+                    graph_sorted.append(node)
+                    if _include_subdag_tasks and isinstance(node, SubDagOperator):
+                        graph_sorted.extend(
+                            node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
+                        )
+
+            if not acyclic:
+                raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
+
+        return graph_sorted
+
 
 class MappedTaskGroup(TaskGroup):
     """
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f90463651b..03cd5e6e02 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -251,7 +251,7 @@ def task_group_to_tree(task_item_or_group, dag, dag_runs, tis, session):
     task_group = task_item_or_group
 
     children = [
-        task_group_to_tree(child, dag, dag_runs, tis, session) for child in task_group.children.values()
+        task_group_to_tree(child, dag, dag_runs, tis, session) for child in task_group.topological_sort()
     ]
 
     def get_summary(dag_run, children):
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index a882547a92..e3fc9b431a 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -255,69 +255,6 @@ class TestDag(unittest.TestCase):
         assert self._occur_before('a_child', 'b_parent', topological_list)
         assert self._occur_before('b_child', 'b_parent', topological_list)
 
-    def test_dag_topological_sort1(self):
-        dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
-        # A -> B
-        # A -> C -> D
-        # ordered: B, D, C, A or D, B, C, A or D, C, B, A
-        with dag:
-            op1 = DummyOperator(task_id='A')
-            op2 = DummyOperator(task_id='B')
-            op3 = DummyOperator(task_id='C')
-            op4 = DummyOperator(task_id='D')
-            op1.set_upstream([op2, op3])
-            op3.set_upstream(op4)
-
-        topological_list = dag.topological_sort()
-        logging.info(topological_list)
-
-        tasks = [op2, op3, op4]
-        assert topological_list[0] in tasks
-        tasks.remove(topological_list[0])
-        assert topological_list[1] in tasks
-        tasks.remove(topological_list[1])
-        assert topological_list[2] in tasks
-        tasks.remove(topological_list[2])
-        assert topological_list[3] == op1
-
-    def test_dag_topological_sort2(self):
-        dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
-        # C -> (A u B) -> D
-        # C -> E
-        # ordered: E | D, A | B, C
-        with dag:
-            op1 = DummyOperator(task_id='A')
-            op2 = DummyOperator(task_id='B')
-            op3 = DummyOperator(task_id='C')
-            op4 = DummyOperator(task_id='D')
-            op5 = DummyOperator(task_id='E')
-            op1.set_downstream(op3)
-            op2.set_downstream(op3)
-            op1.set_upstream(op4)
-            op2.set_upstream(op4)
-            op5.set_downstream(op3)
-
-        topological_list = dag.topological_sort()
-        logging.info(topological_list)
-
-        set1 = [op4, op5]
-        assert topological_list[0] in set1
-        set1.remove(topological_list[0])
-
-        set2 = [op1, op2]
-        set2.extend(set1)
-        assert topological_list[1] in set2
-        set2.remove(topological_list[1])
-
-        assert topological_list[2] in set2
-        set2.remove(topological_list[2])
-
-        assert topological_list[3] in set2
-
-        assert topological_list[4] == op3
-
     def test_dag_topological_sort_dag_without_tasks(self):
         dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index a0bcf0d3ec..02026e96f8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1196,6 +1196,7 @@ class TestStringifiedDAGs:
         assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys()
 
         def check_task_group(node):
+            assert node.dag is serialized_dag
             try:
                 children = node.children.values()
             except AttributeError:
@@ -1770,5 +1771,5 @@ def test_mapped_task_group_serde():
         ],
     }
 
-    with DAG("test", start_date=execution_date):
-        SerializedTaskGroup.deserialize_task_group(serialized, None, dag.task_dict)
+    with DAG("test", start_date=execution_date) as new_dag:
+        SerializedTaskGroup.deserialize_task_group(serialized, None, dag.task_dict, new_dag)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index 21ec0b907b..b944f636cc 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -1154,3 +1154,134 @@ def test_decorator_map():
     tg = dag.task_group.get_child_by_label("my_task_group")
     assert isinstance(tg, MappedTaskGroup)
     assert "my_arg_1" in tg.mapped_kwargs
+
+
+def test_topological_sort1():
+    dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+
+    # A -> B
+    # A -> C -> D
+    # ordered: B, D, C, A or D, B, C, A or D, C, B, A
+    with dag:
+        op1 = DummyOperator(task_id='A')
+        op2 = DummyOperator(task_id='B')
+        op3 = DummyOperator(task_id='C')
+        op4 = DummyOperator(task_id='D')
+        [op2, op3] >> op1
+        op3 >> op4
+
+    topological_list = dag.task_group.topological_sort()
+
+    tasks = [op2, op3, op4]
+    assert topological_list[0] in tasks
+    tasks.remove(topological_list[0])
+    assert topological_list[1] in tasks
+    tasks.remove(topological_list[1])
+    assert topological_list[2] in tasks
+    tasks.remove(topological_list[2])
+    assert topological_list[3] == op1
+
+
+def test_topological_sort2():
+    dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+
+    # C -> (A u B) -> D
+    # C -> E
+    # ordered: E | D, A | B, C
+    with dag:
+        op1 = DummyOperator(task_id='A')
+        op2 = DummyOperator(task_id='B')
+        op3 = DummyOperator(task_id='C')
+        op4 = DummyOperator(task_id='D')
+        op5 = DummyOperator(task_id='E')
+        op3 << [op1, op2]
+        op4 >> [op1, op2]
+        op5 >> op3
+
+    topological_list = dag.task_group.topological_sort()
+
+    set1 = [op4, op5]
+    assert topological_list[0] in set1
+    set1.remove(topological_list[0])
+
+    set2 = [op1, op2]
+    set2.extend(set1)
+    assert topological_list[1] in set2
+    set2.remove(topological_list[1])
+
+    assert topological_list[2] in set2
+    set2.remove(topological_list[2])
+
+    assert topological_list[3] in set2
+
+    assert topological_list[4] == op3
+
+
+def test_topological_nested_groups():
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_dag_edges", start_date=execution_date) as dag:
+        task1 = DummyOperator(task_id="task1")
+        task5 = DummyOperator(task_id="task5")
+        with TaskGroup("group_a") as group_a:
+            with TaskGroup("group_b"):
+                task2 = DummyOperator(task_id="task2")
+                task3 = DummyOperator(task_id="task3")
+                task4 = DummyOperator(task_id="task4")
+                task2 >> [task3, task4]
+
+        task1 >> group_a
+        group_a >> task5
+
+    def nested_topo(group):
+        return [
+            nested_topo(node) if isinstance(node, TaskGroup) else node for node in group.topological_sort()
+        ]
+
+    topological_list = nested_topo(dag.task_group)
+
+    assert topological_list == [
+        task1,
+        [
+            [
+                task2,
+                task3,
+                task4,
+            ],
+        ],
+        task5,
+    ]
+
+
+def test_topological_group_dep():
+    execution_date = pendulum.parse("20200101")
+    with DAG("test_dag_edges", start_date=execution_date) as dag:
+        task1 = DummyOperator(task_id="task1")
+        task6 = DummyOperator(task_id="task6")
+        with TaskGroup("group_a") as group_a:
+            task2 = DummyOperator(task_id="task2")
+            task3 = DummyOperator(task_id="task3")
+        with TaskGroup("group_b") as group_b:
+            task4 = DummyOperator(task_id="task4")
+            task5 = DummyOperator(task_id="task5")
+
+        task1 >> group_a >> group_b >> task6
+
+    def nested_topo(group):
+        return [
+            nested_topo(node) if isinstance(node, TaskGroup) else node for node in group.topological_sort()
+        ]
+
+    topological_list = nested_topo(dag.task_group)
+
+    assert topological_list == [
+        task1,
+        [
+            task2,
+            task3,
+        ],
+        [
+            task4,
+            task5,
+        ],
+        task6,
+    ]