You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/07/25 16:20:24 UTC
[airflow] branch main updated: Adding EdgeModifier support for
chain() (#17099)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 29d8e7f Adding EdgeModifier support for chain() (#17099)
29d8e7f is described below
commit 29d8e7f50b6e946a6b6561cad99620e00a2c8360
Author: josh-fell <48...@users.noreply.github.com>
AuthorDate: Sun Jul 25 12:20:09 2021 -0400
Adding EdgeModifier support for chain() (#17099)
---
airflow/models/baseoperator.py | 54 +++++++++++++++++++++++++++------------
docs/spelling_wordlist.txt | 2 ++
tests/models/test_baseoperator.py | 35 +++++++++++++++++++++++--
3 files changed, 72 insertions(+), 19 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 8cbaad7..5018a76 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1546,13 +1546,16 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
return getattr(self, '_is_dummy', False)
-def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "XComArg"]]]):
+Chainable = Union[BaseOperator, "XComArg", EdgeModifier]
+
+
+def chain(*tasks: Union[Chainable, Sequence[Chainable]]) -> None:
r"""
Given a number of tasks, builds a dependency chain.
- This function accepts values of BaseOperator (aka tasks), XComArg, or lists containing
- either type (or a mix of both in the same list). If you want to chain between two lists you must
- ensure they have the same length.
+ This function accepts values of BaseOperator (aka tasks), EdgeModifiers (aka Labels), XComArg, or lists
+ containing any mix of these types (or a mix in the same list). If you want to chain between two lists
+ you must ensure they have the same length.
Using classic operators/sensors:
@@ -1603,41 +1606,58 @@ def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "X
x5.set_downstream(x6)
- It is also possible to mix between classic operator/sensor and XComArg tasks:
+ It is also possible to mix between classic operator/sensor, EdgeModifiers, and XComArg tasks:
.. code-block:: python
- chain(t1, [x1(), x2()], t2, x3())
+ chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], t2, x3())
is equivalent to::
- / -> x1 \
- t1 -> t2 -> x3
- \ -> x2 /
+ / "branch one" -> x1 \
+ t1 -> t2 -> x3
+ \ "branch two" -> x2 /
.. code-block:: python
x1 = x1()
x2 = x2()
x3 = x3()
- t1.set_downstream(x1)
- t1.set_downstream(x2)
+ label1 = Label("branch one")
+ label2 = Label("branch two")
+ t1.set_downstream(label1)
+ label1.set_downstream(x1)
+ t2.set_downstream(label2)
+ label2.set_downstream(x2)
x1.set_downstream(t2)
x2.set_downstream(t2)
t2.set_downstream(x3)
- :param tasks: List of tasks or XComArgs to set dependencies
- :type tasks: List[airflow.models.BaseOperator], airflow.models.BaseOperator, List[airflow.models.XComArg],
+ # or
+
+ x1 = x1()
+ x2 = x2()
+ x3 = x3()
+ t1.set_downstream(x1, edge_modifier=Label("branch one"))
+ t1.set_downstream(x2, edge_modifier=Label("branch two"))
+ x1.set_downstream(t2)
+ x2.set_downstream(t2)
+ t2.set_downstream(x3)
+
+
+ :param tasks: Individual and/or list of tasks, EdgeModifiers, or XComArgs to set dependencies
+ :type tasks: List[airflow.models.BaseOperator], airflow.models.BaseOperator,
+ List[airflow.utils.EdgeModifier], airflow.utils.EdgeModifier, List[airflow.models.XComArg],
or XComArg
"""
from airflow.models.xcom_arg import XComArg
for index, up_task in enumerate(tasks[:-1]):
down_task = tasks[index + 1]
- if isinstance(up_task, (BaseOperator, XComArg)):
+ if isinstance(up_task, (BaseOperator, XComArg, EdgeModifier)):
up_task.set_downstream(down_task)
continue
- if isinstance(down_task, (BaseOperator, XComArg)):
+ if isinstance(down_task, (BaseOperator, XComArg, EdgeModifier)):
down_task.set_upstream(up_task)
continue
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
@@ -1650,8 +1670,8 @@ def chain(*tasks: Union[BaseOperator, "XComArg", Sequence[Union[BaseOperator, "X
down_task_list = down_task
if len(up_task_list) != len(down_task_list):
raise AirflowException(
- f'Chain not supported different length Iterable '
- f'but get {len(up_task_list)} and {len(down_task_list)}'
+ f'Chain not supported for different length Iterable. '
+ f'Got {len(up_task_list)} and {len(down_task_list)}.'
)
for up_t, down_t in zip(up_task_list, down_task_list):
up_t.set_downstream(down_t)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 586094a..403750f 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -116,6 +116,8 @@ Dsn
Dynamodb
EDITMSG
ETag
+EdgeModifier
+EdgeModifiers
Eg
EmrAddSteps
EmrCreateJobFlow
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index f2ca59e..becb970 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -30,6 +30,7 @@ from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperatorMeta, chain, cross_downstream
from airflow.operators.dummy import DummyOperator
+from airflow.utils.edgemodifier import Label
from tests.models import DEFAULT_DATE
from tests.test_utils.mock_operators import DeprecatedOperator, MockNamedTuple, MockOperator
@@ -402,26 +403,44 @@ class TestBaseOperatorMethods(unittest.TestCase):
def test_chain(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
+
+ # Begin test for classic operators
+ [label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
[op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
- chain(op1, [op2, op3], [op4, op5], op6)
+ chain(op1, [label1, label2], [op2, op3], [op4, op5], op6)
assert {op2, op3} == set(op1.get_direct_relatives(upstream=False))
assert [op4] == op2.get_direct_relatives(upstream=False)
assert [op5] == op3.get_direct_relatives(upstream=False)
assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
+ assert {"label": "label1"} == dag.get_edge_info(
+ upstream_task_id=op1.task_id, downstream_task_id=op2.task_id
+ )
+ assert {"label": "label2"} == dag.get_edge_info(
+ upstream_task_id=op1.task_id, downstream_task_id=op3.task_id
+ )
+
# Begin test for `XComArgs`
+ [xlabel1, xlabel2] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
[xop1, xop2, xop3, xop4, xop5, xop6] = [
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
for i in range(1, 7)
]
- chain(xop1, [xop2, xop3], [xop4, xop5], xop6)
+ chain(xop1, [xlabel1, xlabel2], [xop2, xop3], [xop4, xop5], xop6)
assert {xop2.operator, xop3.operator} == set(xop1.operator.get_direct_relatives(upstream=False))
assert [xop4.operator] == xop2.operator.get_direct_relatives(upstream=False)
assert [xop5.operator] == xop3.operator.get_direct_relatives(upstream=False)
assert {xop4.operator, xop5.operator} == set(xop6.operator.get_direct_relatives(upstream=True))
+ assert {"label": "xcomarg_label1"} == dag.get_edge_info(
+ upstream_task_id=xop1.operator.task_id, downstream_task_id=xop2.operator.task_id
+ )
+ assert {"label": "xcomarg_label2"} == dag.get_edge_info(
+ upstream_task_id=xop1.operator.task_id, downstream_task_id=xop3.operator.task_id
+ )
+
def test_chain_not_support_type(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 3)]
@@ -437,13 +456,22 @@ class TestBaseOperatorMethods(unittest.TestCase):
with pytest.raises(TypeError):
chain([xop1, xop2], 1)
+ with pytest.raises(TypeError):
+ chain([Label("labe1"), Label("label2")], 1)
+
def test_chain_different_length_iterable(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
+ [label1, label2] = [Label(label=f"label{i}") for i in range(1, 3)]
[op1, op2, op3, op4, op5] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 6)]
+
with pytest.raises(AirflowException):
chain([op1, op2], [op3, op4, op5])
+ with pytest.raises(AirflowException):
+ chain([op1, op2, op3], [label1, label2])
+
# Begin test for `XComArgs`
+ [label3, label4] = [Label(label=f"xcomarg_label{i}") for i in range(1, 3)]
[xop1, xop2, xop3, xop4, xop5] = [
task_decorator(task_id=f"xcomarg_task{i}", python_callable=lambda: None, dag=dag)()
for i in range(1, 6)
@@ -452,6 +480,9 @@ class TestBaseOperatorMethods(unittest.TestCase):
with pytest.raises(AirflowException):
chain([xop1, xop2], [xop3, xop4, xop5])
+ with pytest.raises(AirflowException):
+ chain([xop1, xop2, xop3], [label1, label2])
+
def test_lineage_composition(self):
"""
Test composition with lineage