You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2021/07/14 07:44:00 UTC

[airflow] branch main updated: Update chain() and cross_downstream() to support XComArgs (#16732)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7529546  Update chain() and cross_downstream() to support XComArgs (#16732)
7529546 is described below

commit 7529546939250266ccf404c2eea98b298365ef46
Author: josh-fell <48...@users.noreply.github.com>
AuthorDate: Wed Jul 14 03:43:41 2021 -0400

    Update chain() and cross_downstream() to support XComArgs (#16732)
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
---
 airflow/models/baseoperator.py    | 157 +++++++++++++++++++++++++++++++++-----
 docs/spelling_wordlist.txt        |   2 +
 tests/models/test_baseoperator.py |  47 ++++++++++++
 3 files changed, 188 insertions(+), 18 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 6d478e4..0c98467 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -74,6 +74,7 @@ from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 
 if TYPE_CHECKING:
+    from airflow.models.xcom_arg import XComArg
     from airflow.utils.task_group import TaskGroup
 
 ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -1545,22 +1546,25 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         return getattr(self, '_is_dummy', False)
 
 
-def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
+def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "XComArg"]]]):
     r"""
     Given a number of tasks, builds a dependency chain.
-    Support mix airflow.models.BaseOperator and List[airflow.models.BaseOperator].
-    If you want to chain between two List[airflow.models.BaseOperator], have to
-    make sure they have same length.
+
+    This function accepts values of BaseOperator (aka tasks), XComArg, or lists containing
+    either type (or a mix of both in the same list). If you want to chain between two lists you must
+    ensure they have the same length.
+
+    Using classic operators/sensors:
 
     .. code-block:: python
 
-         chain(t1, [t2, t3], [t4, t5], t6)
+        chain(t1, [t2, t3], [t4, t5], t6)
 
     is equivalent to::
 
-         / -> t2 -> t4 \
-       t1               -> t6
-         \ -> t3 -> t5 /
+          / -> t2 -> t4 \
+        t1               -> t6
+          \ -> t3 -> t5 /
 
     .. code-block:: python
 
@@ -1571,15 +1575,69 @@ def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
         t4.set_downstream(t6)
         t5.set_downstream(t6)
 
-    :param tasks: List of tasks or List[airflow.models.BaseOperator] to set dependencies
-    :type tasks: List[airflow.models.BaseOperator] or airflow.models.BaseOperator
+    Using task-decorated functions aka XComArgs:
+
+    .. code-block:: python
+
+        chain(x1(), [x2(), x3()], [x4(), x5()], x6())
+
+    is equivalent to::
+
+          / -> x2 -> x4 \
+        x1               -> x6
+          \ -> x3 -> x5 /
+
+    .. code-block:: python
+
+        x1 = x1()
+        x2 = x2()
+        x3 = x3()
+        x4 = x4()
+        x5 = x5()
+        x6 = x6()
+        x1.set_downstream(x2)
+        x1.set_downstream(x3)
+        x2.set_downstream(x4)
+        x3.set_downstream(x5)
+        x4.set_downstream(x6)
+        x5.set_downstream(x6)
+
+
+    It is also possible to mix between classic operator/sensor and XComArg tasks:
+
+    .. code-block:: python
+
+        chain(t1, [x1(), x2()], t2, x3())
+
+    is equivalent to::
+
+          / -> x1 \
+        t1         -> t2 -> x3
+          \ -> x2 /
+
+    .. code-block:: python
+
+        x1 = x1()
+        x2 = x2()
+        x3 = x3()
+        t1.set_downstream(x1)
+        t1.set_downstream(x2)
+        x1.set_downstream(t2)
+        x2.set_downstream(t2)
+        t2.set_downstream(x3)
+
+    :param tasks: List of tasks or XComArgs to set dependencies
+    :type tasks: List[airflow.models.BaseOperator], airflow.models.BaseOperator, List[airflow.models.XComArg],
+        or XComArg
     """
+    from airflow.models.xcom_arg import XComArg
+
     for index, up_task in enumerate(tasks[:-1]):
         down_task = tasks[index + 1]
-        if isinstance(up_task, BaseOperator):
+        if isinstance(up_task, (BaseOperator, XComArg)):
             up_task.set_downstream(down_task)
             continue
-        if isinstance(down_task, BaseOperator):
+        if isinstance(down_task, (BaseOperator, XComArg)):
             down_task.set_upstream(up_task)
             continue
         if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
