You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/13 19:49:15 UTC

[airflow] 06/08: Fix race condition with dagrun callbacks (#16741)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a76a0df969cd6e354768811886754eec6ed0d6ea
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Thu Jul 29 11:05:02 2021 -0600

    Fix race condition with dagrun callbacks (#16741)
    
    Instead of immediately sending callbacks to be processed, wait until
    after we commit so the dagrun.end_date is guaranteed to be there when
    the callback runs.
    
    (cherry picked from commit fb3031acf51f95384154143553aac1a40e568ebf)
---
 airflow/jobs/scheduler_job.py          | 18 +++++---
 tests/dag_processing/test_processor.py | 20 +++++----
 tests/jobs/test_scheduler_job.py       | 80 +++++++++++++++++++++++++++++-----
 3 files changed, 94 insertions(+), 24 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 7a37b25..18ec981 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -888,6 +888,7 @@ class SchedulerJob(BaseJob):
             # Bulk fetch the currently active dag runs for the dags we are
             # examining, rather than making one query per DagRun
 
+            callback_tuples = []
             for dag_run in dag_runs:
                 # Use try_except to not stop the Scheduler when a Serialized DAG is not found
                 # This takes care of Dynamic DAGs especially
@@ -896,13 +897,18 @@ class SchedulerJob(BaseJob):
                 # But this would take care of the scenario when the Scheduler is restarted after DagRun is
                 # created and the DAG is deleted / renamed
                 try:
-                    self._schedule_dag_run(dag_run, session)
+                    callback_to_run = self._schedule_dag_run(dag_run, session)
+                    callback_tuples.append((dag_run, callback_to_run))
                 except SerializedDagNotFound:
                     self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
                     continue
 
             guard.commit()
 
+            # Send the callbacks after we commit to ensure the context is up to date when it gets run
+            for dag_run, callback_to_run in callback_tuples:
+                self._send_dag_callbacks_to_processor(dag_run, callback_to_run)
+
             # Without this, the session has an invalid view of the DB
             session.expunge_all()
             # END: schedule TIs
@@ -1064,12 +1070,12 @@ class SchedulerJob(BaseJob):
         self,
         dag_run: DagRun,
         session: Session,
-    ) -> int:
+    ) -> Optional[DagCallbackRequest]:
         """
         Make scheduling decisions about an individual dag run
 
         :param dag_run: The DagRun to schedule
-        :return: Number of tasks scheduled
+        :return: Callback that needs to be executed
         """
         dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
 
@@ -1116,13 +1122,13 @@ class SchedulerJob(BaseJob):
         # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
         schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
 
-        self._send_dag_callbacks_to_processor(dag_run, callback_to_run)
-
         # This will do one query per dag run. We "could" build up a complex
         # query to update all the TIs across all the execution dates and dag
         # IDs in a single query, but it turns out that can be _very very slow_
         # see #11147/commit ee90807ac for more details
-        return dag_run.schedule_tis(schedulable_tis, session)
+        dag_run.schedule_tis(schedulable_tis, session)
+
+        return callback_to_run
 
     @provide_session
     def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index feb3497..b7f8e7d 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -115,6 +115,10 @@ class TestDagFileProcessor(unittest.TestCase):
         non_serialized_dagbag.sync_to_db()
         cls.dagbag = DagBag(read_dags_from_db=True)
 
