You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/07/25 16:20:24 UTC

[airflow] branch main updated: Adding EdgeModifier support for chain() (#17099)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 29d8e7f  Adding EdgeModifier support for chain() (#17099)
29d8e7f is described below

commit 29d8e7f50b6e946a6b6561cad99620e00a2c8360
Author: josh-fell <48...@users.noreply.github.com>
AuthorDate: Sun Jul 25 12:20:09 2021 -0400

    Adding EdgeModifier support for chain() (#17099)
---
 airflow/models/baseoperator.py    | 54 +++++++++++++++++++++++++++------------
 docs/spelling_wordlist.txt        |  2 ++
 tests/models/test_baseoperator.py | 35 +++++++++++++++++++++++--
 3 files changed, 72 insertions(+), 19 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 8cbaad7..5018a76 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1546,13 +1546,16 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         return getattr(self, '_is_dummy', False)
 
 
-def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "XComArg"]]]):
+Chainable = Union[BaseOperator, "XComArg", EdgeModifier]
+
+
+def chain(*tasks: Union[Chainable, Sequence[Chainable]]) -> None:
     r"""
     Given a number of tasks, builds a dependency chain.
 
-    This function accepts values of BaseOperator (aka tasks), XComArg, or lists containing
-    either type (or a mix of both in the same list). If you want to chain between two lists you must
-    ensure they have the same length.
+    This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, or lists
+    containing any mix of these types (or a mix in the same list). If you want to chain between two lists
+    you must ensure they have the same length.
 
     Using classic operators/sensors:
 
@@ -1603,41 +1606,58 @@ def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "X
         x5.set_downstream(x6)
 
 
-    It is also possible to mix between classic operator/sensor and XComArg tasks:
+    It is also possible to mix between classic operator/sensor, EdgeModifiers, and XComArg tasks:
 
     .. code-block:: python
 
-        chain(t1, [x1(), x2()], t2, x3())
+        chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], t2, x3())
 
     is equivalent to::
 
-          / -> x1 \
-        t1         -> t2 -> x3
-          \ -> x2 /
+          / "branch one" -> x1 \
+        t1                      -> t2 -> x3
+          \ "branch two" -> x2 /
 
     .. code-block:: python
 
         x1 = x1()
         x2 = x2()
         x3 = x3()
-        t1.set_downstream(x1)
-        t1.set_downstream(x2)
+        label1 = Label("branch one")
+        label2 = Label("branch two")
+        t1.set_downstream(label1)
+        label1.set_downstream(x1)
+        t2.set_downstream(label2)
+        label2.set_downstream(x2)
         x1.set_downstream(t2)
         x2.set_downstream(t2)
         t2.set_downstream(x3)
 
-    :param tasks: List of tasks or XComArgs to set dependencies
-    :type tasks: List[airflow.models.BaseOperator], airflow.models.BaseOperator, List[airflow.models.XComArg],
+        # or
+
+        x1 = x1()
+        x2 = x2()
+        x3 = x3()
+        t1.set_downstream(x1, edge_modifier=Label("branch one"))
+        t1.set_downstream(x2, edge_modifier=Label("branch two"))
+        x1.set_downstream(t2)
+        x2.set_downstream(t2)
+        t2.set_downstream(x3)
+
+
+    :param tasks: Individual and/or list of tasks, EdgeModifiers, or XComArgs to set dependencies
+    :type tasks: List[airflow.models.BaseOperator], airflow.models.BaseOperator,
+        List[airflow.utils.EdgeModifier], airflow.utils.EdgeModifier, List[airflow.models.XComArg],
         or XComArg
     """
     from airflow.models.xcom_arg import XComArg
 
     for index, up_task in enumerate(tasks[:-1]):
         down_task = tasks[index + 1]
-        if isinstance(up_task, (BaseOperator, XComArg)):
+        if isinstance(up_task, (BaseOperator, XComArg, EdgeModifier)):
             up_task.set_downstream(down_task)
             continue
-        if isinstance(down_task, (BaseOperator, XComArg)):
+        if isinstance(down_task, (BaseOperator, XComArg, EdgeModifier)):
             down_task.set_upstream(up_task)
             continue
         if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
