You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2021/07/14 07:44:00 UTC
[airflow] branch main updated: Update chain() and
cross_downstream() to support XComArgs (#16732)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7529546 Update chain() and cross_downstream() to support XComArgs (#16732)
7529546 is described below
commit 7529546939250266ccf404c2eea98b298365ef46
Author: josh-fell <48...@users.noreply.github.com>
AuthorDate: Wed Jul 14 03:43:41 2021 -0400
Update chain() and cross_downstream() to support XComArgs (#16732)
Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
---
airflow/models/baseoperator.py | 157 +++++++++++++++++++++++++++++++++-----
docs/spelling_wordlist.txt | 2 +
tests/models/test_baseoperator.py | 47 ++++++++++++
3 files changed, 188 insertions(+), 18 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 6d478e4..0c98467 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -74,6 +74,7 @@ from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
if TYPE_CHECKING:
+ from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -1545,22 +1546,25 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
return getattr(self, '_is_dummy', False)
-def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
+def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "XComArg"]]]):
r"""
Given a number of tasks, builds a dependency chain.
- Support mix airflow.models.BaseOperator and List[airflow.models.BaseOperator].
- If you want to chain between two List[airflow.models.BaseOperator], have to
- make sure they have same length.
+
+ This function accepts values of BaseOperator (aka tasks), XComArg, or lists containing
+ either type (or a mix of both in the same list). If you want to chain between two lists you must
+ ensure they have the same length.
+
+ Using classic operators/sensors:
.. code-block:: python
- chain(t1, [t2, t3], [t4, t5], t6)
+ chain(t1, [t2, t3], [t4, t5], t6)
is equivalent to::
- / -> t2 -> t4 \
- t1 -> t6
- \ -> t3 -> t5 /
+ / -> t2 -> t4 \
+ t1 -> t6
+ \ -> t3 -> t5 /
.. code-block:: python
@@ -1571,15 +1575,69 @@ def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
t4.set_downstream(t6)
t5.set_downstream(t6)
- :param tasks: List of tasks or List[airflow.models.BaseOperator] to set dependencies
- :type tasks: List[airflow.models.BaseOperator] or airflow.models.BaseOperator
+ Using task-decorated functions aka XComArgs:
+
+ .. code-block:: python
+
+ chain(x1(), [x2(), x3()], [x4(), x5()], x6())
+
+ is equivalent to::
+
+ / -> x2 -> x4 \
+ x1 -> x6
+ \ -> x3 -> x5 /
+
+ .. code-block:: python
+
+ x1 = x1()
+ x2 = x2()
+ x3 = x3()
+ x4 = x4()
+ x5 = x5()
+ x6 = x6()
+ x1.set_downstream(x2)
+ x1.set_downstream(x3)
+ x2.set_downstream(x4)
+ x3.set_downstream(x5)
+ x4.set_downstream(x6)
+ x5.set_downstream(x6)
+
+
+ It is also possible to mix between classic operator/sensor and XComArg tasks:
+
+ .. code-block:: python
+
+ chain(t1, [x1(), x2()], t2, x3())
+
+ is equivalent to::
+
+ / -> x1 \
+ t1 -> t2 -> x3
+ \ -> x2 /
+
+ .. code-block:: python
+
+ x1 = x1()
+ x2 = x2()
+ x3 = x3()
+ t1.set_downstream(x1)
+ t1.set_downstream(x2)
+ x1.set_downstream(t2)
+ x2.set_downstream(t2)
+ t2.set_downstream(x3)
+
+ :param tasks: List of tasks or XComArgs to set dependencies
+ :type tasks: List[airflow.models.BaseOperator], airflow.models.BaseOperator, List[airflow.models.XComArg],
+ or XComArg
"""
+ from airflow.models.xcom_arg import XComArg
+
for index, up_task in enumerate(tasks[:-1]):
down_task = tasks[index + 1]
- if isinstance(up_task, BaseOperator):
+ if isinstance(up_task, (BaseOperator, XComArg)):
up_task.set_downstream(down_task)
continue
- if isinstance(down_task, BaseOperator):
+ if isinstance(down_task, (BaseOperator, XComArg)):
down_task.set_upstream(up_task)
continue
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
@@ -1600,11 +1658,14 @@ def chain(*tasks: Union[BaseOperator, Sequence[BaseOperator]]):
def cross_downstream(
- from_tasks: Sequence[BaseOperator], to_tasks: Union[BaseOperator, Sequence[BaseOperator]]
+ from_tasks: Sequence[Union[BaseOperator, "XComArg"]],
+ to_tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "XComArg"]]],
):
r"""
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
+ Using classic operators/sensors:
+
.. code-block:: python
cross_downstream(from_tasks=[t1, t2, t3], to_tasks=[t4, t5, t6])
@@ -1617,7 +1678,6 @@ def cross_downstream(
/ \
t3 ---> t6
-
.. code-block:: python
t1.set_downstream(t4)
@@ -1630,10 +1690,71 @@ def cross_downstream(
t3.set_downstream(t5)
t3.set_downstream(t6)
- :param from_tasks: List of tasks to start from.
- :type from_tasks: List[airflow.models.BaseOperator]
- :param to_tasks: List of tasks to set as downstream dependencies.
- :type to_tasks: List[airflow.models.BaseOperator]
+ Using task-decorated functions aka XComArgs:
+
+ .. code-block:: python
+
+ cross_downstream(from_tasks=[x1(), x2(), x3()], to_tasks=[x4(), x5(), x6()])
+
+ is equivalent to::
+
+ x1 ---> x4
+ \ /
+ x2 -X -> x5
+ / \
+ x3 ---> x6
+
+ .. code-block:: python
+
+ x1 = x1()
+ x2 = x2()
+ x3 = x3()
+ x4 = x4()
+ x5 = x5()
+ x6 = x6()
+ x1.set_downstream(x4)
+ x1.set_downstream(x5)
+ x1.set_downstream(x6)
+ x2.set_downstream(x4)
+ x2.set_downstream(x5)
+ x2.set_downstream(x6)
+ x3.set_downstream(x4)
+ x3.set_downstream(x5)
+ x3.set_downstream(x6)
+
+ It is also possible to mix between classic operator/sensor and XComArg tasks:
+
+ .. code-block:: python
+
+ cross_downstream(from_tasks=[t1, x2(), t3], to_tasks=[x1(), t2, x3()])
+
+ is equivalent to::
+
+ t1 ---> x1
+ \ /
+ x2 -X -> t2
+ / \
+ t3 ---> x3
+
+ .. code-block:: python
+
+ x1 = x1()
+ x2 = x2()
+ x3 = x3()
+ t1.set_downstream(x1)
+ t1.set_downstream(t2)
+ t1.set_downstream(x3)
+ x2.set_downstream(x1)
+ x2.set_downstream(t2)
+ x2.set_downstream(x3)
+ t3.set_downstream(x1)
+ t3.set_downstream(t2)
+ t3.set_downstream(x3)
+
+ :param from_tasks: List of tasks or XComArgs to start from.
+ :type from_tasks: List[airflow.models.BaseOperator] or List[airflow.models.XComArg]
+ :param to_tasks: List of tasks or XComArgs to set as downstream dependencies.
+ :type to_tasks: List[airflow.models.BaseOperator] or List[airflow.models.XComArg]
"""
for task in from_tasks:
task.set_downstream(to_tasks)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 5c86f14..586094a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -372,6 +372,8 @@ Webhook
Webserver
Werkzeug
XCom
+XComArg
+XComArgs
XComs
Xcom
Xero
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index 04d3f54..3f5ba12 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -24,6 +24,7 @@ import jinja2
import pytest
from parameterized import parameterized
+from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException
from airflow.lineage.entities import File
from airflow.models import DAG
@@ -383,6 +384,22 @@ class TestBaseOperatorMethods(unittest.TestCase):
for start_task in start_tasks:
assert set(start_task.get_direct_relatives(upstream=False)) == set(end_tasks)
+ # Begin test for `XComArgs`
+ xstart_tasks = [
+ task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+ for i in range(1, 4)
+ ]
+ xend_tasks = [
+ task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+ for i in range(4, 7)
+ ]
+ cross_downstream(from_tasks=xstart_tasks, to_tasks=xend_tasks)
+
+ for xstart_task in xstart_tasks:
+ assert set(xstart_task.operator.get_direct_relatives(upstream=False)) == {
+ xend_task.operator for xend_task in xend_tasks
+ }
+
def test_chain(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
@@ -393,18 +410,48 @@ class TestBaseOperatorMethods(unittest.TestCase):
assert [op5] == op3.get_direct_relatives(upstream=False)
assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
+ # Begin test for `XComArgs`
+ [xop1, xop2, xop3, xop4, xop5, xop6] = [
+ task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+ for i in range(1, 7)
+ ]
+ chain(xop1, [xop2, xop3], [xop4, xop5], xop6)
+
+ assert {xop2.operator, xop3.operator} == set(xop1.operator.get_direct_relatives(upstream=False))
+ assert [xop4.operator] == xop2.operator.get_direct_relatives(upstream=False)
+ assert [xop5.operator] == xop3.operator.get_direct_relatives(upstream=False)
+ assert {xop4.operator, xop5.operator} == set(xop6.operator.get_direct_relatives(upstream=True))
+
def test_chain_not_support_type(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 3)]
with pytest.raises(TypeError):
chain([op1, op2], 1)
+ # Begin test for `XComArgs`
+ [xop1, xop2] = [
+ task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+ for i in range(1, 3)
+ ]
+
+ with pytest.raises(TypeError):
+ chain([xop1, xop2], 1)
+
def test_chain_different_length_iterable(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2, op3, op4, op5] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 6)]
with pytest.raises(AirflowException):
chain([op1, op2], [op3, op4, op5])
+ # Begin test for `XComArgs`
+ [xop1, xop2, xop3, xop4, xop5] = [
+ task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
+ for i in range(1, 6)
+ ]
+
+ with pytest.raises(AirflowException):
+ chain([xop1, xop2], [xop3, xop4, xop5])
+
def test_lineage_composition(self):
"""
Test composition with lineage