You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/24 13:50:29 UTC

[airflow] branch main updated: ExternalTaskSensor respects soft_fail if the external task enters a failed_state (#23647)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b345981f6 ExternalTaskSensor respects soft_fail if the external task enters a failed_state (#23647)
1b345981f6 is described below

commit 1b345981f6e8e910b3542ec53829e39e6c9b6dba
Author: Andrew Gibbs <gi...@andrew.gibbs.io>
AuthorDate: Fri Jun 24 14:50:13 2022 +0100

    ExternalTaskSensor respects soft_fail if the external task enters a failed_state (#23647)
    
    * Respecting soft_fail in ExternalTaskSensor when the upstream tasks are in the failed state (#19754)
    
    - Changed behaviour of sensor to as above to respect soft_fail
    - Added tests of new soft_fail behaviour (#19754)
    - Added newsfragment and improved sensor docstring
---
 airflow/sensors/external_task.py           | 29 ++++++++++++++-
 newsfragments/23647.bugfix.rst             |  1 +
 tests/sensors/test_external_task_sensor.py | 57 +++++++++++++++++++++++++-----
 3 files changed, 78 insertions(+), 9 deletions(-)

diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index 327fed6db0..1b85d7a021 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Callable, Collection, FrozenSet, Iterable
 import attr
 from sqlalchemy import func
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.models.baseoperator import BaseOperatorLink
 from airflow.models.dag import DagModel
 from airflow.models.dagbag import DagBag
@@ -57,6 +57,24 @@ class ExternalTaskSensor(BaseSensorOperator):
     Waits for a different DAG or a task in a different DAG to complete for a
     specific logical date.
 
+    By default the ExternalTaskSensor will wait for the external task to
+    succeed, at which point it will also succeed. However, by default it will
+    *not* fail if the external task fails, but will continue to check the status
+    until the sensor times out (thus giving you time to retry the external task
+    without also having to clear the sensor).
+
+    It is possible to alter the default behavior by setting states which
+    cause the sensor to fail, e.g. by setting ``allowed_states=[State.FAILED]``
+    and ``failed_states=[State.SUCCESS]`` you will flip the behaviour to get a
+    sensor which goes green when the external task *fails* and immediately goes
+    red if the external task *succeeds*!
+
+    Note that ``soft_fail`` is respected when examining the failed_states. Thus
+    if the external task enters a failed state and ``soft_fail == True`` the
+    sensor will _skip_ rather than fail. As a result, setting ``soft_fail=True``
+    and ``failed_states=[State.SKIPPED]`` will result in the sensor skipping if
+    the external task skips.
+
     :param external_dag_id: The dag_id that contains the task you want to
         wait for
     :param external_task_id: The task_id that contains the task you want to
@@ -184,11 +202,20 @@ class ExternalTaskSensor(BaseSensorOperator):
 
         if count_failed == len(dttm_filter):
             if self.external_task_ids:
+                if self.soft_fail:
+                    raise AirflowSkipException(
+                        f'Some of the external tasks {self.external_task_ids} '
+                        f'in DAG {self.external_dag_id} failed. Skipping due to soft_fail.'
+                    )
                 raise AirflowException(
                     f'Some of the external tasks {self.external_task_ids} '
                     f'in DAG {self.external_dag_id} failed.'
                 )
             else:
+                if self.soft_fail:
+                    raise AirflowSkipException(
+                        f'The external DAG {self.external_dag_id} failed. Skipping due to soft_fail.'
+                    )
                 raise AirflowException(f'The external DAG {self.external_dag_id} failed.')
 
         return count_allowed == len(dttm_filter)
diff --git a/newsfragments/23647.bugfix.rst b/newsfragments/23647.bugfix.rst
new file mode 100644
index 0000000000..d12c1d7046
--- /dev/null
+++ b/newsfragments/23647.bugfix.rst
@@ -0,0 +1 @@
+``ExternalTaskSensor`` now supports the ``soft_fail`` flag to skip if external task or DAG enters a failed state.
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 15c78083d4..4dbc07ebc2 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -134,6 +134,28 @@ class TestExternalTaskSensor(unittest.TestCase):
                 "unit_test_dag failed."
             )
 
+    def test_external_task_sensor_soft_fail_failed_states_as_skipped(self, session=None):
+        self.test_time_sensor()
+        op = ExternalTaskSensor(
+            task_id='test_external_task_sensor_check',
+            external_dag_id=TEST_DAG_ID,
+            external_task_id=TEST_TASK_ID,
+            allowed_states=[State.FAILED],
+            failed_states=[State.SUCCESS],
+            soft_fail=True,
+            dag=self.dag,
+        )
+
+        # when
+        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+        # then
+        session = settings.Session()
+        TI = TaskInstance
+        task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
+        assert len(task_instances) == 1, "Unexpected number of task instances"
+        assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"
+
     def test_external_task_sensor_external_task_id_param(self):
         """Test external_task_ids is set properly when external_task_id is passed as a template"""
         self.test_time_sensor()
@@ -141,10 +163,7 @@ class TestExternalTaskSensor(unittest.TestCase):
             task_id='test_external_task_sensor_check',
             external_dag_id='{{ params.dag_id }}',
             external_task_id='{{ params.task_id }}',
-            params={
-                'dag_id': TEST_DAG_ID,
-                'task_id': TEST_TASK_ID,
-            },
+            params={'dag_id': TEST_DAG_ID, 'task_id': TEST_TASK_ID},
             dag=self.dag,
         )
 
@@ -162,10 +181,7 @@ class TestExternalTaskSensor(unittest.TestCase):
             task_id='test_external_task_sensor_check',
             external_dag_id='{{ params.dag_id }}',
             external_task_ids=['{{ params.task_id }}'],
-            params={
-                'dag_id': TEST_DAG_ID,
-                'task_id': TEST_TASK_ID,
-            },
+            params={'dag_id': TEST_DAG_ID, 'task_id': TEST_TASK_ID},
             dag=self.dag,
         )
 
@@ -214,6 +230,31 @@ class TestExternalTaskSensor(unittest.TestCase):
         )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
 
+    def test_external_dag_sensor_soft_fail_as_skipped(self):
+        other_dag = DAG('other_dag', default_args=self.args, end_date=DEFAULT_DATE, schedule_interval='@once')
+        other_dag.create_dagrun(
+            run_id='test', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.SUCCESS
+        )
+        op = ExternalTaskSensor(
+            task_id='test_external_dag_sensor_check',
+            external_dag_id='other_dag',
+            external_task_id=None,
+            allowed_states=[State.FAILED],
+            failed_states=[State.SUCCESS],
+            soft_fail=True,
+            dag=self.dag,
+        )
+
+        # when
+        op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+        # then
+        session = settings.Session()
+        TI = TaskInstance
+        task_instances: list[TI] = session.query(TI).filter(TI.task_id == op.task_id).all()
+        assert len(task_instances) == 1, "Unexpected number of task instances"
+        assert task_instances[0].state == State.SKIPPED, "Unexpected external task state"
+
     def test_external_task_sensor_fn_multiple_execution_dates(self):
         bash_command_code = """
 {% set s=logical_date.time().second %}