@@ -1650,8 +1670,8 @@ def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "X
         down_task_list = down_task
         if len(up_task_list) != len(down_task_list):
             raise AirflowException(
-                f'Chain not supported different length Iterable '
-                f'but get {len(up_task_list)} and {len(down_task_list)}'
+                f'Chain not supported for different length Iterable. '
+                f'Got {len(up_task_list)} and {len(down_task_list)}.'
             )
         for up_t, down_t in zip(up_task_list, down_task_list):
             up_t.set_downstream(down_t)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 586094a..403750f 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -116,6 +116,8 @@ Dsn
 Dynamodb
 EDITMSG
 ETag
+EdgeModifier
+EdgeModifiers
 Eg
 EmrAddSteps
 EmrCreateJobFlow
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index f2ca59e..becb970 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -30,6 +30,7 @@ from airflow.lineage.entities import File
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperatorMeta, chain, cross_downstream
 from airflow.operators.dummy import DummyOperator
+from airflow.utils.edgemodifier import Label
 from tests.models import DEFAULT_DATE
 from tests.test_utils.mock_operators import DeprecatedOperator, MockNamedTuple, MockOperator
 
@@ -402,26 +403,44 @@ class TestBaseOperatorMethods(unittest.TestCase):
 
     def test_chain(self):
         dag = DAG(dag_id='test_chain', start_date=datetime.now())
+
+        # Begin test for classic operators
+        [label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
         [op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
-        chain(op1, [op2, op3], [op4, op5], op6)
+        chain(op1, [label1, label2], [op2, op3], [op4, op5], op6)
 
         assert {op2, op3} == set(op1.get_direct_relatives(upstream=False))
         assert [op4] == op2.get_direct_relatives(upstream=False)
         assert [op5] == op3.get_direct_relatives(upstream=False)
         assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
 
+        assert {"label": "label1"} == dag.get_edge_info(
+            upstream_task_id=op1.task_id, downstream_task_id=op2.task_id
+        )
+        assert {"label": "label2"} == dag.get_edge_info(
+            upstream_task_id=op1.task_id, downstream_task_id=op3.task_id
+        )
+
         # Begin test for `XComArgs`
+        [xlabel1, xlabel2] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
         [xop1, xop2, xop3, xop4, xop5, xop6] = [
             task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
             for i in range(1, 7)
         ]
-        chain(xop1, [xop2, xop3], [xop4, xop5], xop6)
+        chain(xop1, [xlabel1, xlabel2], [xop2, xop3], [xop4, xop5], xop6)
 
         assert {xop2.operator, xop3.operator} == set(xop1.operator.get_direct_relatives(upstream=False))
         assert [xop4.operator] == xop2.operator.get_direct_relatives(upstream=False)
         assert [xop5.operator] == xop3.operator.get_direct_relatives(upstream=False)
         assert {xop4.operator, xop5.operator} == set(xop6.operator.get_direct_relatives(upstream=True))
 
+        assert {"label": "xcomarg_label1"} == dag.get_edge_info(
+            upstream_task_id=xop1.operator.task_id, downstream_task_id=xop2.operator.task_id
+        )
+        assert {"label": "xcomarg_label2"} == dag.get_edge_info(
+            upstream_task_id=xop1.operator.task_id, downstream_task_id=xop3.operator.task_id
+        )
+
     def test_chain_not_support_type(self):
         dag = DAG(dag_id='test_chain', start_date=datetime.now())
         [op1, op2] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 3)]
@@ -437,13 +456,22 @@ class TestBaseOperatorMethods(unittest.TestCase):
         with pytest.raises(TypeError):
             chain([xop1, xop2], 1)
 
+        with pytest.raises(TypeError):
+            chain([Label("labe1"), Label("label2")], 1)
+
     def test_chain_different_length_iterable(self):
         dag = DAG(dag_id='test_chain', start_date=datetime.now())
+        [label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
         [op1, op2, op3, op4, op5] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 6)]
+
         with pytest.raises(AirflowException):
             chain([op1, op2], [op3, op4, op5])
 
+        with pytest.raises(AirflowException):
+            chain([op1, op2, op3], [label1, label2])
+
         # Begin test for `XComArgs`
+        [label3, label4] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
         [xop1, xop2, xop3, xop4, xop5] = [
             task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
             for i in range(1, 6)
@@ -452,6 +480,9 @@ class TestBaseOperatorMethods(unittest.TestCase):
         with pytest.raises(AirflowException):
             chain([xop1, xop2], [xop3, xop4, xop5])
 
+        with pytest.raises(AirflowException):
+            chain([xop1, xop2, xop3], [label1, label2])
+
     def test_lineage_composition(self):
         """
         Test composition with lineage