You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/08 21:26:31 UTC

[airflow] branch v1-10-test updated: SkipMixin: Handle empty branches (#11120)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-test by this push:
     new d355a3c  SkipMixin: Handle empty branches (#11120)
d355a3c is described below

commit d355a3c425dc57ae9b827128d7406e3a54ff2004
Author: yuqian90 <yu...@gmail.com>
AuthorDate: Fri Oct 9 05:25:07 2020 +0800

    SkipMixin: Handle empty branches (#11120)
---
 airflow/models/skipmixin.py             | 29 +++++++++++++++---------
 tests/operators/test_python_operator.py | 40 +++++++++++++++++++++++++++++++++
 2 files changed, 59 insertions(+), 10 deletions(-)

diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 3b4531f..f45cac6 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -24,7 +24,6 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import State
 
 import six
-from typing import Set
 
 # The key used by SkipMixin to store XCom data.
 XCOM_SKIPMIXIN_KEY = "skipmixin_key"
@@ -122,7 +121,8 @@ class SkipMixin(LoggingMixin):
         """
         self.log.info("Following branch %s", branch_task_ids)
         if isinstance(branch_task_ids, six.string_types):
-            branch_task_ids = [branch_task_ids]
+            branch_task_ids = {branch_task_ids}
+        branch_task_ids = set(branch_task_ids)
 
         dag_run = ti.get_dagrun()
         task = ti.task
@@ -131,20 +131,29 @@ class SkipMixin(LoggingMixin):
         downstream_tasks = task.downstream_list
 
         if downstream_tasks:
-            # Also check downstream tasks of the branch task. In case the task to skip
-            # is also a downstream task of the branch task, we exclude it from skipping.
-            branch_downstream_task_ids = set()  # type: Set[str]
-            for b in branch_task_ids:
-                branch_downstream_task_ids.update(
-                    dag.get_task(b).get_flat_relative_ids(upstream=False)
+            # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
+            # we intuitively expect both "task1" and "join" to execute even though strictly speaking,
+            # "join" is also immediately downstream of "branch" and should have been skipped. Therefore,
+            # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids.
+            # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and
+            # exclude it from skipping.
+            #
+            # branch  ----->  join
+            #   \            ^
+            #     v        /
+            #       task1
+            #
+            for branch_task_id in list(branch_task_ids):
+                branch_task_ids.update(
+                    dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)
                 )
 
             skip_tasks = [
                 t
                 for t in downstream_tasks
                 if t.task_id not in branch_task_ids
-                and t.task_id not in branch_downstream_task_ids
             ]
+            follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids]
 
             self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
             with create_session() as session:
@@ -152,5 +161,5 @@ class SkipMixin(LoggingMixin):
                     dag_run, ti.execution_date, skip_tasks, session=session
                 )
                 ti.xcom_push(
-                    key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids}
+                    key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}
                 )
diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py
index 13a33b2..81eaa60 100644
--- a/tests/operators/test_python_operator.py
+++ b/tests/operators/test_python_operator.py
@@ -22,6 +22,7 @@ from __future__ import print_function, unicode_literals
 import copy
 import logging
 import os
+import pytest
 
 import unittest
 
@@ -846,3 +847,42 @@ class ShortCircuitOperatorTest(unittest.TestCase):
                 self.assertEqual(ti.state, State.SKIPPED)
             else:
                 raise
+
+
+@pytest.mark.parametrize(
+    "choice,expected_states",
+    [
+        ("task1", [State.SUCCESS, State.SUCCESS, State.SUCCESS]),
+        ("join", [State.SUCCESS, State.SKIPPED, State.SUCCESS]),
+    ]
+)
+def test_empty_branch(choice, expected_states):
+    """
+    Tests that BranchPythonOperator handles empty branches properly.
+    """
+    with DAG(
+        'test_empty_branch',
+        start_date=DEFAULT_DATE,
+    ) as dag:
+        branch = BranchPythonOperator(task_id='branch', python_callable=lambda: choice)
+        task1 = DummyOperator(task_id='task1')
+        join = DummyOperator(task_id='join', trigger_rule="none_failed_or_skipped")
+
+        branch >> [task1, join]
+        task1 >> join
+
+    dag.clear(start_date=DEFAULT_DATE)
+
+    task_ids = ["branch", "task1", "join"]
+
+    tis = {}
+    for task_id in task_ids:
+        task_instance = TI(dag.get_task(task_id), execution_date=DEFAULT_DATE)
+        tis[task_id] = task_instance
+        task_instance.run()
+
+    def get_state(ti):
+        ti.refresh_from_db()
+        return ti.state
+
+    assert [get_state(tis[task_id]) for task_id in task_ids] == expected_states