You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/11 07:30:49 UTC

[airflow] branch main updated: Fix mapped sensor with reschedule mode (#25594)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f3733ea31 Fix mapped sensor with reschedule mode (#25594)
5f3733ea31 is described below

commit 5f3733ea310b53a0a90c660dc94dd6e1ad5755b7
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Aug 11 08:30:33 2022 +0100

    Fix mapped sensor with reschedule mode (#25594)
    
    There are two issues with mapped sensor with `reschedule` mode. First, the reschedule table is being populated with a default map_index of -1 even when the map_index is not -1. Secondly, MappedOperator does not have the `ReadyToReschedule` dependency.
    This PR is an attempt to fix this
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/models/taskinstance.py                     |   1 +
 airflow/models/taskreschedule.py                   |   1 +
 airflow/serialization/serialized_objects.py        |   3 +-
 airflow/ti_deps/deps/ready_to_reschedule.py        |  11 +-
 tests/models/test_taskinstance.py                  | 217 +++++++++++++++++++++
 tests/serialization/test_dag_serialization.py      |  28 +++
 tests/ti_deps/deps/test_ready_to_reschedule_dep.py |  81 +++++++-
 7 files changed, 337 insertions(+), 5 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 7930d91fb8..ebdf07a382 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1806,6 +1806,7 @@ class TaskInstance(Base, LoggingMixin):
                 actual_start_date,
                 self.end_date,
                 reschedule_exception.reschedule_date,
+                self.map_index,
             )
         )
 
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 48e9178345..3dc0e6d609 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -112,6 +112,7 @@ class TaskReschedule(Base):
             TR.dag_id == task_instance.dag_id,
             TR.task_id == task_instance.task_id,
             TR.run_id == task_instance.run_id,
+            TR.map_index == task_instance.map_index,
             TR.try_number == try_number,
         )
         if descending:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 9ddec956c5..c5bc9567fe 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -642,8 +642,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     @classmethod
     def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
