You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/04/05 12:59:17 UTC
[airflow] branch main updated: Show tasks in grid view based on topological sort. (#22741)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 34154803ac Show tasks in grid view based on topological sort. (#22741)
34154803ac is described below
commit 34154803ac73d62d3e969e480405df3073032622
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Apr 5 13:59:09 2022 +0100
Show tasks in grid view based on topological sort. (#22741)
This takes the existing topological sort that existed on a DAG and moves
it down to TaskGroup.
In order to do this (and not have duplicated sort) the existing sort on
DAG is re-implemented on top of the new method.
This also surfaced a tiny bug in deserialize_task_group where the
SerializedTaskGroup did not have `dag` set -- it didn't cause any
problems until now but was needed to call `upstream_list` on a
SerializedTaskGroup object.
---
airflow/models/dag.py | 57 ++---------
airflow/models/taskmixin.py | 7 ++
airflow/serialization/serialized_objects.py | 15 ++-
airflow/utils/task_group.py | 61 +++++++++++-
airflow/www/views.py | 2 +-
tests/models/test_dag.py | 63 -------------
tests/serialization/test_dag_serialization.py | 5 +-
tests/utils/test_task_group.py | 131 ++++++++++++++++++++++++++
8 files changed, 218 insertions(+), 123 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8c5b65b45e..2871088392 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -26,7 +26,6 @@ import re
import sys
import traceback
import warnings
-from collections import OrderedDict
from datetime import datetime, timedelta
from inspect import signature
from typing import (
@@ -1718,54 +1717,18 @@ class DAG(LoggingMixin):
Sorts tasks in topographical order, such that a task comes after any of its
upstream dependencies.
- Heavily inspired by:
- http://blog.jupo.org/2012/04/06/topological-sorting-acyclic-directed-graphs/
-
- :param include_subdag_tasks: whether to include tasks in subdags, default to False
- :return: list of tasks in topological order
- """
- from airflow.operators.subdag import SubDagOperator # Avoid circular import
-
- # convert into an OrderedDict to speedup lookup while keeping order the same
- graph_unsorted = OrderedDict((task.task_id, task) for task in self.tasks)
-
- graph_sorted: List[Operator] = []
-
- # special case
- if len(self.tasks) == 0:
- return tuple(graph_sorted)
-
- # Run until the unsorted graph is empty.
- while graph_unsorted:
- # Go through each of the node/edges pairs in the unsorted
- # graph. If a set of edges doesn't contain any nodes that
- # haven't been resolved, that is, that are still in the
- # unsorted graph, remove the pair from the unsorted graph,
- # and append it to the sorted graph. Note here that by using
- # using the items() method for iterating, a copy of the
- # unsorted graph is used, allowing us to modify the unsorted
- # graph as we move through it. We also keep a flag for
- # checking that graph is acyclic, which is true if any
- # nodes are resolved during each pass through the graph. If
- # not, we need to exit as the graph therefore can't be
- # sorted.
- acyclic = False
- for node in list(graph_unsorted.values()):
- for edge in node.upstream_list:
- if edge.node_id in graph_unsorted:
- break
- # no edges in upstream tasks
- else:
- acyclic = True
- del graph_unsorted[node.task_id]
- graph_sorted.append(node)
- if include_subdag_tasks and isinstance(node, SubDagOperator):
- graph_sorted.extend(node.subdag.topological_sort(include_subdag_tasks=True))
+ Deprecated in place of ``task_group.topological_sort``
+ """
+ from airflow.utils.task_group import TaskGroup
- if not acyclic:
- raise AirflowException(f"A cyclic dependency occurred in dag: {self.dag_id}")
+ def nested_topo(group):
+ for node in group.topological_sort(_include_subdag_tasks=include_subdag_tasks):
+ if isinstance(node, TaskGroup):
+ yield from nested_topo(node)
+ else:
+ yield node
- return tuple(graph_sorted)
+ return tuple(nested_topo(self.task_group))
@provide_session
def set_dag_runs_state(
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index ad6404869e..118d8c9c23 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -146,6 +146,13 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
def has_dag(self) -> bool:
return self.dag is not None
+ @property
+ def dag_id(self) -> str:
+ """Returns dag id if it has one or an adhoc/meaningless ID"""
+ if self.dag:
+ return self.dag.dag_id
+ return "_in_memory_dag_"
+
@property
def log(self) -> "Logger":
raise NotImplementedError()
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 02d9cf6c22..d2294e79fc 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1031,10 +1031,9 @@ class SerializedDAG(DAG, BaseSerialization):
dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)
# Set _task_group
-
if "_task_group" in encoded_dag:
- dag._task_group = SerializedTaskGroup.deserialize_task_group( # type: ignore
- encoded_dag["_task_group"], None, dag.task_dict
+ dag._task_group = SerializedTaskGroup.deserialize_task_group(
+ encoded_dag["_task_group"], None, dag.task_dict, dag
)
else:
# This must be old data that had no task_group. Create a root TaskGroup and add
@@ -1138,17 +1137,15 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
encoded_group: Dict[str, Any],
parent_group: Optional[TaskGroup],
task_dict: Dict[str, Operator],
- ) -> Optional[TaskGroup]:
+ dag: SerializedDAG,
+ ) -> TaskGroup:
"""Deserializes a TaskGroup from a JSON object."""
- if not encoded_group:
- return None
-
group_id = cls._deserialize(encoded_group["_group_id"])
kwargs = {
key: cls._deserialize(encoded_group[key])
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
}
- group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, **kwargs)
+ group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
def set_ref(task: Operator) -> Operator:
task.task_group = weakref.proxy(group)
@@ -1157,7 +1154,7 @@ class SerializedTaskGroup(TaskGroup, BaseSerialization):
group.children = {
label: set_ref(task_dict[val]) # type: ignore
if _type == DAT.OP # type: ignore
- else SerializedTaskGroup.deserialize_task_group(val, group, task_dict)
+ else SerializedTaskGroup.deserialize_task_group(val, group, task_dict, dag=dag)
for label, (_type, val) in encoded_group["children"].items()
}
group.upstream_group_ids.update(cls._deserialize(encoded_group["upstream_group_ids"]))
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index cba816e1d5..8076f0d252 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -24,7 +24,7 @@ import re
import weakref
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Sequence, Set, Tuple, Union
-from airflow.exceptions import AirflowException, DuplicateTaskIdFound
+from airflow.exceptions import AirflowDagCycleException, AirflowException, DuplicateTaskIdFound
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key
@@ -401,6 +401,65 @@ class TaskGroup(DAGNode):
self.task_group._remove(self)
return MappedTaskGroup(group_id=self._group_id, dag=self.dag, mapped_arg=arg)
+ def topological_sort(self, _include_subdag_tasks: bool = False):
+ """
+ Sorts children in topographical order, such that a task comes after any of its
+ upstream dependencies.
+
+ :return: list of tasks in topological order
+ """
+ # This uses a modified version of Kahn's Topological Sort algorithm to
+ # not have to pre-compute the "in-degree" of the nodes.
+ from airflow.operators.subdag import SubDagOperator # Avoid circular import
+
+ graph_unsorted = copy.copy(self.children)
+
+ graph_sorted: List[DAGNode] = []
+
+ # special case
+ if len(self.children) == 0:
+ return graph_sorted
+
+ # Run until the unsorted graph is empty.
+ while graph_unsorted:
+ # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
+ # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
+ # pair from the unsorted graph, and append it to the sorted graph. Note here that by using using
+ # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
+ # the unsorted graph as we move through it.
+ #
+ # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
+ # during each pass through the graph. If not, we need to exit as the graph therefore can't be
+ # sorted.
+ acyclic = False
+ for node in list(graph_unsorted.values()):
+ for edge in node.upstream_list:
+ if edge.node_id in graph_unsorted:
+ break
+ # Check for task's group is a child (or grand child) of this TG,
+ tg = edge.task_group
+ while tg:
+ if tg.node_id in graph_unsorted:
+ break
+ tg = tg.task_group
+
+ if tg:
+ # We are already going to visit that TG
+ break
+ else:
+ acyclic = True
+ del graph_unsorted[node.node_id]
+ graph_sorted.append(node)
+ if _include_subdag_tasks and isinstance(node, SubDagOperator):
+ graph_sorted.extend(
+ node.subdag.task_group.topological_sort(_include_subdag_tasks=True)
+ )
+
+ if not acyclic:
+ raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")
+
+ return graph_sorted
+
class MappedTaskGroup(TaskGroup):
"""
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f90463651b..03cd5e6e02 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -251,7 +251,7 @@ def task_group_to_tree(task_item_or_group, dag, dag_runs, tis, session):
task_group = task_item_or_group
children = [
- task_group_to_tree(child, dag, dag_runs, tis, session) for child in task_group.children.values()
+ task_group_to_tree(child, dag, dag_runs, tis, session) for child in task_group.topological_sort()
]
def get_summary(dag_run, children):
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index a882547a92..e3fc9b431a 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -255,69 +255,6 @@ class TestDag(unittest.TestCase):
assert self._occur_before('a_child', 'b_parent', topological_list)
assert self._occur_before('b_child', 'b_parent', topological_list)
- def test_dag_topological_sort1(self):
- dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
- # A -> B
- # A -> C -> D
- # ordered: B, D, C, A or D, B, C, A or D, C, B, A
- with dag:
- op1 = DummyOperator(task_id='A')
- op2 = DummyOperator(task_id='B')
- op3 = DummyOperator(task_id='C')
- op4 = DummyOperator(task_id='D')
- op1.set_upstream([op2, op3])
- op3.set_upstream(op4)
-
- topological_list = dag.topological_sort()
- logging.info(topological_list)
-
- tasks = [op2, op3, op4]
- assert topological_list[0] in tasks
- tasks.remove(topological_list[0])
- assert topological_list[1] in tasks
- tasks.remove(topological_list[1])
- assert topological_list[2] in tasks
- tasks.remove(topological_list[2])
- assert topological_list[3] == op1
-
- def test_dag_topological_sort2(self):
- dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
- # C -> (A u B) -> D
- # C -> E
- # ordered: E | D, A | B, C
- with dag:
- op1 = DummyOperator(task_id='A')
- op2 = DummyOperator(task_id='B')
- op3 = DummyOperator(task_id='C')
- op4 = DummyOperator(task_id='D')
- op5 = DummyOperator(task_id='E')
- op1.set_downstream(op3)
- op2.set_downstream(op3)
- op1.set_upstream(op4)
- op2.set_upstream(op4)
- op5.set_downstream(op3)
-
- topological_list = dag.topological_sort()
- logging.info(topological_list)
-
- set1 = [op4, op5]
- assert topological_list[0] in set1
- set1.remove(topological_list[0])
-
- set2 = [op1, op2]
- set2.extend(set1)
- assert topological_list[1] in set2
- set2.remove(topological_list[1])
-
- assert topological_list[2] in set2
- set2.remove(topological_list[2])
-
- assert topological_list[3] in set2
-
- assert topological_list[4] == op3
-
def test_dag_topological_sort_dag_without_tasks(self):
dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index a0bcf0d3ec..02026e96f8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1196,6 +1196,7 @@ class TestStringifiedDAGs:
assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys()
def check_task_group(node):
+ assert node.dag is serialized_dag
try:
children = node.children.values()
except AttributeError:
@@ -1770,5 +1771,5 @@ def test_mapped_task_group_serde():
],
}
- with DAG("test", start_date=execution_date):
- SerializedTaskGroup.deserialize_task_group(serialized, None, dag.task_dict)
+ with DAG("test", start_date=execution_date) as new_dag:
+ SerializedTaskGroup.deserialize_task_group(serialized, None, dag.task_dict, new_dag)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index 21ec0b907b..b944f636cc 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -1154,3 +1154,134 @@ def test_decorator_map():
tg = dag.task_group.get_child_by_label("my_task_group")
assert isinstance(tg, MappedTaskGroup)
assert "my_arg_1" in tg.mapped_kwargs
+
+
+def test_topological_sort1():
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+
+ # A -> B
+ # A -> C -> D
+ # ordered: B, D, C, A or D, B, C, A or D, C, B, A
+ with dag:
+ op1 = DummyOperator(task_id='A')
+ op2 = DummyOperator(task_id='B')
+ op3 = DummyOperator(task_id='C')
+ op4 = DummyOperator(task_id='D')
+ [op2, op3] >> op1
+ op3 >> op4
+
+ topological_list = dag.task_group.topological_sort()
+
+ tasks = [op2, op3, op4]
+ assert topological_list[0] in tasks
+ tasks.remove(topological_list[0])
+ assert topological_list[1] in tasks
+ tasks.remove(topological_list[1])
+ assert topological_list[2] in tasks
+ tasks.remove(topological_list[2])
+ assert topological_list[3] == op1
+
+
+def test_topological_sort2():
+ dag = DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+
+ # C -> (A u B) -> D
+ # C -> E
+ # ordered: E | D, A | B, C
+ with dag:
+ op1 = DummyOperator(task_id='A')
+ op2 = DummyOperator(task_id='B')
+ op3 = DummyOperator(task_id='C')
+ op4 = DummyOperator(task_id='D')
+ op5 = DummyOperator(task_id='E')
+ op3 << [op1, op2]
+ op4 >> [op1, op2]
+ op5 >> op3
+
+ topological_list = dag.task_group.topological_sort()
+
+ set1 = [op4, op5]
+ assert topological_list[0] in set1
+ set1.remove(topological_list[0])
+
+ set2 = [op1, op2]
+ set2.extend(set1)
+ assert topological_list[1] in set2
+ set2.remove(topological_list[1])
+
+ assert topological_list[2] in set2
+ set2.remove(topological_list[2])
+
+ assert topological_list[3] in set2
+
+ assert topological_list[4] == op3
+
+
+def test_topological_nested_groups():
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_dag_edges", start_date=execution_date) as dag:
+ task1 = DummyOperator(task_id="task1")
+ task5 = DummyOperator(task_id="task5")
+ with TaskGroup("group_a") as group_a:
+ with TaskGroup("group_b"):
+ task2 = DummyOperator(task_id="task2")
+ task3 = DummyOperator(task_id="task3")
+ task4 = DummyOperator(task_id="task4")
+ task2 >> [task3, task4]
+
+ task1 >> group_a
+ group_a >> task5
+
+ def nested_topo(group):
+ return [
+ nested_topo(node) if isinstance(node, TaskGroup) else node for node in group.topological_sort()
+ ]
+
+ topological_list = nested_topo(dag.task_group)
+
+ assert topological_list == [
+ task1,
+ [
+ [
+ task2,
+ task3,
+ task4,
+ ],
+ ],
+ task5,
+ ]
+
+
+def test_topological_group_dep():
+ execution_date = pendulum.parse("20200101")
+ with DAG("test_dag_edges", start_date=execution_date) as dag:
+ task1 = DummyOperator(task_id="task1")
+ task6 = DummyOperator(task_id="task6")
+ with TaskGroup("group_a") as group_a:
+ task2 = DummyOperator(task_id="task2")
+ task3 = DummyOperator(task_id="task3")
+ with TaskGroup("group_b") as group_b:
+ task4 = DummyOperator(task_id="task4")
+ task5 = DummyOperator(task_id="task5")
+
+ task1 >> group_a >> group_b >> task6
+
+ def nested_topo(group):
+ return [
+ nested_topo(node) if isinstance(node, TaskGroup) else node for node in group.topological_sort()
+ ]
+
+ topological_list = nested_topo(dag.task_group)
+
+ assert topological_list == [
+ task1,
+ [
+ task2,
+ task3,
+ ],
+ [
+ task4,
+ task5,
+ ],
+ task6,
+ ]