You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/07/05 18:11:06 UTC

[airflow] 01/03: Fix cycle bug with attaching label to task group (#24847)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7a4869c737e46acb1fe794acef588b7251d2c4a9
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Tue Jul 5 16:40:00 2022 +0100

    Fix cycle bug with attaching label to task group (#24847)
    
    The problem was specific to EdgeModifiers as they try to be
    "transparent" to upstream/downstream
    
    The fix is to set track the upstream/downstream for the task group
    before making any changes to the EdgeModifiers' relations -- otherwise
    the roots of the TG were added as dependencies to themeslves!
    
    (cherry picked from commit efc05a5f0b3d261293c2efaf6771e4af9a2f324c)
---
 airflow/utils/task_group.py    | 12 ++++++------
 tests/utils/test_task_group.py | 25 +++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index ed8d380ff0..64c11f79db 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -283,6 +283,12 @@ class TaskGroup(DAGNode):
         Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup.
         Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids.
         """
+        if not isinstance(task_or_task_list, Sequence):
+            task_or_task_list = [task_or_task_list]
+
+        for task_like in task_or_task_list:
+            self.update_relative(task_like, upstream)
+
         if upstream:
             for task in self.get_roots():
                 task.set_upstream(task_or_task_list)
@@ -290,12 +296,6 @@ class TaskGroup(DAGNode):
             for task in self.get_leaves():
                 task.set_downstream(task_or_task_list)
 
-        if not isinstance(task_or_task_list, Sequence):
-            task_or_task_list = [task_or_task_list]
-
-        for task_like in task_or_task_list:
-            self.update_relative(task_like, upstream)
-
     def __enter__(self) -> "TaskGroup":
         TaskGroupContext.push_context_managed_task_group(self)
         return self
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index 9aacc96b82..864b2fb68a 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -1222,3 +1222,28 @@ def test_add_to_another_group():
             tg.add(task)
 
     assert str(ctx.value) == "cannot add 'section_2.task' to 'section_1' (already in group 'section_2')"
+
+
+def test_task_group_edge_modifier_chain():
+    from airflow.models.baseoperator import chain
+    from airflow.utils.edgemodifier import Label
+
+    with DAG(dag_id="test", start_date=pendulum.DateTime(2022, 5, 20)) as dag:
+        start = EmptyOperator(task_id="sleep_3_seconds")
+
+        with TaskGroup(group_id="group1") as tg:
+            t1 = EmptyOperator(task_id="dummy1")
+            t2 = EmptyOperator(task_id="dummy2")
+
+        t3 = EmptyOperator(task_id="echo_done")
+
+    # The case we are testing for is when a Label is inside a list -- meaning that we do tg.set_upstream
+    # instead of label.set_downstream
+    chain(start, [Label("branch three")], tg, t3)
+
+    assert start.downstream_task_ids == {t1.node_id, t2.node_id}
+    assert t3.upstream_task_ids == {t1.node_id, t2.node_id}
+    assert tg.upstream_task_ids == set()
+    assert tg.downstream_task_ids == {t3.node_id}
+    # Check that we can perform a topological_sort
+    dag.topological_sort()