-        serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator))
-
+        serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
         # Handle expand_input and op_kwargs_expand_input.
         expansion_kwargs = op._get_specified_expand_input()
         serialized_op[op._expand_input_attr] = {
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py
index 9086822cea..88219bb401 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -40,7 +40,11 @@ class ReadyToRescheduleDep(BaseTIDep):
         considered as passed. This dependency fails if the latest reschedule
         request's reschedule date is still in future.
         """
-        if not getattr(ti.task, "reschedule", False):
+        is_mapped = ti.task.is_mapped
+        if not is_mapped and not getattr(ti.task, "reschedule", False):
+            # Mapped sensors don't have the reschedule property (it can only
+            # be calculated after unmapping), so we don't check them here.
+            # They are handled below by checking TaskReschedule instead.
             yield self._passing_status(reason="Task is not in reschedule mode.")
             return
 
@@ -62,6 +66,11 @@ class ReadyToRescheduleDep(BaseTIDep):
             .first()
         )
         if not task_reschedule:
+            # Because mapped sensors don't have the reschedule property, here's the last resort
+            # and we need a slightly different passing reason
+            if is_mapped:
+                yield self._passing_status(reason="The task is mapped and not in reschedule mode")
+                return
             yield self._passing_status(reason="There is no reschedule request for this task instance.")
             return
 
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 0ef15d3018..76a9a303f5 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -801,6 +801,174 @@ class TestTaskInstance:
         done, fail = True, False
         run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
 
+    def test_mapped_reschedule_handling(self, dag_maker):
+        """
+        Test that mapped task reschedules are handled properly
+        """
+        # Return values of the python sensor callable, modified during tests
+        done = False
+        fail = False
+
+        def func():
+            if fail:
+                raise AirflowException()
+            return done
+
+        with dag_maker(dag_id='test_reschedule_handling') as dag:
+
+            task = PythonSensor.partial(
+                task_id='test_reschedule_handling_sensor',
+                mode='reschedule',
+                python_callable=func,
+                retries=1,
+                retry_delay=datetime.timedelta(seconds=0),
+            ).expand(poke_interval=[0])
+
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+
+        ti.task = task
+        assert ti._try_number == 0
+        assert ti.try_number == 1
+
+        def run_ti_and_assert(
+            run_date,
+            expected_start_date,
+            expected_end_date,
+            expected_duration,
+            expected_state,
+            expected_try_number,
+            expected_task_reschedule_count,
+        ):
+            ti.refresh_from_task(task)
+            with freeze_time(run_date):
+                try:
+                    ti.run()
+                except AirflowException:
+                    if not fail:
+                        raise
+            ti.refresh_from_db()
+            assert ti.state == expected_state
+            assert ti._try_number == expected_try_number
+            assert ti.try_number == expected_try_number + 1
+            assert ti.start_date == expected_start_date
+            assert ti.end_date == expected_end_date
+            assert ti.duration == expected_duration
+            trs = TaskReschedule.find_for_task_instance(ti)
+            assert len(trs) == expected_task_reschedule_count
+
+        date1 = timezone.utcnow()
+        date2 = date1 + datetime.timedelta(minutes=1)
+        date3 = date2 + datetime.timedelta(minutes=1)
+        date4 = date3 + datetime.timedelta(minutes=1)
+
+        # Run with multiple reschedules.
+        # During reschedule the try number remains the same, but each reschedule is recorded.
+        # The start date is expected to remain the initial date, hence the duration increases.
+        # When finished the try number is incremented and there is no reschedule expected
+        # for this try.
+
+        done, fail = False, False
+        run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)
+
+        done, fail = False, False
+        run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2)
+
+        done, fail = False, False
+        run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3)
+
+        done, fail = True, False
+        run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0)
+
+        # Clear the task instance.
+        dag.clear()
+        ti.refresh_from_db()
+        assert ti.state == State.NONE
+        assert ti._try_number == 1
+
+        # Run again after clearing with reschedules and a retry.
+        # The retry increments the try number, and for that try no reschedule is expected.
+        # After the retry the start date is reset, hence the duration is also reset.
+
+        done, fail = False, False
+        run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)
+
+        done, fail = False, True
+        run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0)
+
+        done, fail = False, False
+        run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1)
+
+        done, fail = True, False
+        run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
+
+    @pytest.mark.usefixtures('test_pool')
+    def test_mapped_task_reschedule_handling_clear_reschedules(self, dag_maker):
+        """
+        Test that mapped task reschedules clearing are handled properly
+        """
+        # Return values of the python sensor callable, modified during tests
+        done = False
+        fail = False
+
+        def func():
+            if fail:
+                raise AirflowException()
+            return done
+
+        with dag_maker(dag_id='test_reschedule_handling') as dag:
+            task = PythonSensor.partial(
+                task_id='test_reschedule_handling_sensor',
+                mode='reschedule',
+                python_callable=func,
+                retries=1,
+                retry_delay=datetime.timedelta(seconds=0),
+                pool='test_pool',
+            ).expand(poke_interval=[0])
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
+        assert ti._try_number == 0
+        assert ti.try_number == 1
+
+        def run_ti_and_assert(
+            run_date,
+            expected_start_date,
+            expected_end_date,
+            expected_duration,
+            expected_state,
+            expected_try_number,
+            expected_task_reschedule_count,
+        ):
+            ti.refresh_from_task(task)
+            with freeze_time(run_date):
+                try:
+                    ti.run()
+                except AirflowException:
+                    if not fail:
+                        raise
+            ti.refresh_from_db()
+            assert ti.state == expected_state
+            assert ti._try_number == expected_try_number
+            assert ti.try_number == expected_try_number + 1
+            assert ti.start_date == expected_start_date
+            assert ti.end_date == expected_end_date
+            assert ti.duration == expected_duration
+            trs = TaskReschedule.find_for_task_instance(ti)
+            assert len(trs) == expected_task_reschedule_count
+
+        date1 = timezone.utcnow()
+
+        done, fail = False, False
+        run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)
+
+        # Clear the task instance.
+        dag.clear()
+        ti.refresh_from_db()
+        assert ti.state == State.NONE
+        assert ti._try_number == 0
+        # Check that reschedules for ti have also been cleared.
+        trs = TaskReschedule.find_for_task_instance(ti)
+        assert not trs
+
     @pytest.mark.usefixtures('test_pool')
     def test_reschedule_handling_clear_reschedules(self, dag_maker):
         """
@@ -2541,6 +2709,55 @@ def test_sensor_timeout(mode, retries, dag_maker):
     assert ti.state == State.FAILED
 
 
+@pytest.mark.parametrize("mode", ["poke", "reschedule"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_mapped_sensor_timeout(mode, retries, dag_maker):
+    """
+    Test that AirflowSensorTimeout does not cause mapped sensor to retry.
+    """
+
+    def timeout():
+        raise AirflowSensorTimeout
+
+    mock_on_failure = mock.MagicMock()
+    with dag_maker(dag_id=f'test_sensor_timeout_{mode}_{retries}'):
+        PythonSensor.partial(
+            task_id='test_raise_sensor_timeout',
+            python_callable=timeout,
+            on_failure_callback=mock_on_failure,
+            retries=retries,
+        ).expand(mode=[mode])
+    ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+
+    with pytest.raises(AirflowSensorTimeout):
+        ti.run()
+
+    assert mock_on_failure.called
+    assert ti.state == State.FAILED
+
+
+@pytest.mark.parametrize("mode", ["poke", "reschedule"])
+@pytest.mark.parametrize("retries", [0, 1])
+def test_mapped_sensor_works(mode, retries, dag_maker):
+    """
+    Test that mapped sensors reaches success state.
+    """
+
+    def timeout(ti):
+        return 1
+
+    with dag_maker(dag_id=f'test_sensor_timeout_{mode}_{retries}'):
+        PythonSensor.partial(
+            task_id='test_raise_sensor_timeout',
+            python_callable=timeout,
+            retries=retries,
+        ).expand(mode=[mode])
+    ti = dag_maker.create_dagrun().task_instances[0]
+
+    ti.run()
+    assert ti.state == State.SUCCESS
+
+
 class TestTaskInstanceRecordTaskMapXComPush:
     """Test TI.xcom_push() correctly records return values for task-mapping."""
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 37da287cf1..ec0996bf43 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -45,6 +45,7 @@ from airflow.models.param import Param, ParamsDict
 from airflow.models.xcom import XCOM_RETURN_KEY, XCom
 from airflow.operators.bash import BashOperator
 from airflow.security import permissions
+from airflow.sensors.bash import BashSensor
 from airflow.serialization.json_schema import load_dag_schema_dict
 from airflow.serialization.serialized_objects import (
     DagDependency,
@@ -1270,6 +1271,7 @@ class TestStringifiedDAGs:
             task1 >> task2
 
         serialize_op = SerializedBaseOperator.serialize_operator(dag.task_dict["task1"])
+
         deps = serialize_op["deps"]
         assert deps == [
             'airflow.ti_deps.deps.not_in_retry_period_dep.NotInRetryPeriodDep',
@@ -1611,6 +1613,21 @@ class TestStringifiedDAGs:
         assert serialized_op.reschedule == (mode == "reschedule")
         assert op.deps == serialized_op.deps
 
+    @pytest.mark.parametrize("mode", ["poke", "reschedule"])
+    def test_serialize_mapped_sensor_has_reschedule_dep(self, mode):
+        from airflow.sensors.base import BaseSensorOperator
+
+        class DummySensor(BaseSensorOperator):
+            def poke(self, context: Context):
+                return False
+
+        op = DummySensor.partial(task_id='dummy', mode=mode).expand(poke_interval=[23])
+
+        blob = SerializedBaseOperator.serialize_mapped_operator(op)
+        assert "deps" in blob
+
+        assert 'airflow.ti_deps.deps.ready_to_reschedule.ReadyToRescheduleDep' in blob['deps']
+
     @pytest.mark.parametrize(
         "passed_success_callback, expected_value",
         [
@@ -1980,6 +1997,17 @@ def test_operator_expand_deserialized_unmap():
     assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal))
 
 
+def test_sensor_expand_deserialized_unmap():
+    """Unmap a deserialized mapped sensor should be similar to deserializing a non-mapped sensor"""
+    normal = BashSensor(task_id='a', bash_command=[1, 2], mode='reschedule')
+    mapped = BashSensor.partial(task_id='a', mode='reschedule').expand(bash_command=[1, 2])
+
+    serialize = SerializedBaseOperator._serialize
+
+    deserialize = SerializedBaseOperator.deserialize_operator
+    assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal))
+
+
 def test_task_resources_serde():
     """
     Test task resources serialization/deserialization.
diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
index 99416bbbc8..2ab8c539dc 100644
--- a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
+++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py
@@ -31,12 +31,30 @@ from airflow.utils.timezone import utcnow
 class TestNotInReschedulePeriodDep(unittest.TestCase):
     def _get_task_instance(self, state):
         dag = DAG('test_dag')
-        task = Mock(dag=dag, reschedule=True)
+        task = Mock(dag=dag, reschedule=True, is_mapped=False)
         ti = TaskInstance(task=task, state=state, run_id=None)
         return ti
 
     def _get_task_reschedule(self, reschedule_date):
-        task = Mock(dag_id='test_dag', task_id='test_task')
+        task = Mock(dag_id='test_dag', task_id='test_task', is_mapped=False)
+        reschedule = TaskReschedule(
+            task=task,
+            run_id=None,
+            try_number=None,
+            start_date=reschedule_date,
+            end_date=reschedule_date,
+            reschedule_date=reschedule_date,
+        )
+        return reschedule
+
+    def _get_mapped_task_instance(self, state):
+        dag = DAG('test_dag')
+        task = Mock(dag=dag, reschedule=True, is_mapped=True)
+        ti = TaskInstance(task=task, state=state, run_id=None)
+        return ti
+
+    def _get_mapped_task_reschedule(self, reschedule_date):
+        task = Mock(dag_id='test_dag', task_id='test_task', is_mapped=True)
         reschedule = TaskReschedule(
             task=task,
             run_id=None,
@@ -103,3 +121,62 @@ class TestNotInReschedulePeriodDep(unittest.TestCase):
         ][-1]
         ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
         assert not ReadyToRescheduleDep().is_met(ti=ti)
+
+    def test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self):
+        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+        dep_context = DepContext(ignore_in_reschedule_period=True)
+        assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context)
+
+    @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
+    def test_mapped_task_should_pass_if_not_reschedule_mode(self, mock_query_for_task_instance):
+        mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = []
+        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+        del ti.task.reschedule
+        assert ReadyToRescheduleDep().is_met(ti=ti)
+
+    def test_mapped_task_should_pass_if_not_in_none_state(self):
+        ti = self._get_mapped_task_instance(State.UP_FOR_RETRY)
+        assert ReadyToRescheduleDep().is_met(ti=ti)
+
+    @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
+    def test_mapped_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_instance):
+        mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = []
+        ti = self._get_mapped_task_instance(State.NONE)
+        assert ReadyToRescheduleDep().is_met(ti=ti)
+
+    @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
+    def test_mapped_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance):
+        mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = (
+            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1))
+        )
+        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+        assert ReadyToRescheduleDep().is_met(ti=ti)
+
+    @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
+    def test_mapped_task_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_instance):
+        mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [
+            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=21)),
+            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=11)),
+            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)),
+        ][-1]
+        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+        assert ReadyToRescheduleDep().is_met(ti=ti)
+
+    @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
+    def test_mapped_task_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance):
+        mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = (
+            self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1))
+        )
+
+        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+        assert not ReadyToRescheduleDep().is_met(ti=ti)
+
+    @patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
+    def test_mapped_task_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_instance):
+        mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [
+            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=19)),
+            self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=9)),
+            self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)),
+        ][-1]
+        ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
+        assert not ReadyToRescheduleDep().is_met(ti=ti)