You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/08 21:26:31 UTC
[airflow] branch v1-10-test updated: SkipMixin: Handle empty
branches (#11120)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new d355a3c SkipMixin: Handle empty branches (#11120)
d355a3c is described below
commit d355a3c425dc57ae9b827128d7406e3a54ff2004
Author: yuqian90 <yu...@gmail.com>
AuthorDate: Fri Oct 9 05:25:07 2020 +0800
SkipMixin: Handle empty branches (#11120)
---
airflow/models/skipmixin.py | 29 +++++++++++++++---------
tests/operators/test_python_operator.py | 40 +++++++++++++++++++++++++++++++++
2 files changed, 59 insertions(+), 10 deletions(-)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 3b4531f..f45cac6 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -24,7 +24,6 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
import six
-from typing import Set
# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
@@ -122,7 +121,8 @@ class SkipMixin(LoggingMixin):
"""
self.log.info("Following branch %s", branch_task_ids)
if isinstance(branch_task_ids, six.string_types):
- branch_task_ids = [branch_task_ids]
+ branch_task_ids = {branch_task_ids}
+ branch_task_ids = set(branch_task_ids)
dag_run = ti.get_dagrun()
task = ti.task
@@ -131,20 +131,29 @@ class SkipMixin(LoggingMixin):
downstream_tasks = task.downstream_list
if downstream_tasks:
- # Also check downstream tasks of the branch task. In case the task to skip
- # is also a downstream task of the branch task, we exclude it from skipping.
- branch_downstream_task_ids = set() # type: Set[str]
- for b in branch_task_ids:
- branch_downstream_task_ids.update(
- dag.get_task(b).get_flat_relative_ids(upstream=False)
+ # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
+ # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
+ # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
+ # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
+ # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
+ # exclude it from skipping.
+ #
+ # branch -----> join
+ # \ ^
+ # v /
+ # task1
+ #
+ for branch_task_id in list(branch_task_ids):
+ branch_task_ids.update(
+ dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)
)
skip_tasks = [
t
for t in downstream_tasks
if t.task_id not in branch_task_ids
- and t.task_id not in branch_downstream_task_ids
]
+ follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids]
self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
with create_session() as session:
@@ -152,5 +161,5 @@ class SkipMixin(LoggingMixin):
dag_run, ti.execution_date, skip_tasks, session=session
)
ti.xcom_push(
- key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids}
+ key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}
)
diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py
index 13a33b2..81eaa60 100644
--- a/tests/operators/test_python_operator.py
+++ b/tests/operators/test_python_operator.py
@@ -22,6 +22,7 @@ from __future__ import print_function, unicode_literals
import copy
import logging
import os
+import pytest
import unittest
@@ -846,3 +847,42 @@ class ShortCircuitOperatorTest(unittest.TestCase):
self.assertEqual(ti.state, State.SKIPPED)
else:
raise
+
+
+@pytest.mark.parametrize(
+ "choice,expected_states",
+ [
+ ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
+ ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
+ ]
+)
+def test_empty_branch(choice, expected_states):
+ """
+ Tests that BranchPythonOperator handles empty branches properly.
+ """
+ with DAG(
+ 'test_empty_branch',
+ start_date=DEFAULT_DATE,
+ ) as dag:
+ branch = BranchPythonOperator(task_id='branch', python_callable=lambda: choice)
+ task1 = DummyOperator(task_id='task1')
+ join = DummyOperator(task_id='join', trigger_rule="none_failed_or_skipped")
+
+ branch >> [task1, join]
+ task1 >> join
+
+ dag.clear(start_date=DEFAULT_DATE)
+
+ task_ids = ["branch", "task1", "join"]
+
+ tis = {}
+ for task_id in task_ids:
+ task_instance = TI(dag.get_task(task_id), execution_date=DEFAULT_DATE)
+ tis[task_id] = task_instance
+ task_instance.run()
+
+ def get_state(ti):
+ ti.refresh_from_db()
+ return ti.state
+
+ assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states