You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/02/18 17:57:52 UTC
[airflow] branch main updated: Edgemodifier refactoring w/ labels in TaskGroup edge case (#29410)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4b05468129 Edgemodifier refactoring w/ labels in TaskGroup edge case (#29410)
4b05468129 is described below
commit 4b05468129361946688909943fe332f383302069
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Sat Feb 18 12:57:44 2023 -0500
Edgemodifier refactoring w/ labels in TaskGroup edge case (#29410)
* Edgemodifier refactoring w/ labels in TaskGroup edge case
---
airflow/models/taskmixin.py | 14 +-
airflow/utils/edgemodifier.py | 145 ++++++-----
airflow/utils/task_group.py | 18 +-
tests/utils/test_edgemodifier.py | 501 +++++++++++++++++++++++++++++++++++++++
4 files changed, 618 insertions(+), 60 deletions(-)
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 211fad6ff9..d93879498b 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -56,16 +56,22 @@ class DependencyMixin:
raise NotImplementedError()
@abstractmethod
- def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
+ def set_upstream(
+ self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
+ ):
"""Set a task or a task list to be directly upstream from the current task."""
raise NotImplementedError()
@abstractmethod
- def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
+ def set_downstream(
+ self, other: DependencyMixin | Sequence[DependencyMixin], edge_modifier: EdgeModifier | None = None
+ ):
"""Set a task or a task list to be directly downstream from the current task."""
raise NotImplementedError()
- def update_relative(self, other: DependencyMixin, upstream=True) -> None:
+ def update_relative(
+ self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
+ ) -> None:
"""
Update relationship information about another TaskMixin. Default is no-op.
Override if necessary.
@@ -172,7 +178,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
task_list: list[Operator] = []
for task_object in task_or_task_list:
- task_object.update_relative(self, not upstream)
+ task_object.update_relative(self, not upstream, edge_modifier=edge_modifier)
relatives = task_object.leaves if upstream else task_object.roots
for task in relatives:
if not isinstance(task, (BaseOperator, MappedOperator)):
diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py
index d2135cf233..b693e1a1be 100644
--- a/airflow/utils/edgemodifier.py
+++ b/airflow/utils/edgemodifier.py
@@ -16,12 +16,10 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+from typing import Sequence
-from airflow.models.taskmixin import DependencyMixin
-
-if TYPE_CHECKING:
- from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskmixin import DAGNode, DependencyMixin
+from airflow.utils.task_group import TaskGroup
class EdgeModifier(DependencyMixin):
@@ -44,8 +42,8 @@ class EdgeModifier(DependencyMixin):
def __init__(self, label: str | None = None):
self.label = label
- self._upstream: list[BaseOperator] = []
- self._downstream: list[BaseOperator] = []
+ self._upstream: list[DependencyMixin] = []
+ self._downstream: list[DependencyMixin] = []
@property
def roots(self):
@@ -55,8 +53,72 @@ class EdgeModifier(DependencyMixin):
def leaves(self):
return self._upstream
+ @staticmethod
+ def _make_list(item_or_list: DependencyMixin | Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
+ if not isinstance(item_or_list, Sequence):
+ return [item_or_list]
+ return item_or_list
+
+ def _save_nodes(
+ self,
+ nodes: DependencyMixin | Sequence[DependencyMixin],
+ stream: list[DependencyMixin],
+ ):
+ from airflow.models.xcom_arg import XComArg
+
+ for node in self._make_list(nodes):
+ if isinstance(node, (TaskGroup, XComArg, DAGNode)):
+ stream.append(node)
+ else:
+ raise TypeError(
+ f"Cannot use edge labels with {type(node).__name__}, "
+ f"only tasks, XComArg or TaskGroups"
+ )
+
+ def _convert_streams_to_task_groups(self):
+ """
+ Both self._upstream and self._downstream are required to determine if
+ we should convert a node to a TaskGroup or leave it as a DAGNode.
+
+ To do this, we keep a set of group_ids seen among the streams. If we find that
+ the nodes are from the same TaskGroup, we will leave them as DAGNodes and not
+ convert them to TaskGroups
+ """
+ from airflow.models.xcom_arg import XComArg
+
+ group_ids = set()
+ for node in [*self._upstream, *self._downstream]:
+ if isinstance(node, DAGNode) and node.task_group:
+ if node.task_group.is_root:
+ group_ids.add("root")
+ else:
+ group_ids.add(node.task_group.group_id)
+ elif isinstance(node, TaskGroup):
+ group_ids.add(node.group_id)
+ elif isinstance(node, XComArg):
+ if isinstance(node.operator, DAGNode) and node.operator.task_group:
+ if node.operator.task_group.is_root:
+ group_ids.add("root")
+ else:
+ group_ids.add(node.operator.task_group.group_id)
+
+ # If all nodes originate from the same TaskGroup, we will not convert them
+ if len(group_ids) != 1:
+ self._upstream = self._convert_stream_to_task_groups(self._upstream)
+ self._downstream = self._convert_stream_to_task_groups(self._downstream)
+
+ def _convert_stream_to_task_groups(self, stream: Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
+ return [
+ node.task_group
+ if isinstance(node, DAGNode) and node.task_group and not node.task_group.is_root
+ else node
+ for node in stream
+ ]
+
def set_upstream(
- self, task_or_task_list: DependencyMixin | Sequence[DependencyMixin], chain: bool = True
+ self,
+ other: DependencyMixin | Sequence[DependencyMixin],
+ edge_modifier: EdgeModifier | None = None,
):
"""
Sets the given task/list onto the upstream attribute, and then checks if
@@ -64,30 +126,17 @@ class EdgeModifier(DependencyMixin):
Providing this also provides << via DependencyMixin.
"""
- from airflow.models.baseoperator import BaseOperator
-
- # Ensure we have a list, even if it's just one item
- if isinstance(task_or_task_list, DependencyMixin):
- task_or_task_list = [task_or_task_list]
- # Unfurl it into actual operators
- operators: list[BaseOperator] = []
- for task in task_or_task_list:
- for root in task.roots:
- if not isinstance(root, BaseOperator):
- raise TypeError(f"Cannot use edge labels with {type(root).__name__}, only operators")
- operators.append(root)
- # For each already-declared downstream, pair off with each new upstream
- # item and store the edge info.
- for operator in operators:
- for downstream in self._downstream:
- self.add_edge_info(operator.dag, operator.task_id, downstream.task_id)
- if chain:
- operator.set_downstream(downstream)
- # Add the new tasks to our list of ones we've seen
- self._upstream.extend(operators)
+ self._save_nodes(other, self._upstream)
+ if self._upstream and self._downstream:
+ # Convert _upstream and _downstream to task_groups only after both are set
+ self._convert_streams_to_task_groups()
+ for node in self._downstream:
+ node.set_upstream(other, edge_modifier=self)
def set_downstream(
- self, task_or_task_list: DependencyMixin | Sequence[DependencyMixin], chain: bool = True
+ self,
+ other: DependencyMixin | Sequence[DependencyMixin],
+ edge_modifier: EdgeModifier | None = None,
):
"""
Sets the given task/list onto the downstream attribute, and then checks if
@@ -95,36 +144,24 @@ class EdgeModifier(DependencyMixin):
Providing this also provides >> via DependencyMixin.
"""
- from airflow.models.baseoperator import BaseOperator
-
- # Ensure we have a list, even if it's just one item
- if isinstance(task_or_task_list, DependencyMixin):
- task_or_task_list = [task_or_task_list]
- # Unfurl it into actual operators
- operators: list[BaseOperator] = []
- for task in task_or_task_list:
- for leaf in task.leaves:
- if not isinstance(leaf, BaseOperator):
- raise TypeError(f"Cannot use edge labels with {type(leaf).__name__}, only operators")
- operators.append(leaf)
- # Pair them off with existing
- for operator in operators:
- for upstream in self._upstream:
- self.add_edge_info(upstream.dag, upstream.task_id, operator.task_id)
- if chain:
- upstream.set_downstream(operator)
- # Add the new tasks to our list of ones we've seen
- self._downstream.extend(operators)
-
- def update_relative(self, other: DependencyMixin, upstream: bool = True) -> None:
+ self._save_nodes(other, self._downstream)
+ if self._upstream and self._downstream:
+ # Convert _upstream and _downstream to task_groups only after both are set
+ self._convert_streams_to_task_groups()
+ for node in self._upstream:
+ node.set_downstream(other, edge_modifier=self)
+
+ def update_relative(
+ self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
+ ) -> None:
"""
Called if we're not the "main" side of a relationship; we still run the
same logic, though.
"""
if upstream:
- self.set_upstream(other, chain=False)
+ self.set_upstream(other)
else:
- self.set_downstream(other, chain=False)
+ self.set_downstream(other)
def add_edge_info(self, dag, upstream_id: str, downstream_id: str):
"""
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index c9635c40c8..eb63f831c0 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -253,7 +253,9 @@ class TaskGroup(DAGNode):
"""group_id excluding parent's group_id used as the node label in UI."""
return self._group_id
- def update_relative(self, other: DependencyMixin, upstream=True) -> None:
+ def update_relative(
+ self, other: DependencyMixin, upstream: bool = True, edge_modifier: EdgeModifier | None = None
+ ) -> None:
"""
Overrides TaskMixin.update_relative.
@@ -264,8 +266,12 @@ class TaskGroup(DAGNode):
# Handles setting relationship between a TaskGroup and another TaskGroup
if upstream:
parent, child = (self, other)
+ if edge_modifier:
+ edge_modifier.add_edge_info(self.dag, other.downstream_join_id, self.upstream_join_id)
else:
parent, child = (other, self)
+ if edge_modifier:
+ edge_modifier.add_edge_info(self.dag, self.downstream_join_id, other.upstream_join_id)
parent.upstream_group_ids.add(child.group_id)
child.downstream_group_ids.add(parent.group_id)
@@ -278,10 +284,18 @@ class TaskGroup(DAGNode):
f"or operators; received {task.__class__.__name__}"
)
+ # Do not set a relationship between a TaskGroup and a Label's roots
+ if self == task:
+ continue
+
if upstream:
self.upstream_task_ids.add(task.node_id)
+ if edge_modifier:
+ edge_modifier.add_edge_info(self.dag, task.node_id, self.upstream_join_id)
else:
self.downstream_task_ids.add(task.node_id)
+ if edge_modifier:
+ edge_modifier.add_edge_info(self.dag, self.downstream_join_id, task.node_id)
def _set_relatives(
self,
@@ -297,7 +311,7 @@ class TaskGroup(DAGNode):
task_or_task_list = [task_or_task_list]
for task_like in task_or_task_list:
- self.update_relative(task_like, upstream)
+ self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
if upstream:
for task in self.get_roots():
diff --git a/tests/utils/test_edgemodifier.py b/tests/utils/test_edgemodifier.py
index 97097679ce..4d92f075b8 100644
--- a/tests/utils/test_edgemodifier.py
+++ b/tests/utils/test_edgemodifier.py
@@ -22,9 +22,11 @@ import pytest
from airflow import DAG
from airflow.models.xcom_arg import XComArg
+from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.utils.edgemodifier import Label
from airflow.utils.task_group import TaskGroup
+from airflow.www.views import dag_edges
DEFAULT_ARGS = {
"owner": "test",
@@ -63,6 +65,283 @@ def test_taskgroup_dag():
return dag, group, (op1, op2, op3, op4)
+@pytest.fixture
+def test_complex_taskgroup_dag():
+ """Creates a test DAG with many operators and a task group."""
+
+ def f(task_id):
+ return f"OP:{task_id}"
+
+ with DAG(dag_id="test_complex_dag", default_args=DEFAULT_ARGS) as dag:
+ with TaskGroup("group_1") as group:
+ group_emp1 = EmptyOperator(task_id="group_empty1")
+ group_emp2 = EmptyOperator(task_id="group_empty2")
+ group_emp3 = EmptyOperator(task_id="group_empty3")
+ emp_in1 = EmptyOperator(task_id="empty_in1")
+ emp_in2 = EmptyOperator(task_id="empty_in2")
+ emp_in3 = EmptyOperator(task_id="empty_in3")
+ emp_in4 = EmptyOperator(task_id="empty_in4")
+ emp_out1 = EmptyOperator(task_id="empty_out1")
+ emp_out2 = EmptyOperator(task_id="empty_out2")
+ emp_out3 = EmptyOperator(task_id="empty_out3")
+ emp_out4 = EmptyOperator(task_id="empty_out4")
+ op_in1 = PythonOperator(python_callable=f, task_id="op_in1")
+ op_out1 = PythonOperator(python_callable=f, task_id="op_out1")
+
+ return (
+ dag,
+ group,
+ (
+ group_emp1,
+ group_emp2,
+ group_emp3,
+ emp_in1,
+ emp_in2,
+ emp_in3,
+ emp_in4,
+ emp_out1,
+ emp_out2,
+ emp_out3,
+ emp_out4,
+ op_in1,
+ op_out1,
+ ),
+ )
+
+
+@pytest.fixture
+def test_multiple_taskgroups_dag():
+ """Creates a test DAG with many operators and multiple task groups."""
+
+ def f(task_id):
+ return f"OP:{task_id}"
+
+ with DAG(dag_id="test_multiple_task_group_dag", default_args=DEFAULT_ARGS) as dag:
+ with TaskGroup("group1") as group1:
+ group1_emp1 = EmptyOperator(task_id="group1_empty1")
+ group1_emp2 = EmptyOperator(task_id="group1_empty2")
+ group1_emp3 = EmptyOperator(task_id="group1_empty3")
+ with TaskGroup("group2") as group2:
+ group2_emp1 = EmptyOperator(task_id="group2_empty1")
+ group2_emp2 = EmptyOperator(task_id="group2_empty2")
+ group2_emp3 = EmptyOperator(task_id="group2_empty3")
+ group2_op1 = PythonOperator(python_callable=f, task_id="group2_op1")
+ group2_op2 = PythonOperator(python_callable=f, task_id="group2_op2")
+
+ with TaskGroup("group3") as group3:
+ group3_emp1 = EmptyOperator(task_id="group3_empty1")
+ group3_emp2 = EmptyOperator(task_id="group3_empty2")
+ group3_emp3 = EmptyOperator(task_id="group3_empty3")
+ emp_in1 = EmptyOperator(task_id="empty_in1")
+ emp_in2 = EmptyOperator(task_id="empty_in2")
+ emp_in3 = EmptyOperator(task_id="empty_in3")
+ emp_in4 = EmptyOperator(task_id="empty_in4")
+ emp_out1 = EmptyOperator(task_id="empty_out1")
+ emp_out2 = EmptyOperator(task_id="empty_out2")
+ emp_out3 = EmptyOperator(task_id="empty_out3")
+ emp_out4 = EmptyOperator(task_id="empty_out4")
+ op_in1 = PythonOperator(python_callable=f, task_id="op_in1")
+ op_out1 = PythonOperator(python_callable=f, task_id="op_out1")
+
+ return (
+ dag,
+ group1,
+ group2,
+ group3,
+ (
+ group1_emp1,
+ group1_emp2,
+ group1_emp3,
+ group2_emp1,
+ group2_emp2,
+ group2_emp3,
+ group2_op1,
+ group2_op2,
+ group3_emp1,
+ group3_emp2,
+ group3_emp3,
+ emp_in1,
+ emp_in2,
+ emp_in3,
+ emp_in4,
+ emp_out1,
+ emp_out2,
+ emp_out3,
+ emp_out4,
+ op_in1,
+ op_out1,
+ ),
+ )
+
+
+@pytest.fixture
+def simple_dag_expected_edges():
+ return [
+ {"source_id": "group_1.downstream_join_id", "target_id": "test_op_4"},
+ {"source_id": "group_1.test_op_2", "target_id": "group_1.downstream_join_id"},
+ {"source_id": "group_1.test_op_3", "target_id": "group_1.downstream_join_id"},
+ {"source_id": "group_1.upstream_join_id", "target_id": "group_1.test_op_2"},
+ {"source_id": "group_1.upstream_join_id", "target_id": "group_1.test_op_3"},
+ {"label": "Label", "source_id": "test_op_1", "target_id": "group_1.upstream_join_id"},
+ ]
+
+
+@pytest.fixture
+def complex_dag_expected_edges():
+ return [
+ {"source_id": "empty_in1", "target_id": "group_1.upstream_join_id"},
+ {
+ "label": "label emp_in2 <=> group",
+ "source_id": "empty_in2",
+ "target_id": "group_1.upstream_join_id",
+ },
+ {
+ "label": "label emp_in3/emp_in4 <=> group",
+ "source_id": "empty_in3",
+ "target_id": "group_1.upstream_join_id",
+ },
+ {
+ "label": "label emp_in3/emp_in4 <=> group",
+ "source_id": "empty_in4",
+ "target_id": "group_1.upstream_join_id",
+ },
+ {"source_id": "group_1.downstream_join_id", "target_id": "empty_out1"},
+ {
+ "label": "label group <=> emp_out2",
+ "source_id": "group_1.downstream_join_id",
+ "target_id": "empty_out2",
+ },
+ {
+ "label": "label group <=> emp_out3/emp_out4",
+ "source_id": "group_1.downstream_join_id",
+ "target_id": "empty_out3",
+ },
+ {
+ "label": "label group <=> emp_out3/emp_out4",
+ "source_id": "group_1.downstream_join_id",
+ "target_id": "empty_out4",
+ },
+ {
+ "label": "label group <=> op_out1",
+ "source_id": "group_1.downstream_join_id",
+ "target_id": "op_out1",
+ },
+ {"source_id": "group_1.group_empty1", "target_id": "group_1.downstream_join_id"},
+ {"source_id": "group_1.group_empty2", "target_id": "group_1.group_empty1"},
+ {"source_id": "group_1.group_empty3", "target_id": "group_1.group_empty1"},
+ {"source_id": "group_1.upstream_join_id", "target_id": "group_1.group_empty2"},
+ {"source_id": "group_1.upstream_join_id", "target_id": "group_1.group_empty3"},
+ {
+ "label": "label op_in1 <=> group",
+ "source_id": "op_in1",
+ "target_id": "group_1.upstream_join_id",
+ },
+ ]
+
+
+@pytest.fixture
+def multiple_taskgroups_dag_expected_edges():
+ return [
+ {"source_id": "empty_in1", "target_id": "group1.upstream_join_id"},
+ {
+ "label": "label emp_in2 <=> group1",
+ "source_id": "empty_in2",
+ "target_id": "group1.upstream_join_id",
+ },
+ {
+ "label": "label emp_in3/emp_in4 <=> group1",
+ "source_id": "empty_in3",
+ "target_id": "group1.upstream_join_id",
+ },
+ {
+ "label": "label emp_in3/emp_in4 <=> group1",
+ "source_id": "empty_in4",
+ "target_id": "group1.upstream_join_id",
+ },
+ {
+ "label": "label group1 <=> group2",
+ "source_id": "group1.downstream_join_id",
+ "target_id": "group2.upstream_join_id",
+ },
+ {
+ "label": "label group1.group1_emp1 <=> group1.group1_emp2",
+ "source_id": "group1.group1_empty1",
+ "target_id": "group1.group1_empty3",
+ },
+ {"source_id": "group1.group1_empty2", "target_id": "group1.downstream_join_id"},
+ {"source_id": "group1.group1_empty3", "target_id": "group1.downstream_join_id"},
+ {"source_id": "group1.upstream_join_id", "target_id": "group1.group1_empty1"},
+ {"source_id": "group1.upstream_join_id", "target_id": "group1.group1_empty2"},
+ {
+ "label": "label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3",
+ "source_id": "group2.group2_empty1",
+ "target_id": "group2.group2_empty2",
+ },
+ {
+ "label": "label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3",
+ "source_id": "group2.group2_empty1",
+ "target_id": "group2.group2_empty3",
+ },
+ {
+ "label": "label group2.group2_emp1/group2.group2_emp2 <=> group2.group2_emp3",
+ "source_id": "group2.group2_empty2",
+ "target_id": "group2.group2_empty3",
+ },
+ {
+ "label": "label group2.group2_emp3 <=> group3",
+ "source_id": "group2.group2_empty3",
+ "target_id": "group2.group3.upstream_join_id",
+ },
+ {
+ "label": "label group2.group2_op1 <=> group2.group2_op2",
+ "source_id": "group2.group2_op1",
+ "target_id": "group2.group2_op2",
+ },
+ {
+ "label": "label group2.group2_op2 <=> group3",
+ "source_id": "group2.group2_op2",
+ "target_id": "group2.group3.upstream_join_id",
+ },
+ {"source_id": "group2.group3.downstream_join_id", "target_id": "empty_out1"},
+ {
+ "label": "label group3 <=> emp_out2",
+ "source_id": "group2.group3.downstream_join_id",
+ "target_id": "empty_out2",
+ },
+ {
+ "label": "label group3 <=> emp_out3/emp_out4",
+ "source_id": "group2.group3.downstream_join_id",
+ "target_id": "empty_out3",
+ },
+ {
+ "label": "label group3 <=> emp_out3/emp_out4",
+ "source_id": "group2.group3.downstream_join_id",
+ "target_id": "empty_out4",
+ },
+ {
+ "label": "label group3 <=> op_out1",
+ "source_id": "group2.group3.downstream_join_id",
+ "target_id": "op_out1",
+ },
+ {"source_id": "group2.group3.group3_empty1", "target_id": "group2.group3.downstream_join_id"},
+ {"source_id": "group2.group3.group3_empty2", "target_id": "group2.group3.downstream_join_id"},
+ {"source_id": "group2.group3.group3_empty3", "target_id": "group2.group3.downstream_join_id"},
+ {"source_id": "group2.group3.upstream_join_id", "target_id": "group2.group3.group3_empty1"},
+ {"source_id": "group2.group3.upstream_join_id", "target_id": "group2.group3.group3_empty2"},
+ {"source_id": "group2.group3.upstream_join_id", "target_id": "group2.group3.group3_empty3"},
+ {"source_id": "group2.upstream_join_id", "target_id": "group2.group2_empty1"},
+ {"source_id": "group2.upstream_join_id", "target_id": "group2.group2_op1"},
+ {"label": "label op_in1 <=> group1", "source_id": "op_in1", "target_id": "group1.upstream_join_id"},
+ ]
+
+
+def compare_dag_edges(current, expected):
+ assert len(current) == len(expected)
+
+ for i in current:
+ assert current.count(i) == expected.count(i), f"The unexpected DAG edge: {i}"
+
+
class TestEdgeModifierBuilding:
"""
Tests that EdgeModifiers work when composed with Tasks (either via >>
@@ -163,3 +442,225 @@ class TestEdgeModifierBuilding:
assert dag.get_edge_info(op1.task_id, op2.task_id) == {"label": "Group label"}
assert dag.get_edge_info(op1.task_id, op3.task_id) == {"label": "Group label"}
assert dag.get_edge_info(op3.task_id, op4.task_id) == {}
+
+ def test_simple_dag(self, test_taskgroup_dag, simple_dag_expected_edges):
+ """Tests the simple dag with a TaskGroup and a Label"""
+ dag, group, (op1, op2, op3, op4) = test_taskgroup_dag
+ op1 >> Label("Label") >> group >> op4
+ compare_dag_edges(dag_edges(dag), simple_dag_expected_edges)
+
+ def test_simple_reversed_dag(self, test_taskgroup_dag, simple_dag_expected_edges):
+ """Tests the simple reversed dag with a TaskGroup and a Label"""
+ dag, group, (op1, op2, op3, op4) = test_taskgroup_dag
+ op4 << group << Label("Label") << op1
+ compare_dag_edges(dag_edges(dag), simple_dag_expected_edges)
+
+ def test_complex_dag(self, test_complex_taskgroup_dag, complex_dag_expected_edges):
+ """Tests the complex dag with a TaskGroup and a Label"""
+ (
+ dag,
+ group,
+ (
+ group_emp1,
+ group_emp2,
+ group_emp3,
+ emp_in1,
+ emp_in2,
+ emp_in3,
+ emp_in4,
+ emp_out1,
+ emp_out2,
+ emp_out3,
+ emp_out4,
+ op_in1,
+ op_out1,
+ ),
+ ) = test_complex_taskgroup_dag
+
+ [group_emp2, group_emp3] >> group_emp1
+
+ emp_in1 >> group
+ emp_in2 >> Label("label emp_in2 <=> group") >> group
+ [emp_in3, emp_in4] >> Label("label emp_in3/emp_in4 <=> group") >> group
+ XComArg(op_in1, "test_key") >> Label("label op_in1 <=> group") >> group
+
+ group >> emp_out1
+ group >> Label("label group <=> emp_out2") >> emp_out2
+ group >> Label("label group <=> emp_out3/emp_out4") >> [emp_out3, emp_out4]
+ group >> Label("label group <=> op_out1") >> XComArg(op_out1, "test_key")
+
+ compare_dag_edges(dag_edges(dag), complex_dag_expected_edges)
+
+ def test_complex_reversed_dag(self, test_complex_taskgroup_dag, complex_dag_expected_edges):
+ """Tests the complex reversed dag with a TaskGroup and a Label"""
+ (
+ dag,
+ group,
+ (
+ group_emp1,
+ group_emp2,
+ group_emp3,
+ emp_in1,
+ emp_in2,
+ emp_in3,
+ emp_in4,
+ emp_out1,
+ emp_out2,
+ emp_out3,
+ emp_out4,
+ op_in1,
+ op_out1,
+ ),
+ ) = test_complex_taskgroup_dag
+
+ group_emp1 << [group_emp2, group_emp3]
+
+ group << emp_in1
+ group << Label("label emp_in2 <=> group") << emp_in2
+ group << Label("label emp_in3/emp_in4 <=> group") << [emp_in3, emp_in4]
+ group << Label("label op_in1 <=> group") << XComArg(op_in1, "test_key")
+
+ emp_out1 << group
+ emp_out2 << Label("label group <=> emp_out2") << group
+ [emp_out3, emp_out4] << Label("label group <=> emp_out3/emp_out4") << group
+ XComArg(op_out1, "test_key") << Label("label group <=> op_out1") << group
+
+ compare_dag_edges(dag_edges(dag), complex_dag_expected_edges)
+
+ def test_multiple_task_groups_dag(
+ self, test_multiple_taskgroups_dag, multiple_taskgroups_dag_expected_edges
+ ):
+ """Tests multiple task groups and labels"""
+ (
+ dag,
+ group1,
+ group2,
+ group3,
+ (
+ group1_emp1,
+ group1_emp2,
+ group1_emp3,
+ group2_emp1,
+ group2_emp2,
+ group2_emp3,
+ group2_op1,
+ group2_op2,
+ group3_emp1,
+ group3_emp2,
+ group3_emp3,
+ emp_in1,
+ emp_in2,
+ emp_in3,
+ emp_in4,
+ emp_out1,
+ emp_out2,
+ emp_out3,
+ emp_out4,
+ op_in1,
+ op_out1,
+ ),
+ ) = test_multiple_taskgroups_dag
+
+ group1_emp1 >> Label("label group1.group1_emp1 <=> group1.group1_emp2") >> group1_emp3
+
+ emp_in1 >> group1
+ emp_in2 >> Label("label emp_in2 <=> group1") >> group1
+ [emp_in3, emp_in4] >> Label("label emp_in3/emp_in4 <=> group1") >> group1
+ XComArg(op_in1, "test_key") >> Label("label op_in1 <=> group1") >> group1
+
+ (
+ [group2_emp1, group2_emp2]
+ >> Label("label group2.group2_emp1/group2.group2_emp2 <=> group2.group2_emp3")
+ >> group2_emp3
+ )
+ (
+ group2_emp1
+ >> Label("label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3")
+ >> [group2_emp2, group2_emp3]
+ )
+ group2_emp3 >> Label("label group2.group2_emp3 <=> group3") >> group3
+
+ (
+ XComArg(group2_op1, "test_key")
+ >> Label("label group2.group2_op1 <=> group2.group2_op2")
+ >> XComArg(group2_op2, "test_key")
+ )
+ XComArg(group2_op2, "test_key") >> Label("label group2.group2_op2 <=> group3") >> group3
+
+ group3 >> emp_out1
+ group3 >> Label("label group3 <=> emp_out2") >> emp_out2
+ group3 >> Label("label group3 <=> emp_out3/emp_out4") >> [emp_out3, emp_out4]
+ group3 >> Label("label group3 <=> op_out1") >> XComArg(op_out1, "test_key")
+
+ group1 >> Label("label group1 <=> group2") >> group2
+
+ compare_dag_edges(dag_edges(dag), multiple_taskgroups_dag_expected_edges)
+
+ def test_multiple_task_groups_reversed_dag(
+ self, test_multiple_taskgroups_dag, multiple_taskgroups_dag_expected_edges
+ ):
+ """Tests multiple task groups and labels"""
+ (
+ dag,
+ group1,
+ group2,
+ group3,
+ (
+ group1_emp1,
+ group1_emp2,
+ group1_emp3,
+ group2_emp1,
+ group2_emp2,
+ group2_emp3,
+ group2_op1,
+ group2_op2,
+ group3_emp1,
+ group3_emp2,
+ group3_emp3,
+ emp_in1,
+ emp_in2,
+ emp_in3,
+ emp_in4,
+ emp_out1,
+ emp_out2,
+ emp_out3,
+ emp_out4,
+ op_in1,
+ op_out1,
+ ),
+ ) = test_multiple_taskgroups_dag
+
+ group1_emp3 << Label("label group1.group1_emp1 <=> group1.group1_emp2") << group1_emp1
+
+ group1 << emp_in1
+ group1 << Label("label emp_in2 <=> group1") << emp_in2
+ group1 << Label("label emp_in3/emp_in4 <=> group1") << [emp_in3, emp_in4]
+ group1 << Label("label op_in1 <=> group1") << XComArg(op_in1, "test_key")
+
+ (
+ group2_emp3
+ << Label("label group2.group2_emp1/group2.group2_emp2 <=> group2.group2_emp3")
+ << [group2_emp1, group2_emp2]
+ )
+ (
+ [group2_emp2, group2_emp3]
+ << Label("label group2.group2_emp1 <=> group2.group2_emp2/group2.group2_emp3")
+ << group2_emp1
+ )
+ group3 << Label("label group2.group2_emp3 <=> group3") << group2_emp3
+
+ (
+ XComArg(group2_op2, "test_key")
+ << Label("label group2.group2_op1 <=> group2.group2_op2")
+ << XComArg(group2_op1, "test_key")
+ )
+ group3 << Label("label group2.group2_op2 <=> group3") << XComArg(group2_op2, "test_key")
+
+ emp_out1 << group3
+ emp_out2 << Label("label group3 <=> emp_out2") << group3
+ [emp_out3, emp_out4] << Label("label group3 <=> emp_out3/emp_out4") << group3
+ XComArg(op_out1, "test_key") << Label("label group3 <=> op_out1") << group3
+
+ group2 << Label("label group1 <=> group2") << group1
+
+ compare_dag_edges(dag_edges(dag), multiple_taskgroups_dag_expected_edges)