@@ -1600,11 +1658,14 @@ def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
 
 
 def cross_downstream(
-    from_tasks: Sequence[BaseOperator], to_tasks: Union[BaseOperator, Sequence[BaseOperator]]
+    from_tasks: Sequence[Union[BaseOperator, "XComArg"]],
+    to_tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "XComArg"]]],
 ):
     r"""
     Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
 
+    Using classic operators/sensors:
+
     .. code-block:: python
 
         cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
@@ -1617,7 +1678,6 @@ def cross_downstream(
            / \
         t3 ---> t6
 
-
     .. code-block:: python
 
         t1.set_downstream(t4)
@@ -1630,10 +1690,71 @@ def cross_downstream(
         t3.set_downstream(t5)
         t3.set_downstream(t6)
 
-    :param from_tasks: List of tasks to start from.
-    :type from_tasks: List[airflow.models.BaseOperator]
-    :param to_tasks: List of tasks to set as downstream dependencies.
-    :type to_tasks: List[airflow.models.BaseOperator]
+    Using task-decorated functions aka XComArgs:
+
+    .. code-block:: python
+
+        cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
+
+    is equivalent to::
+
+        x1 ---> x4
+           \ /
+        x2 -X -> x5
+           / \
+        x3 ---> x6
+
+    .. code-block:: python
+
+        x1 = x1()
+        x2 = x2()
+        x3 = x3()
+        x4 = x4()
+        x5 = x5()
+        x6 = x6()
+        x1.set_downstream(x4)
+        x1.set_downstream(x5)
+        x1.set_downstream(x6)
+        x2.set_downstream(x4)
+        x2.set_downstream(x5)
+        x2.set_downstream(x6)
+        x3.set_downstream(x4)
+        x3.set_downstream(x5)
+        x3.set_downstream(x6)
+
+    It is also possible to mix between classic operator/sensor and XComArg tasks:
+
+    .. code-block:: python
+
+        cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
+
+    is equivalent to::
+
+        t1 ---> x1
+           \ /
+        x2 -X -> t2
+           / \
+        t3 ---> x3
+
+    .. code-block:: python
+
+        x1 = x1()
+        x2 = x2()
+        x3 = x3()
+        t1.set_downstream(x1)
+        t1.set_downstream(t2)
+        t1.set_downstream(x3)
+        x2.set_downstream(x1)
+        x2.set_downstream(t2)
+        x2.set_downstream(x3)
+        t3.set_downstream(x1)
+        t3.set_downstream(t2)
+        t3.set_downstream(x3)
+
+    :param from_tasks: List of tasks or XComArgs to start from.
+    :type from_tasks: List[airflow.models.BaseOperator] or List[airflow.models.XComArg]
+    :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
+    :type to_tasks: List[airflow.models.BaseOperator] or List[airflow.models.XComArg]
     """
     for task in from_tasks:
         task.set_downstream(to_tasks)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 5c86f14..586094a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -372,6 +372,8 @@ Webhook
 Webserver
 Werkzeug
 XCom
+XComArg
+XComArgs
 XComs
 Xcom
 Xero
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index 04d3f54..3f5ba12 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -24,6 +24,7 @@ import jinja2
 import pytest
 from parameterized import parameterized
 
+from airflow.decorators import task as task_decorator
 from airflow.exceptions import AirflowException
 from airflow.lineage.entities import File
 from airflow.models import DAG
@@ -383,6 +384,22 @@ class TestBaseOperatorMethods(unittest.TestCase):
         for start_task in start_tasks:
             assert set(start_task.get_direct_relatives(upstream=False)) == set(end_tasks)
 
+        # Begin test for `XComArgs`
+        xstart_tasks = [
+            task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+            for i in range(1, 4)
+        ]
+        xend_tasks = [
+            task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+            for i in range(4, 7)
+        ]
+        cross_downstream(from_tasks=xstart_tasks, to_tasks=xend_tasks)
+
+        for xstart_task in xstart_tasks:
+            assert set(xstart_task.operator.get_direct_relatives(upstream=False)) == {
+                xend_task.operator for xend_task in xend_tasks
+            }
+
     def test_chain(self):
         dag = DAG(dag_id='test_chain', start_date=datetime.now())
         [op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
@@ -393,18 +410,48 @@ class TestBaseOperatorMethods(unittest.TestCase):
         assert [op5] == op3.get_direct_relatives(upstream=False)
         assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
 
+        # Begin test for `XComArgs`
+        [xop1, xop2, xop3, xop4, xop5, xop6] = [
+            task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+            for i in range(1, 7)
+        ]
+        chain(xop1, [xop2, xop3], [xop4, xop5], xop6)
+
+        assert {xop2.operator, xop3.operator} == set(xop1.operator.get_direct_relatives(upstream=False))
+        assert [xop4.operator] == xop2.operator.get_direct_relatives(upstream=False)
+        assert [xop5.operator] == xop3.operator.get_direct_relatives(upstream=False)
+        assert {xop4.operator, xop5.operator} == set(xop6.operator.get_direct_relatives(upstream=True))
+
     def test_chain_not_support_type(self):
         dag = DAG(dag_id='test_chain', start_date=datetime.now())
         [op1, op2] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 3)]
         with pytest.raises(TypeError):
             chain([op1, op2], 1)
 
+        # Begin test for `XComArgs`
+        [xop1, xop2] = [
+            task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+            for i in range(1, 3)
+        ]
+
+        with pytest.raises(TypeError):
+            chain([xop1, xop2], 1)
+
     def test_chain_different_length_iterable(self):
         dag = DAG(dag_id='test_chain', start_date=datetime.now())
         [op1, op2, op3, op4, op5] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 6)]
         with pytest.raises(AirflowException):
             chain([op1, op2], [op3, op4, op5])
 
+        # Begin test for `XComArgs`
+        [xop1, xop2, xop3, xop4, xop5] = [
+            task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+            for i in range(1, 6)
+        ]
+
+        with pytest.raises(AirflowException):
+            chain([xop1, xop2], [xop3, xop4, xop5])
+
     def test_lineage_composition(self):
         """
         Test composition with lineage