+    @staticmethod
+    def assert_scheduled_ti_count(session, count):
+        assert count == session.query(TaskInstance).filter_by(state=State.SCHEDULED).count()
+
     def test_dag_file_processor_sla_miss_callback(self):
         """
         Test that the dag file processor calls the sla miss callback
@@ -387,8 +391,8 @@ class TestDagFileProcessor(unittest.TestCase):
             ti.start_date = start_date
             ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, session)
-            assert count == 1
+            self.scheduler_job._schedule_dag_run(dr, session)
+            self.assert_scheduled_ti_count(session, 1)
 
             session.refresh(ti)
             assert ti.state == State.SCHEDULED
@@ -444,8 +448,8 @@ class TestDagFileProcessor(unittest.TestCase):
             ti.start_date = start_date
             ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, session)
-            assert count == 1
+            self.scheduler_job._schedule_dag_run(dr, session)
+            self.assert_scheduled_ti_count(session, 1)
 
             session.refresh(ti)
             assert ti.state == State.SCHEDULED
@@ -504,8 +508,8 @@ class TestDagFileProcessor(unittest.TestCase):
                 ti.start_date = start_date
                 ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, session)
-            assert count == 2
+            self.scheduler_job._schedule_dag_run(dr, session)
+            self.assert_scheduled_ti_count(session, 2)
 
             session.refresh(tis[0])
             session.refresh(tis[1])
@@ -547,9 +551,9 @@ class TestDagFileProcessor(unittest.TestCase):
         BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
         SerializedDagModel.write_dag(dag=dag)
 
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
+        self.assert_scheduled_ti_count(session, 2)
         session.flush()
-        assert scheduled_tis == 2
 
         drs = DagRun.find(dag_id=dag.dag_id, session=session)
         assert len(drs) == 1
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 0ee6f5f..5de365e 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -1710,10 +1710,11 @@ class TestSchedulerJob(unittest.TestCase):
         ti = dr.get_task_instance('dummy')
         ti.set_state(state, session)
 
-        self.scheduler_job._schedule_dag_run(dr, session)
+        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
+            self.scheduler_job._do_scheduling(session)
 
         expected_callback = DagCallbackRequest(
-            full_filepath=dr.dag.fileloc,
+            full_filepath=dag.fileloc,
             dag_id=dr.dag_id,
             is_failure_callback=bool(state == State.FAILED),
             execution_date=dr.execution_date,
@@ -1729,6 +1730,64 @@ class TestSchedulerJob(unittest.TestCase):
         session.rollback()
         session.close()
 
+    def test_dagrun_callbacks_commited_before_sent(self):
+        """
+        Tests that before any callbacks are sent to the processor, the session is committed. This ensures
+        that the dagrun details are up to date when the callbacks are run.
+        """
+        dag = DAG(dag_id='test_dagrun_callbacks_commited_before_sent', start_date=DEFAULT_DATE)
+        DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.Mock()
+        self.scheduler_job._send_dag_callbacks_to_processor = mock.Mock()
+        self.scheduler_job._schedule_dag_run = mock.Mock()
+
+        # Sync DAG into DB
+        with mock.patch.object(settings, "STORE_DAG_CODE", False):
+            self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+            self.scheduler_job.dagbag.sync_to_db()
+
+        session = settings.Session()
+        orm_dag = session.query(DagModel).get(dag.dag_id)
+        assert orm_dag is not None
+
+        # Create DagRun
+        self.scheduler_job._create_dag_runs([orm_dag], session)
+
+        drs = DagRun.find(dag_id=dag.dag_id, session=session)
+        assert len(drs) == 1
+        dr = drs[0]
+
+        ti = dr.get_task_instance('dummy')
+        ti.set_state(State.SUCCESS, session)
+
+        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False), mock.patch(
+            "airflow.jobs.scheduler_job.prohibit_commit"
+        ) as mock_gaurd:
+            mock_gaurd.return_value.__enter__.return_value.commit.side_effect = session.commit
+
+            def mock_schedule_dag_run(*args, **kwargs):
+                mock_gaurd.reset_mock()
+                return None
+
+            def mock_send_dag_callbacks_to_processor(*args, **kwargs):
+                mock_gaurd.return_value.__enter__.return_value.commit.assert_called_once()
+
+            self.scheduler_job._send_dag_callbacks_to_processor.side_effect = (
+                mock_send_dag_callbacks_to_processor
+            )
+            self.scheduler_job._schedule_dag_run.side_effect = mock_schedule_dag_run
+
+            self.scheduler_job._do_scheduling(session)
+
+        # Verify dag failure callback request is sent to file processor
+        self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once()
+        # and mock_send_dag_callbacks_to_processor has asserted the callback was sent after a commit
+
+        session.rollback()
+        session.close()
+
     @parameterized.expand([(State.SUCCESS,), (State.FAILED,)])
     def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state):
         """
@@ -1765,10 +1824,15 @@ class TestSchedulerJob(unittest.TestCase):
         ti = dr.get_task_instance('test_task')
         ti.set_state(state, session)
 
-        self.scheduler_job._schedule_dag_run(dr, session)
+        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
+            self.scheduler_job._do_scheduling(session)
 
         # Verify Callback is not set (i.e is None) when no callbacks are set on DAG
-        self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None)
+        self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once()
+        call_args = self.scheduler_job._send_dag_callbacks_to_processor.call_args[0]
+        assert call_args[0].dag_id == dr.dag_id
+        assert call_args[0].execution_date == dr.execution_date
+        assert call_args[1] is None
 
         session.rollback()
         session.close()
@@ -2411,12 +2475,10 @@ class TestSchedulerJob(unittest.TestCase):
 
         # Verify that DagRun.verify_integrity is not called
         with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity:
-            scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+            self.scheduler_job._schedule_dag_run(dr, session)
             mock_verify_integrity.assert_not_called()
         session.flush()
 
-        assert scheduled_tis == 1
-
         tis_count = (
             session.query(func.count(TaskInstance.task_id))
             .filter(
@@ -2474,11 +2536,9 @@ class TestSchedulerJob(unittest.TestCase):
         dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
         assert dag_version_2 != dag_version_1
 
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
         session.flush()
 
-        assert scheduled_tis == 2
-
         drs = DagRun.find(dag_id=dag.dag_id, session=session)
         assert len(drs) == 1
         dr = drs[0]