You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/02/18 17:57:52 UTC

[airflow] branch main updated: Edgemodifier refactoring w/ labels in TaskGroup edge case (#29410)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b05468129 Edgemodifier refactoring w/ labels in TaskGroup edge case (#29410)
4b05468129 is described below

commit 4b05468129361946688909943fe332f383302069
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Sat Feb 18 12:57:44 2023 -0500

    Edgemodifier refactoring w/ labels in TaskGroup edge case (#29410)
    
    * Edgemodifier refactoring w/ labels in TaskGroup edge case
---
 airflow/models/taskmixin.py      |  14 +-
 airflow/utils/edgemodifier.py    | 145 ++++++-----
 airflow/utils/task_group.py      |  18 +-
 tests/utils/test_edgemodifier.py | 501 +++++++++++++++++++++++++++++++++++++++
 4 files changed, 618 insertions(+), 60 deletions(-)

diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 211fad6ff9..d93879498b 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -56,16 +56,22 @@ class DependencyMixin:
         raise NotImplementedError()
 
     @abstractmethod
-    def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
+    def set_upstream(
+        self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
+    ):
         """Set a task or a task list to be directly upstream from the current task."""
         raise NotImplementedError()
 
     @abstractmethod
-    def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
+    def set_downstream(
+        self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
+    ):
         """Set a task or a task list to be directly downstream from the current task."""
         raise NotImplementedError()
 
-    def update_relative(self, other: DependencyMixin, upstream=True) -> None:
+    def update_relative(
+        self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
+    ) -> None:
         """
         Update relationship information about another TaskMixin. Default is no-op.
         Override if necessary.
@@ -172,7 +178,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
 
         task_list: list[Operator] = []
         for task_object in task_or_task_list:
-            task_object.update_relative(self, not upstream)
+            task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
             relatives = task_object.leaves if upstream else task_object.roots
             for task in relatives:
                 if not isinstance(task, (BaseOperator, MappedOperator)):
diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py
index d2135cf233..b693e1a1be 100644
--- a/airflow/utils/edgemodifier.py
+++ b/airflow/utils/edgemodifier.py
@@ -16,12 +16,10 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+from typing import Sequence
 
-from airflow.models.taskmixin import DependencyMixin
-
-if TYPE_CHECKING:
-    from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskmixin import DAGNode, DependencyMixin
+from airflow.utils.task_group import TaskGroup
 
 
 class EdgeModifier(DependencyMixin):
@@ -44,8 +42,8 @@ class EdgeModifier(DependencyMixin):
 
     def __init__(self, label: str | None = None):
         self.label = label
-        self._upstream: list[BaseOperator] = []
-        self._downstream: list[BaseOperator] = []
+        self._upstream: list[DependencyMixin] = []
+        self._downstream: list[DependencyMixin] = []
 
     @property
     def roots(self):
@@ -55,8 +53,72 @@ class EdgeModifier(DependencyMixin):
     def leaves(self):
         return self._upstream
 
+    @staticmethod
+    def _make_list(item_or_list: DependencyMixin | Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
+        if not isinstance(item_or_list, Sequence):
+            return [item_or_list]
+        return item_or_list
+
+    def _save_nodes(
+        self,
+        nodes: DependencyMixin | Sequence[DependencyMixin],
+        stream: list[DependencyMixin],
+    ):
+        from airflow.models.xcom_arg import XComArg
+
+        for node in self._make_list(nodes):
+            if isinstance(node, (TaskGroup, XComArg, DAGNode)):
+                stream.append(node)
+            else:
+                raise TypeError(
+                    f"Cannot use edge labels with {type(node).__name__}, "
+                    f"only tasks, XComArg or TaskGroups"
+                )
+
+    def _convert_streams_to_task_groups(self):
+        """
+        Both self._upstream and self._downstream are required to determine if
+        we should convert a node to a TaskGroup or leave it as a DAGNode.
+
+        To do this, we keep a set of group_ids seen among the streams. If we find that
+        the nodes are from the same TaskGroup, we will leave them as DAGNodes and not
+        convert them to TaskGroups
+        """
+        from airflow.models.xcom_arg import XComArg
+
+        group_ids = set()
+        for node in [*self._upstream, *self._downstream]:
+            if isinstance(node, DAGNode) and node.task_group:
+                if node.task_group.is_root:
+                    group_ids.add("root")
+                else:
+                    group_ids.add(node.task_group.group_id)
+            elif isinstance(node, TaskGroup):
+                group_ids.add(node.group_id)
+            elif isinstance(node, XComArg):
+                if isinstance(node.operator, DAGNode) and node.operator.task_group:
+                    if node.operator.task_group.is_root:
+                        group_ids.add("root")
+                    else:
+                        group_ids.add(node.operator.task_group.group_id)
+
+        # If all nodes originate from the same TaskGroup, we will not convert them
+        if len(group_ids) != 1:
+            self._upstream = self._convert_stream_to_task_groups(self._upstream)
+            self._downstream = self._convert_stream_to_task_groups(self._downstream)
+
+    def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
+        return [
+            node.task_group
+            if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root
+            else node
+            for node in stream
+        ]
+
     def set_upstream(
-        self, task_or_task_list: DependencyMixin | Sequence[DependencyMixin], chain: bool = True
+        self,
+        other: DependencyMixin | Sequence[DependencyMixin],
+        edge_modifier: EdgeModifier | None = None,
     ):
         """
         Sets the given task/list onto the upstream attribute, and then checks if
@@ -64,30 +126,17 @@ class EdgeModifier(DependencyMixin):
 
         Providing this also provides << via DependencyMixin.
         """
-        from airflow.models.baseoperator import BaseOperator
-
-        # Ensure we have a list, even if it's just one item
-        if isinstance(task_or_task_list, DependencyMixin):
-            task_or_task_list = [task_or_task_list]
-        # Unfurl it into actual operators
-        operators: list[BaseOperator] = []
-        for task in task_or_task_list:
-            for root in task.roots:
-                if not isinstance(root, BaseOperator):
-                    raise TypeError(f"Cannot use edge labels with {type(root).__name__}, only operators")
-                operators.append(root)
-        # For each already-declared downstream, pair off with each new upstream
-        # item and store the edge info.
-        for operator in operators:
-            for downstream in self._downstream:
-                self.add_edge_info(operator.dag, operator.task_id, downstream.task_id)
-                if chain:
-                    operator.set_downstream(downstream)
-        # Add the new tasks to our list of ones we've seen
-        self._upstream.extend(operators)
+        self._save_nodes(other, self._upstream)
+        if self._upstream and self._downstream:
+            # Convert _upstream and _downstream to task_groups only after both are set
+            self._convert_streams_to_task_groups()
+        for node in self._downstream:
+            node.set_upstream(other, edge_modifier=self)
 
     def set_downstream(
-        self, task_or_task_list: DependencyMixin | Sequence[DependencyMixin], chain: bool = True
+        self,
+        other: DependencyMixin | Sequence[DependencyMixin],
+        edge_modifier: EdgeModifier | None = None,
     ):
         """
         Sets the given task/list onto the downstream attribute, and then checks if
@@ -95,36 +144,24 @@ class EdgeModifier(DependencyMixin):
 
         Providing this also provides >> via DependencyMixin.
         """
-        from airflow.models.baseoperator import BaseOperator
-
-        # Ensure we have a list, even if it's just one item
-        if isinstance(task_or_task_list, DependencyMixin):
-            task_or_task_list = [task_or_task_list]
-        # Unfurl it into actual operators
-        operators: list[BaseOperator] = []
-        for task in task_or_task_list:
-            for leaf in task.leaves:
-                if not isinstance(leaf, BaseOperator):
-                    raise TypeError(f"Cannot use edge labels with {type(leaf).__name__}, only operators")
-                operators.append(leaf)
-        # Pair them off with existing
-        for operator in operators:
-            for upstream in self._upstream:
-                self.add_edge_info(upstream.dag, upstream.task_id, operator.task_id)
-                if chain:
-                    upstream.set_downstream(operator)
-        # Add the new tasks to our list of ones we've seen
-        self._downstream.extend(operators)
-
-    def update_relative(self, other: DependencyMixin, upstream: bool = True) -> None:
+        self._save_nodes(other, self._downstream)
+        if self._upstream and self._downstream:
+            # Convert _upstream and _downstream to task_groups only after both are set
+            self._convert_streams_to_task_groups()
+        for node in self._upstream:
+            node.set_downstream(other, edge_modifier=self)
+
+    def update_relative(
+        self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
+    ) -> None:
         """
         Called if we're not the "main" side of a relationship; we still run the
         same logic, though.
         """
         if upstream:
-            self.set_upstream(other, chain=False)
+            self.set_upstream(other)
         else:
-            self.set_downstream(other, chain=False)
+            self.set_downstream(other)
 
     def add_edge_info(self, dag, upstream_id: str, downstream_id: str):
         """
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index c9635c40c8..eb63f831c0 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -253,7 +253,9 @@ class TaskGroup(DAGNode):
         """group_id excluding parent's group_id used as the node label in UI."""
         return self._group_id
 
-    def update_relative(self, other: DependencyMixin, upstream=True) -> None:
+    def update_relative(
+        self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
+    ) -> None:
         """
         Overrides TaskMixin.update_relative.
 
@@ -264,8 +266,12 @@ class TaskGroup(DAGNode):
             # Handles setting relationship between a TaskGroup and another TaskGroup
             if upstream:
                 parent, child = (self, other)
+                if edge_modifier:
+                    edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
             else:
                 parent, child = (other, self)
+                if edge_modifier:
+                    edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
 
             parent.upstream_group_ids.add(child.group_id)
             child.downstream_group_ids.add(parent.group_id)
@@ -278,10 +284,18 @@ class TaskGroup(DAGNode):
                         f"or operators; received {task.__class__.__name__}"
                     )
 
+                # Do not set a relationship between a TaskGroup and a Label's roots
+                if self == task:
+                    continue
+
                 if upstream:
                     self.upstream_task_ids.add(task.node_id)
+                    if edge_modifier:
+                        edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
                 else:
                     self.downstream_task_ids.add(task.node_id)
+                    if edge_modifier:
+                        edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
 
     def _set_relatives(
         self,
@@ -297,7 +311,7 @@ class TaskGroup(DAGNode):
             task_or_task_list = [task_or_task_list]
 
         for task_like in task_or_task_list:
-            self.update_relative(task_like, upstream)
+            self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
 
         if upstream:
             for task in self.get_roots():
diff --git a/tests/utils/test_edgemodifier.py b/tests/utils/test_edgemodifier.py
index 97097679ce..4d92f075b8 100644
--- a/tests/utils/test_edgemodifier.py
+++ b/tests/utils/test_edgemodifier.py
@@ -22,9 +22,11 @@ import pytest
 
 from airflow import DAG
 from airflow.models.xcom_arg import XComArg
+from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import PythonOperator
 from airflow.utils.edgemodifier import Label
 from airflow.utils.task_group import TaskGroup
+from airflow.www.views import dag_edges
 
 DEFAULT_ARGS = {
     "owner": "test",
@@ -63,6 +65,283 @@ def test_taskgroup_dag():
             return dag, group, (op1, op2, op3, op4)
 
 
+@pytest.fixture
+def test_complex_taskgroup_dag():
+    """Creates a test DAG with many operators and a task group."""
+
+    def f(task_id):
+        return f"OP:{task_id}"
+
+    with DAG(dag_id="test_complex_dag", default_args=DEFAULT_ARGS) as dag:
+        with TaskGroup("group_1") as group:
+            group_emp1 = EmptyOperator(task_id="group_empty1")
+            group_emp2 = EmptyOperator(task_id="group_empty2")
+            group_emp3 = EmptyOperator(task_id="group_empty3")
+        emp_in1 = EmptyOperator(task_id="empty_in1")
+        emp_in2 = EmptyOperator(task_id="empty_in2")
+        emp_in3 = EmptyOperator(task_id="empty_in3")
+        emp_in4 = EmptyOperator(task_id="empty_in4")
+        emp_out1 = EmptyOperator(task_id="empty_out1")
+        emp_out2 = EmptyOperator(task_id="empty_out2")
+        emp_out3 = EmptyOperator(task_id="empty_out3")
+        emp_out4 = EmptyOperator(task_id="empty_out4")
+        op_in1 = PythonOperator(python_callable=f, task_id="op_in1")
+        op_out1 = PythonOperator(python_callable=f, task_id="op_out1")
+
+        return (
+            dag,
+            group,
+            (
+                group_emp1,
+                group_emp2,
+                group_emp3,
+                emp_in1,
+                emp_in2,
+                emp_in3,
+                emp_in4,
+                emp_out1,
+                emp_out2,
+                emp_out3,
+                emp_out4,
+                op_in1,
+                op_out1,
+            ),
+        )
+
+
+@pytest.fixture
+def test_multiple_taskgroups_dag():
+    """Creates a test DAG with many operators and multiple task groups."""
+
+    def f(task_id):
+        return f"OP:{task_id}"
+
+    with DAG(dag_id="test_multiple_task_group_dag", default_args=DEFAULT_ARGS) as dag:
+        with TaskGroup("group1") as group1:
+            group1_emp1 = EmptyOperator(task_id="group1_empty1")
+            group1_emp2 = EmptyOperator(task_id="group1_empty2")
+            group1_emp3 = EmptyOperator(task_id="group1_empty3")
+        with TaskGroup("group2") as group2:
+            group2_emp1 = EmptyOperator(task_id="group2_empty1")
+            group2_emp2 = EmptyOperator(task_id="group2_empty2")
+            group2_emp3 = EmptyOperator(task_id="group2_empty3")
+            group2_op1 = PythonOperator(python_callable=f, task_id="group2_op1")
+            group2_op2 = PythonOperator(python_callable=f, task_id="group2_op2")
+
+            with TaskGroup("group3") as group3:
+                group3_emp1 = EmptyOperator(task_id="group3_empty1")
+                group3_emp2 = EmptyOperator(task_id="group3_empty2")
+                group3_emp3 = EmptyOperator(task_id="group3_empty3")
+        emp_in1 = EmptyOperator(task_id="empty_in1")
+        emp_in2 = EmptyOperator(task_id="empty_in2")
+        emp_in3 = EmptyOperator(task_id="empty_in3")
+        emp_in4 = EmptyOperator(task_id="empty_in4")
+        emp_out1 = EmptyOperator(task_id="empty_out1")
+        emp_out2 = EmptyOperator(task_id="empty_out2")
+        emp_out3 = EmptyOperator(task_id="empty_out3")
+        emp_out4 = EmptyOperator(task_id="empty_out4")
+        op_in1 = PythonOperator(python_callable=f, task_id="op_in1")
+        op_out1 = PythonOperator(python_callable=f, task_id="op_out1")
+
+        return (
+            dag,
+            group1,
+            group2,
+            group3,
+            (
+                group1_emp1,
+                group1_emp2,
+                group1_emp3,
+                group2_emp1,
+                group2_emp2,
+                group2_emp3,
+                group2_op1,
+                group2_op2,
+                group3_emp1,
+                group3_emp2,
+                group3_emp3,
+                emp_in1,
+                emp_in2,
+                emp_in3,
+                emp_in4,
+                emp_out1,
+                emp_out2,
+                emp_out3,
+                emp_out4,
+                op_in1,
+                op_out1,
+            ),
+        )
+
+
+@pytest.fixture
+def simple_dag_expected_edges():
+    return [
+        {"source_id": "group_1.downstream_join_id", "target_id": "test_op_4"},
+        {"source_id": "group_1.test_op_2", "target_id": "group_1.downstream_join_id"},
+        {"source_id": "group_1.test_op_3", "target_id": "group_1.downstream_join_id"},
+        {"source_id": "group_1.upstream_join_id", "target_id": "group_1.test_op_2"},
+        {"source_id": "group_1.upstream_join_id", "target_id": "group_1.test_op_3"},
+        {"label": "Label", "source_id": "test_op_1", "target_id": "group_1.upstream_join_id"},
+    ]
+
+
+@pytest.fixture
+def complex_dag_expected_edges():
+    return [
+        {"source_id": "empty_in1", "target_id": "group_1.upstream_join_id"},
+        {
+            "label": "label emp_in2 <=> group",
+            "source_id": "empty_in2",
+            "target_id": "group_1.upstream_join_id",
+        },
+        {
+            "label": "label emp_in3/emp_in4 <=> group",
+            "source_id": "empty_in3",
+            "target_id": "group_1.upstream_join_id",
+        },
+        {
+            "label": "label emp_in3/emp_in4 <=> group",
+            "source_id": "empty_in4",
+            "target_id": "group_1.upstream_join_id",
+        },
+        {"source_id": "group_1.downstream_join_id", "target_id": "empty_out1"},
+        {
+            "label": "label group <=> emp_out2",
+            "source_id": "group_1.downstream_join_id",
+            "target_id": "empty_out2",
+        },
+        {
+            "label": "label group <=> emp_out3/emp_out4",
+            "source_id": "group_1.downstream_join_id",
+            "target_id": "empty_out3",
+        },
+        {
+            "label": "label group <=> emp_out3/emp_out4",
+            "source_id": "group_1.downstream_join_id",
+            "target_id": "empty_out4",
+        },
+        {
+            "label": "label group <=> op_out1",
+            "source_id": "group_1.downstream_join_id",
+            "target_id": "op_out1",
+        },
+        {"source_id": "group_1.group_empty1", "target_id": "group_1.downstream_join_id"},
+        {"source_id": "group_1.group_empty2", "target_id": "group_1.group_empty1"},
+        {"source_id": "group_1.group_empty3", "target_id": "group_1.group_empty1"},
+        {"source_id": "group_1.upstream_join_id", "target_id": "group_1.group_empty2"},
+        {"source_id": "group_1.upstream_join_id", "target_id": "group_1.group_empty3"},
+        {
+            "label": "label op_in1 <=> group",
+            "source_id": "op_in1",
+            "target_id": "group_1.upstream_join_id",
+        },
+    ]
+
+
+@pytest.fixture
+def multiple_taskgroups_dag_expected_edges():
+    return [
+        {"source_id": "empty_in1", "target_id": "group1.upstream_join_id"},
+        {
+            "label": "label emp_in2 <=> group1",
+            "source_id": "empty_in2",
+            "target_id": "group1.upstream_join_id",
+        },
+        {
+            "label": "label emp_in3/emp_in4 <=> group1",
+            "source_id": "empty_in3",
+            "target_id": "group1.upstream_join_id",
+        },
+        {
+            "label": "label emp_in3/emp_in4 <=> group1",
+            "source_id": "empty_in4",
+            "target_id": "group1.upstream_join_id",
+        },
+        {
+            "label": "label group1 <=> group2",
+            "source_id": "group1.downstream_join_id",
+            "target_id": "group2.upstream_join_id",
+        },
+        {
+            "label": "label group1.group1_emp1 <=> group1.group1_emp2",
+            "source_id": "group1.group1_empty1",
+            "target_id": "group1.group1_empty3",
+        },
+        {"source_id": "group1.group1_empty2", "target_id": "group1.downstream_join_id"},
+        {"source_id": "group1.group1_empty3", "target_id": "group1.downstream_join_id"},
+        {"source_id": "group1.upstream_join_id", "target_id": "group1.group1_empty1"},
+        {"source_id": "group1.upstream_join_id", "target_id": "group1.group1_empty2"},
+        {
+            "label": "label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3",
+            "source_id": "group2.group2_empty1",
+            "target_id": "group2.group2_empty2",
+        },
+        {
+            "label": "label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3",
+            "source_id": "group2.group2_empty1",
+            "target_id": "group2.group2_empty3",
+        },
+        {
+            "label": "label group2.group2_emp1/group2.group2_emp2 <=> group2.group2_emp3",
+            "source_id": "group2.group2_empty2",
+            "target_id": "group2.group2_empty3",
+        },
+        {
+            "label": "label group2.group2_emp3 <=> group3",
+            "source_id": "group2.group2_empty3",
+            "target_id": "group2.group3.upstream_join_id",
+        },
+        {
+            "label": "label group2.group2_op1 <=> group2.group2_op2",
+            "source_id": "group2.group2_op1",
+            "target_id": "group2.group2_op2",
+        },
+        {
+            "label": "label group2.group2_op2 <=> group3",
+            "source_id": "group2.group2_op2",
+            "target_id": "group2.group3.upstream_join_id",
+        },
+        {"source_id": "group2.group3.downstream_join_id", "target_id": "empty_out1"},
+        {
+            "label": "label group3 <=> emp_out2",
+            "source_id": "group2.group3.downstream_join_id",
+            "target_id": "empty_out2",
+        },
+        {
+            "label": "label group3 <=> emp_out3/emp_out4",
+            "source_id": "group2.group3.downstream_join_id",
+            "target_id": "empty_out3",
+        },
+        {
+            "label": "label group3 <=> emp_out3/emp_out4",
+            "source_id": "group2.group3.downstream_join_id",
+            "target_id": "empty_out4",
+        },
+        {
+            "label": "label group3 <=> op_out1",
+            "source_id": "group2.group3.downstream_join_id",
+            "target_id": "op_out1",
+        },
+        {"source_id": "group2.group3.group3_empty1", "target_id": "group2.group3.downstream_join_id"},
+        {"source_id": "group2.group3.group3_empty2", "target_id": "group2.group3.downstream_join_id"},
+        {"source_id": "group2.group3.group3_empty3", "target_id": "group2.group3.downstream_join_id"},
+        {"source_id": "group2.group3.upstream_join_id", "target_id": "group2.group3.group3_empty1"},
+        {"source_id": "group2.group3.upstream_join_id", "target_id": "group2.group3.group3_empty2"},
+        {"source_id": "group2.group3.upstream_join_id", "target_id": "group2.group3.group3_empty3"},
+        {"source_id": "group2.upstream_join_id", "target_id": "group2.group2_empty1"},
+        {"source_id": "group2.upstream_join_id", "target_id": "group2.group2_op1"},
+        {"label": "label op_in1 <=> group1", "source_id": "op_in1", "target_id": "group1.upstream_join_id"},
+    ]
+
+
+def compare_dag_edges(current, expected):
+    assert len(current) == len(expected)
+
+    for i in current:
+        assert current.count(i) == expected.count(i), f"The unexpected DAG edge: {i}"
+
+
 class TestEdgeModifierBuilding:
     """
     Tests that EdgeModifiers work when composed with Tasks (either via >>
@@ -163,3 +442,225 @@ class TestEdgeModifierBuilding:
         assert dag.get_edge_info(op1.task_id, op2.task_id) == {"label": "Group label"}
         assert dag.get_edge_info(op1.task_id, op3.task_id) == {"label": "Group label"}
         assert dag.get_edge_info(op3.task_id, op4.task_id) == {}
+
+    def test_simple_dag(self, test_taskgroup_dag, simple_dag_expected_edges):
+        """Tests the simple dag with a TaskGroup and a Label"""
+        dag, group, (op1, op2, op3, op4) = test_taskgroup_dag
+        op1 >> Label("Label") >> group >> op4
+        compare_dag_edges(dag_edges(dag), simple_dag_expected_edges)
+
+    def test_simple_reversed_dag(self, test_taskgroup_dag, simple_dag_expected_edges):
+        """Tests the simple reversed dag with a TaskGroup and a Label"""
+        dag, group, (op1, op2, op3, op4) = test_taskgroup_dag
+        op4 << group << Label("Label") << op1
+        compare_dag_edges(dag_edges(dag), simple_dag_expected_edges)
+
+    def test_complex_dag(self, test_complex_taskgroup_dag, complex_dag_expected_edges):
+        """Tests the complex dag with a TaskGroup and a Label"""
+        (
+            dag,
+            group,
+            (
+                group_emp1,
+                group_emp2,
+                group_emp3,
+                emp_in1,
+                emp_in2,
+                emp_in3,
+                emp_in4,
+                emp_out1,
+                emp_out2,
+                emp_out3,
+                emp_out4,
+                op_in1,
+                op_out1,
+            ),
+        ) = test_complex_taskgroup_dag
+
+        [group_emp2, group_emp3] >> group_emp1
+
+        emp_in1 >> group
+        emp_in2 >> Label("label emp_in2 <=> group") >> group
+        [emp_in3, emp_in4] >> Label("label emp_in3/emp_in4 <=> group") >> group
+        XComArg(op_in1, "test_key") >> Label("label op_in1 <=> group") >> group
+
+        group >> emp_out1
+        group >> Label("label group <=> emp_out2") >> emp_out2
+        group >> Label("label group <=> emp_out3/emp_out4") >> [emp_out3, emp_out4]
+        group >> Label("label group <=> op_out1") >> XComArg(op_out1, "test_key")
+
+        compare_dag_edges(dag_edges(dag), complex_dag_expected_edges)
+
+    def test_complex_reversed_dag(self, test_complex_taskgroup_dag, complex_dag_expected_edges):
+        """Tests the complex reversed dag with a TaskGroup and a Label"""
+        (
+            dag,
+            group,
+            (
+                group_emp1,
+                group_emp2,
+                group_emp3,
+                emp_in1,
+                emp_in2,
+                emp_in3,
+                emp_in4,
+                emp_out1,
+                emp_out2,
+                emp_out3,
+                emp_out4,
+                op_in1,
+                op_out1,
+            ),
+        ) = test_complex_taskgroup_dag
+
+        group_emp1 << [group_emp2, group_emp3]
+
+        group << emp_in1
+        group << Label("label emp_in2 <=> group") << emp_in2
+        group << Label("label emp_in3/emp_in4 <=> group") << [emp_in3, emp_in4]
+        group << Label("label op_in1 <=> group") << XComArg(op_in1, "test_key")
+
+        emp_out1 << group
+        emp_out2 << Label("label group <=> emp_out2") << group
+        [emp_out3, emp_out4] << Label("label group <=> emp_out3/emp_out4") << group
+        XComArg(op_out1, "test_key") << Label("label group <=> op_out1") << group
+
+        compare_dag_edges(dag_edges(dag), complex_dag_expected_edges)
+
+    def test_multiple_task_groups_dag(
+        self, test_multiple_taskgroups_dag, multiple_taskgroups_dag_expected_edges
+    ):
+        """Tests multiple task groups and labels"""
+        (
+            dag,
+            group1,
+            group2,
+            group3,
+            (
+                group1_emp1,
+                group1_emp2,
+                group1_emp3,
+                group2_emp1,
+                group2_emp2,
+                group2_emp3,
+                group2_op1,
+                group2_op2,
+                group3_emp1,
+                group3_emp2,
+                group3_emp3,
+                emp_in1,
+                emp_in2,
+                emp_in3,
+                emp_in4,
+                emp_out1,
+                emp_out2,
+                emp_out3,
+                emp_out4,
+                op_in1,
+                op_out1,
+            ),
+        ) = test_multiple_taskgroups_dag
+
+        group1_emp1 >> Label("label group1.group1_emp1 <=> group1.group1_emp2") >> group1_emp3
+
+        emp_in1 >> group1
+        emp_in2 >> Label("label emp_in2 <=> group1") >> group1
+        [emp_in3, emp_in4] >> Label("label emp_in3/emp_in4 <=> group1") >> group1
+        XComArg(op_in1, "test_key") >> Label("label op_in1 <=> group1") >> group1
+
+        (
+            [group2_emp1, group2_emp2]
+            >> Label("label group2.group2_emp1/group2.group2_emp2 <=> group2.group2_emp3")
+            >> group2_emp3
+        )
+        (
+            group2_emp1
+            >> Label("label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3")
+            >> [group2_emp2, group2_emp3]
+        )
+        group2_emp3 >> Label("label group2.group2_emp3 <=> group3") >> group3
+
+        (
+            XComArg(group2_op1, "test_key")
+            >> Label("label group2.group2_op1 <=> group2.group2_op2")
+            >> XComArg(group2_op2, "test_key")
+        )
+        XComArg(group2_op2, "test_key") >> Label("label group2.group2_op2 <=> group3") >> group3
+
+        group3 >> emp_out1
+        group3 >> Label("label group3 <=> emp_out2") >> emp_out2
+        group3 >> Label("label group3 <=> emp_out3/emp_out4") >> [emp_out3, emp_out4]
+        group3 >> Label("label group3 <=> op_out1") >> XComArg(op_out1, "test_key")
+
+        group1 >> Label("label group1 <=> group2") >> group2
+
+        compare_dag_edges(dag_edges(dag), multiple_taskgroups_dag_expected_edges)
+
+    def test_multiple_task_groups_reversed_dag(
+        self, test_multiple_taskgroups_dag, multiple_taskgroups_dag_expected_edges
+    ):
+        """Tests multiple task groups and labels"""
+        (
+            dag,
+            group1,
+            group2,
+            group3,
+            (
+                group1_emp1,
+                group1_emp2,
+                group1_emp3,
+                group2_emp1,
+                group2_emp2,
+                group2_emp3,
+                group2_op1,
+                group2_op2,
+                group3_emp1,
+                group3_emp2,
+                group3_emp3,
+                emp_in1,
+                emp_in2,
+                emp_in3,
+                emp_in4,
+                emp_out1,
+                emp_out2,
+                emp_out3,
+                emp_out4,
+                op_in1,
+                op_out1,
+            ),
+        ) = test_multiple_taskgroups_dag
+
+        group1_emp3 << Label("label group1.group1_emp1 <=> group1.group1_emp2") << group1_emp1
+
+        group1 << emp_in1
+        group1 << Label("label emp_in2 <=> group1") << emp_in2
+        group1 << Label("label emp_in3/emp_in4 <=> group1") << [emp_in3, emp_in4]
+        group1 << Label("label op_in1 <=> group1") << XComArg(op_in1, "test_key")
+
+        (
+            group2_emp3
+            << Label("label group2.group2_emp1/group2.group2_emp2 <=> group2.group2_emp3")
+            << [group2_emp1, group2_emp2]
+        )
+        (
+            [group2_emp2, group2_emp3]
+            << Label("label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3")
+            << group2_emp1
+        )
+        group3 << Label("label group2.group2_emp3 <=> group3") << group2_emp3
+
+        (
+            XComArg(group2_op2, "test_key")
+            << Label("label group2.group2_op1 <=> group2.group2_op2")
+            << XComArg(group2_op1, "test_key")
+        )
+        group3 << Label("label group2.group2_op2 <=> group3") << XComArg(group2_op2, "test_key")
+
+        emp_out1 << group3
+        emp_out2 << Label("label group3 <=> emp_out2") << group3
+        [emp_out3, emp_out4] << Label("label group3 <=> emp_out3/emp_out4") << group3
+        XComArg(op_out1, "test_key") << Label("label group3 <=> op_out1") << group3
+
+        group2 << Label("label group1 <=> group2") << group1
+
+        compare_dag_edges(dag_edges(dag), multiple_taskgroups_dag_expected_edges)