You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/14 01:14:56 UTC

[airflow] branch v2-1-test updated (f1268d2 -> 744ba52)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a change to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.


 discard f1268d2  Update documentation regarding Python 3.9 support (#17611)
 discard 32d4bab  Bump version to 2.1.3
 discard 6ab83b5  Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)
 discard a59a567  Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)
 discard cbe87e0  Fix race condition with dagrun callbacks (#16741)
 discard 9d77deb  Add 'queued' to DagRunState (#16854)
 discard 059f5ab  Add 'queued' state to DagRun (#16401)
 discard 004a414  Move DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py (#16581)
 discard a4a67cb  Run mini scheduler in LocalTaskJob during task exit (#16289)
     new e8e8b19  Run mini scheduler in LocalTaskJob during task exit (#16289)
     new 2256716  Move DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py (#16581)
     new ee67daa  Add 'queued' state to DagRun (#16401)
     new 5c3e373  Add 'queued' to DagRunState (#16854)
     new 7cc6002  Fix race condition with dagrun callbacks (#16741)
     new 343beb6  Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)
     new 12f2467  Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)
     new eadea47  Bump version to 2.1.3
     new 744ba52  Update documentation regarding Python 3.9 support (#17611)

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (f1268d2)
            \
             N -- N -- N   refs/heads/v2-1-test (744ba52)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 9 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 airflow/jobs/local_task_job.py | 1 +
 1 file changed, 1 insertion(+)

[airflow] 04/09: Add 'queued' to DagRunState (#16854)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5c3e373ccf03efc1b988831d6fc6e5109f3e3197
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 7 21:37:22 2021 +0100

    Add 'queued' to DagRunState (#16854)
    
    This change adds 'queued' to DagRunState and improved typing for DagRun state
    
    Co-authored-by: Kaxil Naik <ka...@gmail.com>
    (cherry picked from commit 5a5f30f9133a6c5f0c41886ff9ae80ea53c73989)
---
 airflow/jobs/scheduler_job.py  |  8 ++++----
 airflow/models/dag.py          | 10 +++++-----
 airflow/models/dagrun.py       | 16 +++++++++-------
 airflow/models/taskinstance.py |  4 ++--
 airflow/utils/state.py         |  2 ++
 5 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index b564717..7a37b25 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -54,7 +54,7 @@ from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 TI = models.TaskInstance
@@ -192,7 +192,7 @@ class SchedulerJob(BaseJob):
 
     @provide_session
     def _change_state_for_tis_without_dagrun(
-        self, old_states: List[str], new_state: str, session: Session = None
+        self, old_states: List[TaskInstanceState], new_state: TaskInstanceState, session: Session = None
     ) -> None:
         """
         For all DAG IDs in the DagBag, look for task instances in the
@@ -266,7 +266,7 @@ class SchedulerJob(BaseJob):
 
     @provide_session
     def __get_concurrency_maps(
-        self, states: List[str], session: Session = None
+        self, states: List[TaskInstanceState], session: Session = None
     ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]:
         """
         Get the concurrency maps.
@@ -936,7 +936,7 @@ class SchedulerJob(BaseJob):
             return num_queued_tis
 
     @retry_db_transaction
-    def _get_next_dagruns_to_examine(self, state, session):
+    def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session):
         """Get Next DagRuns to Examine with retries"""
         return DagRun.next_dagruns_to_examine(state, session)
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index a3d06db..2e66b40 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -75,7 +75,7 @@ from airflow.utils.helpers import validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
@@ -1153,7 +1153,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=True,
-        dag_run_state: str = State.QUEUED,
+        dag_run_state: DagRunState = DagRunState.QUEUED,
         dry_run=False,
         session=None,
         get_tis=False,
@@ -1369,7 +1369,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=False,
-        dag_run_state=State.QUEUED,
+        dag_run_state=DagRunState.QUEUED,
         dry_run=False,
     ):
         all_tis = []
@@ -1731,7 +1731,7 @@ class DAG(LoggingMixin):
     @provide_session
     def create_dagrun(
         self,
-        state: State,
+        state: DagRunState,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         start_date: Optional[datetime] = None,
@@ -1753,7 +1753,7 @@ class DAG(LoggingMixin):
         :param execution_date: the execution date of this dag run
         :type execution_date: datetime.datetime
         :param state: the state of the dag run
-        :type state: airflow.utils.state.State
+        :type state: airflow.utils.state.DagRunState
         :param start_date: the date this dag run should be evaluated
         :type start_date: datetime
         :param external_trigger: whether this dag run is externally triggered
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index a061dcc..6f47077 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -36,7 +36,7 @@ from airflow.utils import callback_requests, timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
@@ -110,7 +110,7 @@ class DagRun(Base, LoggingMixin):
         start_date: Optional[datetime] = None,
         external_trigger: Optional[bool] = None,
         conf: Optional[Any] = None,
-        state: Optional[str] = None,
+        state: Optional[DagRunState] = None,
         run_type: Optional[str] = None,
         dag_hash: Optional[str] = None,
         creating_job_id: Optional[int] = None,
@@ -144,7 +144,7 @@ class DagRun(Base, LoggingMixin):
     def get_state(self):
         return self._state
 
-    def set_state(self, state):
+    def set_state(self, state: DagRunState):
         if self._state != state:
             self._state = state
             self.end_date = timezone.utcnow() if self._state in State.finished else None
@@ -170,7 +170,7 @@ class DagRun(Base, LoggingMixin):
     @classmethod
     def next_dagruns_to_examine(
         cls,
-        state: str,
+        state: DagRunState,
         session: Session,
         max_number: Optional[int] = None,
     ):
@@ -219,7 +219,7 @@ class DagRun(Base, LoggingMixin):
         dag_id: Optional[Union[str, List[str]]] = None,
         run_id: Optional[str] = None,
         execution_date: Optional[datetime] = None,
-        state: Optional[str] = None,
+        state: Optional[DagRunState] = None,
         external_trigger: Optional[bool] = None,
         no_backfills: bool = False,
         run_type: Optional[DagRunType] = None,
@@ -239,7 +239,7 @@ class DagRun(Base, LoggingMixin):
         :param execution_date: the execution date
         :type execution_date: datetime.datetime or list[datetime.datetime]
         :param state: the state of the dag run
-        :type state: str
+        :type state: DagRunState
         :param external_trigger: whether this dag run is externally triggered
         :type external_trigger: bool
         :param no_backfills: return no backfills (True), return all (False).
@@ -343,7 +343,9 @@ class DagRun(Base, LoggingMixin):
         return self.dag
 
     @provide_session
-    def get_previous_dagrun(self, state: Optional[str] = None, session: Session = None) -> Optional['DagRun']:
+    def get_previous_dagrun(
+        self, state: Optional[DagRunState] = None, session: Session = None
+    ) -> Optional['DagRun']:
         """The previous DagRun, if there is one"""
         filters = [
             DagRun.dag_id == self.dag_id,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0e10567..c715f22 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -70,7 +70,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.timeout import timeout
 
 try:
@@ -137,7 +137,7 @@ def clear_task_instances(
     session,
     activate_dag_runs=None,
     dag=None,
-    dag_run_state: Union[str, Literal[False]] = State.QUEUED,
+    dag_run_state: Union[DagRunState, Literal[False]] = DagRunState.QUEUED,
 ):
     """
     Clears a set of task instances, but makes sure the running ones
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index b1b27d0..e95b409 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -60,6 +60,7 @@ class DagRunState(str, Enum):
     same name in TaskInstanceState.
     """
 
+    QUEUED = "queued"
     RUNNING = "running"
     SUCCESS = "success"
     FAILED = "failed"
@@ -92,6 +93,7 @@ class State:
     task_states: Tuple[Optional[TaskInstanceState], ...] = (None,) + tuple(TaskInstanceState)
 
     dag_states: Tuple[DagRunState, ...] = (
+        DagRunState.QUEUED,
         DagRunState.SUCCESS,
         DagRunState.RUNNING,
         DagRunState.FAILED,

[airflow] 08/09: Bump version to 2.1.3

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit eadea473c30411ccb3d3cf9366316f8391cce7ce
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Wed Jul 28 16:00:25 2021 +0100

    Bump version to 2.1.3
---
 README.md | 16 ++++++++--------
 setup.py  |  2 +-
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/README.md b/README.md
index 8ab1d2b..722fad2 100644
--- a/README.md
+++ b/README.md
@@ -82,7 +82,7 @@ Airflow is not a streaming solution, but it is often used to process real-time d
 
 Apache Airflow is tested with:
 
-|                      | Main version (dev)        | Stable version (2.1.2)   |
+|                      | Main version (dev)        | Stable version (2.1.3)   |
 | -------------------- | ------------------------- | ------------------------ |
 | Python               | 3.6, 3.7, 3.8, 3.9        | 3.6, 3.7, 3.8, 3.9       |
 | Kubernetes           | 1.20, 1.19, 1.18          | 1.20, 1.19, 1.18         |
@@ -142,15 +142,15 @@ them to appropriate format and workflow that your tool requires.
 
 
 ```bash
-pip install apache-airflow==2.1.2 \
- --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.1.2/constraints-3.7.txt"
+pip install apache-airflow==2.1.3 \
+ --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.1.3/constraints-3.7.txt"
 ```
 
 2. Installing with extras (for example postgres,google)
 
 ```bash
-pip install apache-airflow[postgres,google]==2.1.2 \
- --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.1.2/constraints-3.7.txt"
+pip install apache-airflow[postgres,google]==2.1.3 \
+ --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.1.3/constraints-3.7.txt"
 ```
 
 For information on installing provider packages check
@@ -231,7 +231,7 @@ packages:
 * **Airflow Providers**: SemVer rules apply to changes in the particular provider's code only.
   SemVer MAJOR and MINOR versions for the packages are independent from Airflow version.
   For example `google 4.1.0` and `amazon 3.0.3` providers can happily be installed
-  with `Airflow 2.1.2`. If there are limits of cross-dependencies between providers and Airflow packages,
+  with `Airflow 2.1.3`. If there are limits of cross-dependencies between providers and Airflow packages,
   they are present in providers as `install_requires` limitations. We aim to keep backwards
   compatibility of providers with all previously released Airflow 2 versions but
   there will be sometimes breaking changes that might make some, or all
@@ -254,7 +254,7 @@ Apache Airflow version life cycle:
 
 | Version | Current Patch/Minor | State     | First Release | Limited Support | EOL/Terminated |
 |---------|---------------------|-----------|---------------|-----------------|----------------|
-| 2       | 2.1.2               | Supported | Dec 17, 2020  | Dec 2021        | TBD            |
+| 2       | 2.1.3               | Supported | Dec 17, 2020  | Dec 2021        | TBD            |
 | 1.10    | 1.10.15             | EOL       | Aug 27, 2018  | Dec 17, 2020    | June 17, 2021  |
 | 1.9     | 1.9.0               | EOL       | Jan 03, 2018  | Aug 27, 2018    | Aug 27, 2018   |
 | 1.8     | 1.8.2               | EOL       | Mar 19, 2017  | Jan 03, 2018    | Jan 03, 2018   |
@@ -280,7 +280,7 @@ They are based on the official release schedule of Python and Kubernetes, nicely
 
 2. The "oldest" supported version of Python/Kubernetes is the default one. "Default" is only meaningful
    in terms of "smoke tests" in CI PRs which are run using this default version and default reference
-   image available. Currently ``apache/airflow:latest`` and ``apache/airflow:2.1.2` images
+   image available. Currently ``apache/airflow:latest`` and ``apache/airflow:2.1.3` images
    are both Python 3.6 images, however the first MINOR/MAJOR release of Airflow release after 23.12.2021 will
    become Python 3.7 images.
 
diff --git a/setup.py b/setup.py
index c74808a..2027b61 100644
--- a/setup.py
+++ b/setup.py
@@ -41,7 +41,7 @@ PY39 = sys.version_info >= (3, 9)
 
 logger = logging.getLogger(__name__)
 
-version = '2.1.2'
+version = '2.1.3'
 
 my_dir = dirname(__file__)
 

[airflow] 09/09: Update documentation regarding Python 3.9 support (#17611)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 744ba522e0090beca933719729fbf7d184a69153
Author: Gabe Flores <ga...@match.com>
AuthorDate: Fri Aug 13 19:15:44 2021 -0500

    Update documentation regarding Python 3.9 support (#17611)
    
    https://github.com/apache/airflow#requirements
    
    (cherry picked from commit 721d4e7c60cbccfd064572f16c3941f41ff8ab3a)
---
 docs/apache-airflow/installation.rst | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/docs/apache-airflow/installation.rst b/docs/apache-airflow/installation.rst
index 574d7ec..571a395 100644
--- a/docs/apache-airflow/installation.rst
+++ b/docs/apache-airflow/installation.rst
@@ -34,7 +34,7 @@ Prerequisites
 
 Airflow is tested with:
 
-* Python: 3.6, 3.7, 3.8
+* Python: 3.6, 3.7, 3.8, 3.9
 
 * Databases:
 
@@ -50,8 +50,7 @@ running multiple schedulers -- please see: :doc:`/concepts/scheduler`. MariaDB i
 **Note:** SQLite is used in Airflow tests. Do not use it in production. We recommend
 using the latest stable version of SQLite for local development.
 
-Please note that with respect to Python 3 support, Airflow 2.0.0 has been
-tested with Python 3.6, 3.7, and 3.8, but does not yet support Python 3.9.
+Starting with Airflow 2.1.2, Airflow is tested with Python 3.6, 3.7, 3.8, and 3.9.
 
 Installation tools
 ''''''''''''''''''

[airflow] 03/09: Add 'queued' state to DagRun (#16401)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ee67daa2bcb02dc55ce482b9bd53b8465f8b97ef
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Jul 6 15:03:27 2021 +0100

    Add 'queued' state to DagRun (#16401)
    
    This change adds queued state to DagRun. Newly created DagRuns
    start in the queued state, are then moved to the running state after
    satisfying the DAG's max_active_runs. If the Dag doesn't have
    max_active_runs, the DagRuns are moved to running state immediately
    
    Clearing a DagRun sets the state to queued state
    
    Closes: #9975, #16366
    (cherry picked from commit 6611ffd399dce0474d8329720de7e83f568df598)
---
 airflow/api_connexion/openapi/v1.yaml              |   4 +-
 airflow/jobs/scheduler_job.py                      | 171 +++++++----------
 ...93827b8_add_queued_at_column_to_dagrun_table.py |  49 +++++
 airflow/models/dag.py                              |   4 +-
 airflow/models/dagrun.py                           |  17 +-
 airflow/models/taskinstance.py                     |   6 +-
 airflow/www/static/js/tree.js                      |   4 +-
 airflow/www/views.py                               |   5 +-
 docs/apache-airflow/migrations-ref.rst             |   4 +-
 tests/api/common/experimental/test_mark_tasks.py   |   4 +-
 .../endpoints/test_dag_run_endpoint.py             |  14 +-
 tests/api_connexion/schemas/test_dag_run_schema.py |   3 +
 tests/dag_processing/test_manager.py               |  10 +-
 tests/dag_processing/test_processor.py             |  65 ++++---
 tests/jobs/test_scheduler_job.py                   | 206 +++++++++------------
 tests/models/test_cleartasks.py                    |  37 ++++
 tests/models/test_dagrun.py                        |  25 ++-
 tests/sensors/test_external_task_sensor.py         |   8 +-
 18 files changed, 338 insertions(+), 298 deletions(-)

diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index 182f356..47fa3dc 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1853,6 +1853,7 @@ components:
           description: |
             The start time. The time when DAG run was actually created.
           readOnly: true
+          nullable: true
         end_date:
           type: string
           format: date-time
@@ -3025,8 +3026,9 @@ components:
       description: DAG State.
       type: string
       enum:
-        - success
+        - queued
         - running
+        - success
         - failed
 
     TriggerRule:
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 5b24e00..b564717 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -26,7 +26,7 @@ import sys
 import time
 from collections import defaultdict
 from datetime import timedelta
-from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple
+from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
 
 from sqlalchemy import and_, func, not_, or_, tuple_
 from sqlalchemy.exc import OperationalError
@@ -197,7 +197,7 @@ class SchedulerJob(BaseJob):
         """
         For all DAG IDs in the DagBag, look for task instances in the
         old_states and set them to new_state if the corresponding DagRun
-        does not exist or exists but is not in the running state. This
+        does not exist or exists but is not in the running or queued state. This
         normally should not happen, but it can if the state of DagRuns are
         changed manually.
 
@@ -214,7 +214,7 @@ class SchedulerJob(BaseJob):
             .filter(models.TaskInstance.state.in_(old_states))
             .filter(
                 or_(
-                    models.DagRun.state != State.RUNNING,
+                    models.DagRun.state.notin_([State.RUNNING, State.QUEUED]),
                     models.DagRun.state.is_(None),
                 )
             )
@@ -882,39 +882,12 @@ class SchedulerJob(BaseJob):
             if settings.USE_JOB_SCHEDULE:
                 self._create_dagruns_for_dags(guard, session)
 
-            dag_runs = self._get_next_dagruns_to_examine(session)
+            self._start_queued_dagruns(session)
+            guard.commit()
+            dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session)
             # Bulk fetch the currently active dag runs for the dags we are
             # examining, rather than making one query per DagRun
 
-            # TODO: This query is probably horribly inefficient (though there is an
-            # index on (dag_id,state)). It is to deal with the case when a user
-            # clears more than max_active_runs older tasks -- we don't want the
-            # scheduler to suddenly go and start running tasks from all of the
-            # runs. (AIRFLOW-137/GH #1442)
-            #
-            # The longer term fix would be to have `clear` do this, and put DagRuns
-            # in to the queued state, then take DRs out of queued before creating
-            # any new ones
-
-            # Build up a set of execution_dates that are "active" for a given
-            # dag_id -- only tasks from those runs will be scheduled.
-            active_runs_by_dag_id = defaultdict(set)
-
-            query = (
-                session.query(
-                    TI.dag_id,
-                    TI.execution_date,
-                )
-                .filter(
-                    TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
-                    TI.state.notin_(list(State.finished) + [State.REMOVED]),
-                )
-                .group_by(TI.dag_id, TI.execution_date)
-            )
-
-            for dag_id, execution_date in query:
-                active_runs_by_dag_id[dag_id].add(execution_date)
-
             for dag_run in dag_runs:
                 # Use try_except to not stop the Scheduler when a Serialized DAG is not found
                 # This takes care of Dynamic DAGs especially
@@ -923,7 +896,7 @@ class SchedulerJob(BaseJob):
                 # But this would take care of the scenario when the Scheduler is restarted after DagRun is
                 # created and the DAG is deleted / renamed
                 try:
-                    self._schedule_dag_run(dag_run, active_runs_by_dag_id.get(dag_run.dag_id, set()), session)
+                    self._schedule_dag_run(dag_run, session)
                 except SerializedDagNotFound:
                     self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
                     continue
@@ -963,9 +936,9 @@ class SchedulerJob(BaseJob):
             return num_queued_tis
 
     @retry_db_transaction
-    def _get_next_dagruns_to_examine(self, session):
+    def _get_next_dagruns_to_examine(self, state, session):
         """Get Next DagRuns to Examine with retries"""
-        return DagRun.next_dagruns_to_examine(session)
+        return DagRun.next_dagruns_to_examine(state, session)
 
     @retry_db_transaction
     def _create_dagruns_for_dags(self, guard, session):
@@ -986,14 +959,24 @@ class SchedulerJob(BaseJob):
         # as DagModel.dag_id and DagModel.next_dagrun
         # This list is used to verify if the DagRun already exist so that we don't attempt to create
         # duplicate dag runs
-        active_dagruns = (
-            session.query(DagRun.dag_id, DagRun.execution_date)
-            .filter(
-                tuple_(DagRun.dag_id, DagRun.execution_date).in_(
-                    [(dm.dag_id, dm.next_dagrun) for dm in dag_models]
+
+        if session.bind.dialect.name == 'mssql':
+            existing_dagruns_filter = or_(
+                *(
+                    and_(
+                        DagRun.dag_id == dm.dag_id,
+                        DagRun.execution_date == dm.next_dagrun,
+                    )
+                    for dm in dag_models
                 )
             )
-            .all()
+        else:
+            existing_dagruns_filter = tuple_(DagRun.dag_id, DagRun.execution_date).in_(
+                [(dm.dag_id, dm.next_dagrun) for dm in dag_models]
+            )
+
+        existing_dagruns = (
+            session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all()
         )
 
         for dag_model in dag_models:
@@ -1009,89 +992,83 @@ class SchedulerJob(BaseJob):
             # are not updated.
             # We opted to check DagRun existence instead
             # of catching an Integrity error and rolling back the session i.e
-            # we need to run self._update_dag_next_dagruns if the Dag Run already exists or if we
+            # we need to set dag.next_dagrun_info if the Dag Run already exists or if we
             # create a new one. This is so that in the next Scheduling loop we try to create new runs
             # instead of falling in a loop of Integrity Error.
-            if (dag.dag_id, dag_model.next_dagrun) not in active_dagruns:
-                run = dag.create_dagrun(
+            if (dag.dag_id, dag_model.next_dagrun) not in existing_dagruns:
+                dag.create_dagrun(
                     run_type=DagRunType.SCHEDULED,
                     execution_date=dag_model.next_dagrun,
-                    start_date=timezone.utcnow(),
-                    state=State.RUNNING,
+                    state=State.QUEUED,
                     external_trigger=False,
                     session=session,
                     dag_hash=dag_hash,
                     creating_job_id=self.id,
                 )
-
-                expected_start_date = dag.following_schedule(run.execution_date)
-                if expected_start_date:
-                    schedule_delay = run.start_date - expected_start_date
-                    Stats.timing(
-                        f'dagrun.schedule_delay.{dag.dag_id}',
-                        schedule_delay,
-                    )
-
-        self._update_dag_next_dagruns(dag_models, session)
+            dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(
+                dag_model.next_dagrun
+            )
 
         # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in
         # memory for larger dags? or expunge_all()
 
-    def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Session) -> None:
-        """
-        Bulk update the next_dagrun and next_dagrun_create_after for all the dags.
+    def _start_queued_dagruns(
+        self,
+        session: Session,
+    ) -> int:
+        """Find DagRuns in queued state and decide moving them to running state"""
+        dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session)
 
-        We batch the select queries to get info about all the dags at once
-        """
-        # Check max_active_runs, to see if we are _now_ at the limit for any of
-        # these dag? (we've just created a DagRun for them after all)
-        active_runs_of_dags = dict(
+        active_runs_of_dags = defaultdict(
+            lambda: 0,
             session.query(DagRun.dag_id, func.count('*'))
-            .filter(
-                DagRun.dag_id.in_([o.dag_id for o in dag_models]),
+            .filter(  # We use `list` here because SQLA doesn't accept a set
+                # We use set to avoid duplicate dag_ids
+                DagRun.dag_id.in_(list({dr.dag_id for dr in dag_runs})),
                 DagRun.state == State.RUNNING,
-                DagRun.external_trigger.is_(False),
             )
             .group_by(DagRun.dag_id)
-            .all()
+            .all(),
         )
 
-        for dag_model in dag_models:
-            # Get the DAG in a try_except to not stop the Scheduler when a Serialized DAG is not found
-            # This takes care of Dynamic DAGs especially
+        def _update_state(dag_run):
+            dag_run.state = State.RUNNING
+            dag_run.start_date = timezone.utcnow()
+            expected_start_date = dag.following_schedule(dag_run.execution_date)
+            if expected_start_date:
+                schedule_delay = dag_run.start_date - expected_start_date
+                Stats.timing(
+                    f'dagrun.schedule_delay.{dag.dag_id}',
+                    schedule_delay,
+                )
+
+        for dag_run in dag_runs:
             try:
-                dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
+                dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
             except SerializedDagNotFound:
-                self.log.exception("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
+                self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
                 continue
-            active_runs_of_dag = active_runs_of_dags.get(dag.dag_id, 0)
-            if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs:
-                self.log.info(
-                    "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
+            active_runs = active_runs_of_dags[dag_run.dag_id]
+            if dag.max_active_runs and active_runs >= dag.max_active_runs:
+                self.log.debug(
+                    "DAG %s already has %d active runs, not moving any more runs to RUNNING state %s",
                     dag.dag_id,
-                    active_runs_of_dag,
-                    dag.max_active_runs,
+                    active_runs,
+                    dag_run.execution_date,
                 )
-                dag_model.next_dagrun_create_after = None
             else:
-                dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(
-                    dag_model.next_dagrun
-                )
+                active_runs_of_dags[dag_run.dag_id] += 1
+                _update_state(dag_run)
 
     def _schedule_dag_run(
         self,
         dag_run: DagRun,
-        currently_active_runs: Set[datetime.datetime],
         session: Session,
     ) -> int:
         """
         Make scheduling decisions about an individual dag run
 
-        ``currently_active_runs`` is passed in so that a batch query can be
-        used to ask this for all dag runs in the batch, to avoid an n+1 query.
-
         :param dag_run: The DagRun to schedule
-        :param currently_active_runs: Number of currently active runs of this DAG
         :return: Number of tasks scheduled
         """
         dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
@@ -1118,9 +1095,6 @@ class SchedulerJob(BaseJob):
             session.flush()
             self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
 
-            # Work out if we should allow creating a new DagRun now?
-            self._update_dag_next_dagruns([session.query(DagModel).get(dag_run.dag_id)], session)
-
             callback_to_execute = DagCallbackRequest(
                 full_filepath=dag.fileloc,
                 dag_id=dag.dag_id,
@@ -1138,19 +1112,6 @@ class SchedulerJob(BaseJob):
             self.log.error("Execution date is in future: %s", dag_run.execution_date)
             return 0
 
-        if dag.max_active_runs:
-            if (
-                len(currently_active_runs) >= dag.max_active_runs
-                and dag_run.execution_date not in currently_active_runs
-            ):
-                self.log.info(
-                    "DAG %s already has %d active runs, not queuing any tasks for run %s",
-                    dag.dag_id,
-                    len(currently_active_runs),
-                    dag_run.execution_date,
-                )
-                return 0
-
         self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
         # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
         schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
diff --git a/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py
new file mode 100644
index 0000000..03caebc
--- /dev/null
+++ b/airflow/migrations/versions/97cdd93827b8_add_queued_at_column_to_dagrun_table.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add queued_at column to dagrun table
+
+Revision ID: 97cdd93827b8
+Revises: a13f7613ad25
+Create Date: 2021-06-29 21:53:48.059438
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy.dialects import mssql
+
+# revision identifiers, used by Alembic.
+revision = '97cdd93827b8'
+down_revision = 'a13f7613ad25'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    """Apply Add queued_at column to dagrun table"""
+    conn = op.get_bind()
+    if conn.dialect.name == "mssql":
+        op.add_column('dag_run', sa.Column('queued_at', mssql.DATETIME2(precision=6), nullable=True))
+    else:
+        op.add_column('dag_run', sa.Column('queued_at', sa.DateTime(), nullable=True))
+
+
+def downgrade():
+    """Unapply Add queued_at column to dagrun table"""
+    op.drop_column('dag_run', "queued_at")
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 13d69c2..a3d06db 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1153,7 +1153,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=True,
-        dag_run_state: str = State.RUNNING,
+        dag_run_state: str = State.QUEUED,
         dry_run=False,
         session=None,
         get_tis=False,
@@ -1369,7 +1369,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=False,
-        dag_run_state=State.RUNNING,
+        dag_run_state=State.QUEUED,
         dry_run=False,
     ):
         all_tis = []
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 07f309d..a061dcc 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -61,12 +61,15 @@ class DagRun(Base, LoggingMixin):
 
     __tablename__ = "dag_run"
 
+    __NO_VALUE = object()
+
     id = Column(Integer, primary_key=True)
     dag_id = Column(String(ID_LEN))
+    queued_at = Column(UtcDateTime)
     execution_date = Column(UtcDateTime, default=timezone.utcnow)
-    start_date = Column(UtcDateTime, default=timezone.utcnow)
+    start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
-    _state = Column('state', String(50), default=State.RUNNING)
+    _state = Column('state', String(50), default=State.QUEUED)
     run_id = Column(String(ID_LEN))
     creating_job_id = Column(Integer)
     external_trigger = Column(Boolean, default=True)
@@ -102,6 +105,7 @@ class DagRun(Base, LoggingMixin):
         self,
         dag_id: Optional[str] = None,
         run_id: Optional[str] = None,
+        queued_at: Optional[datetime] = __NO_VALUE,
         execution_date: Optional[datetime] = None,
         start_date: Optional[datetime] = None,
         external_trigger: Optional[bool] = None,
@@ -118,6 +122,10 @@ class DagRun(Base, LoggingMixin):
         self.external_trigger = external_trigger
         self.conf = conf or {}
         self.state = state
+        if queued_at is self.__NO_VALUE:
+            self.queued_at = timezone.utcnow() if state == State.QUEUED else None
+        else:
+            self.queued_at = queued_at
         self.run_type = run_type
         self.dag_hash = dag_hash
         self.creating_job_id = creating_job_id
@@ -140,6 +148,8 @@ class DagRun(Base, LoggingMixin):
         if self._state != state:
             self._state = state
             self.end_date = timezone.utcnow() if self._state in State.finished else None
+            if state == State.QUEUED:
+                self.queued_at = timezone.utcnow()
 
     @declared_attr
     def state(self):
@@ -160,6 +170,7 @@ class DagRun(Base, LoggingMixin):
     @classmethod
     def next_dagruns_to_examine(
         cls,
+        state: str,
         session: Session,
         max_number: Optional[int] = None,
     ):
@@ -180,7 +191,7 @@ class DagRun(Base, LoggingMixin):
         # TODO: Bake this query, it is run _A lot_
         query = (
             session.query(cls)
-            .filter(cls.state == State.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB)
+            .filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
             .join(
                 DagModel,
                 DagModel.dag_id == cls.dag_id,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 5fb8155..0e10567 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -137,7 +137,7 @@ def clear_task_instances(
     session,
     activate_dag_runs=None,
     dag=None,
-    dag_run_state: Union[str, Literal[False]] = State.RUNNING,
+    dag_run_state: Union[str, Literal[False]] = State.QUEUED,
 ):
     """
     Clears a set of task instances, but makes sure the running ones
@@ -239,7 +239,9 @@ def clear_task_instances(
         )
         for dr in drs:
             dr.state = dag_run_state
-            dr.start_date = timezone.utcnow()
+            dr.start_date = None
+            if dag_run_state == State.QUEUED:
+                dr.last_scheduling_decision = None
 
 
 class TaskInstanceKey(NamedTuple):
diff --git a/airflow/www/static/js/tree.js b/airflow/www/static/js/tree.js
index 4bf366a..d45c880 100644
--- a/airflow/www/static/js/tree.js
+++ b/airflow/www/static/js/tree.js
@@ -58,7 +58,9 @@ document.addEventListener('DOMContentLoaded', () => {
   const tree = d3.layout.tree().nodeSize([0, 25]);
   let nodes = tree.nodes(data);
   const nodeobj = {};
-  const getActiveRuns = () => data.instances.filter((run) => run.state === 'running').length > 0;
+  const runActiveStates = ['queued', 'running'];
+  const getActiveRuns = () => data.instances
+    .filter((run) => runActiveStates.includes(run.state)).length > 0;
 
   const now = Date.now() / 1000;
   const devicePixelRatio = window.devicePixelRatio || 1;
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 150ecc9..830047b 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1526,7 +1526,7 @@ class Airflow(AirflowBaseView):
         dag.create_dagrun(
             run_type=DagRunType.MANUAL,
             execution_date=execution_date,
-            state=State.RUNNING,
+            state=State.QUEUED,
             conf=run_conf,
             external_trigger=True,
             dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
@@ -3454,6 +3454,7 @@ class DagRunModelView(AirflowModelView):
         'execution_date',
         'run_id',
         'run_type',
+        'queued_at',
         'start_date',
         'end_date',
         'external_trigger',
@@ -3789,7 +3790,7 @@ class TaskInstanceModelView(AirflowModelView):
         lazy_gettext('Clear'),
         lazy_gettext(
             'Are you sure you want to clear the state of the selected task'
-            ' instance(s) and set their dagruns to the running state?'
+            ' instance(s) and set their dagruns to the QUEUED state?'
         ),
         single=False,
     )
diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst
index 0c663da..c5dbd03 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -23,9 +23,7 @@ Here's the list of all the Database Migrations that are executed via when you ru
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
 | Revision ID                    | Revises ID       | Airflow Version | Description                                                                           |
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
-| ``e9304a3141f0`` (head)        | ``83f031fd9f1c`` |                 | Make XCom primary key columns non-nullable                                            |
-+--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
-| ``83f031fd9f1c``               | ``a13f7613ad25`` |                 | Improve MSSQL compatibility                                                           |
+| ``97cdd93827b8`` (head)        | ``a13f7613ad25`` |                 | Add ``queued_at`` column in ``dag_run`` table                                         |
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
 | ``a13f7613ad25``               | ``e165e7455d70`` | ``2.1.0``       | Resource based permissions for default ``Flask-AppBuilder`` views                     |
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py
index 4dab57e..49008d3 100644
--- a/tests/api/common/experimental/test_mark_tasks.py
+++ b/tests/api/common/experimental/test_mark_tasks.py
@@ -414,7 +414,9 @@ class TestMarkDAGRun(unittest.TestCase):
             assert ti.state == state
 
     def _create_test_dag_run(self, state, date):
-        return self.dag1.create_dagrun(run_type=DagRunType.MANUAL, state=state, execution_date=date)
+        return self.dag1.create_dagrun(
+            run_type=DagRunType.MANUAL, state=state, start_date=date, execution_date=date
+        )
 
     def _verify_dag_run_state(self, dag, date, state):
         drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index e51eca8..0aa13b2 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -90,7 +90,7 @@ class TestDagRunEndpoint:
 
     def teardown_method(self) -> None:
         clear_db_runs()
-        # clear_db_dags()
+        clear_db_dags()
 
     def _create_dag(self, dag_id):
         dag_instance = DagModel(dag_id=dag_id)
@@ -118,6 +118,7 @@ class TestDagRunEndpoint:
             execution_date=timezone.parse(self.default_time_2),
             start_date=timezone.parse(self.default_time),
             external_trigger=True,
+            state=state,
         )
         dag_runs.append(dagrun_model_2)
         if extra_dag:
@@ -131,6 +132,7 @@ class TestDagRunEndpoint:
                         execution_date=timezone.parse(self.default_time_2),
                         start_date=timezone.parse(self.default_time),
                         external_trigger=True,
+                        state=state,
                     )
                 )
         if commit:
@@ -193,6 +195,7 @@ class TestGetDagRun(TestDagRunEndpoint):
             execution_date=timezone.parse(self.default_time),
             start_date=timezone.parse(self.default_time),
             external_trigger=True,
+            state='running',
         )
         session.add(dagrun_model)
         session.commit()
@@ -532,7 +535,7 @@ class TestGetDagRunsEndDateFilters(TestDagRunEndpoint):
             (
                 f"api/v1/dags/TEST_DAG_ID/dagRuns?end_date_lte="
                 f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}",
-                ["TEST_DAG_RUN_ID_1"],
+                ["TEST_DAG_RUN_ID_1", "TEST_DAG_RUN_ID_2"],
             ),
         ]
     )
@@ -750,6 +753,7 @@ class TestGetDagRunBatchPagination(TestDagRunEndpoint):
             DagRun(
                 dag_id="TEST_DAG_ID",
                 run_id="TEST_DAG_RUN_ID" + str(i),
+                state='running',
                 run_type=DagRunType.MANUAL,
                 execution_date=timezone.parse(self.default_time) + timedelta(minutes=i),
                 start_date=timezone.parse(self.default_time),
@@ -884,7 +888,7 @@ class TestGetDagRunBatchDateFilters(TestDagRunEndpoint):
             ),
             (
                 {"end_date_lte": f"{(timezone.utcnow() + timedelta(days=1)).isoformat()}"},
-                ["TEST_DAG_RUN_ID_1"],
+                ["TEST_DAG_RUN_ID_1", "TEST_DAG_RUN_ID_2"],
             ),
         ]
     )
@@ -927,8 +931,8 @@ class TestPostDagRun(TestDagRunEndpoint):
             "end_date": None,
             "execution_date": response.json["execution_date"],
             "external_trigger": True,
-            "start_date": response.json["start_date"],
-            "state": "running",
+            "start_date": None,
+            "state": "queued",
         } == response.json
 
     @parameterized.expand(
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py
index 3e6bf2e..9e4a9e8 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -49,6 +49,7 @@ class TestDAGRunSchema(TestDAGRunBase):
     def test_serialize(self, session):
         dagrun_model = DagRun(
             run_id="my-dag-run",
+            state='running',
             run_type=DagRunType.MANUAL.value,
             execution_date=timezone.parse(self.default_time),
             start_date=timezone.parse(self.default_time),
@@ -124,6 +125,7 @@ class TestDagRunCollection(TestDAGRunBase):
     def test_serialize(self, session):
         dagrun_model_1 = DagRun(
             run_id="my-dag-run",
+            state='running',
             execution_date=timezone.parse(self.default_time),
             run_type=DagRunType.MANUAL.value,
             start_date=timezone.parse(self.default_time),
@@ -131,6 +133,7 @@ class TestDagRunCollection(TestDAGRunBase):
         )
         dagrun_model_2 = DagRun(
             run_id="my-dag-run-2",
+            state='running',
             execution_date=timezone.parse(self.default_time),
             start_date=timezone.parse(self.default_time),
             run_type=DagRunType.MANUAL.value,
diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py
index 0ab7f2b..02613ec 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -43,6 +43,7 @@ from airflow.dag_processing.manager import (
 )
 from airflow.dag_processing.processor import DagFileProcessorProcess
 from airflow.jobs.local_task_job import LocalTaskJob as LJ
+from airflow.jobs.scheduler_job import SchedulerJob
 from airflow.models import DagBag, DagModel, TaskInstance as TI
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import SimpleTaskInstance
@@ -508,8 +509,8 @@ class TestDagFileProcessorManager(unittest.TestCase):
             child_pipe.close()
             parent_pipe.close()
 
-    @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.pid", new_callable=PropertyMock)
-    @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.kill")
+    @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock)
+    @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.kill")
     def test_kill_timed_out_processors_kill(self, mock_kill, mock_pid):
         mock_pid.return_value = 1234
         manager = DagFileProcessorManager(
@@ -529,8 +530,8 @@ class TestDagFileProcessorManager(unittest.TestCase):
         manager._kill_timed_out_processors()
         mock_kill.assert_called_once_with()
 
-    @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess.pid", new_callable=PropertyMock)
-    @mock.patch("airflow.jobs.scheduler_job.DagFileProcessorProcess")
+    @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess.pid", new_callable=PropertyMock)
+    @mock.patch("airflow.dag_processing.processor.DagFileProcessorProcess")
     def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_pid):
         mock_pid.return_value = 1234
         manager = DagFileProcessorManager(
@@ -560,7 +561,6 @@ class TestDagFileProcessorManager(unittest.TestCase):
         # We need to _actually_ parse the files here to test the behaviour.
         # Right now the parsing code lives in SchedulerJob, even though it's
         # called via utils.dag_processing.
-        from airflow.jobs.scheduler_job import SchedulerJob
 
         dag_id = 'exit_test_dag'
         dag_directory = TEST_DAG_FOLDER.parent / 'dags_with_system_exit'
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index 5953517..feb3497 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -277,7 +277,7 @@ class TestDagFileProcessor(unittest.TestCase):
         assert email1 in send_email_to
         assert email2 not in send_email_to
 
-    @mock.patch('airflow.jobs.scheduler_job.Stats.incr')
+    @mock.patch('airflow.dag_processing.processor.Stats.incr')
     @mock.patch("airflow.utils.email.send_email")
     def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
         """
@@ -387,7 +387,7 @@ class TestDagFileProcessor(unittest.TestCase):
             ti.start_date = start_date
             ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+            count = self.scheduler_job._schedule_dag_run(dr, session)
             assert count == 1
 
             session.refresh(ti)
@@ -444,7 +444,7 @@ class TestDagFileProcessor(unittest.TestCase):
             ti.start_date = start_date
             ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+            count = self.scheduler_job._schedule_dag_run(dr, session)
             assert count == 1
 
             session.refresh(ti)
@@ -504,7 +504,7 @@ class TestDagFileProcessor(unittest.TestCase):
                 ti.start_date = start_date
                 ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+            count = self.scheduler_job._schedule_dag_run(dr, session)
             assert count == 2
 
             session.refresh(tis[0])
@@ -547,7 +547,7 @@ class TestDagFileProcessor(unittest.TestCase):
         BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
         SerializedDagModel.write_dag(dag=dag)
 
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session)
+        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
         session.flush()
         assert scheduled_tis == 2
 
@@ -560,11 +560,10 @@ class TestDagFileProcessor(unittest.TestCase):
 
     def test_runs_respected_after_clear(self):
         """
-        Test if _process_task_instances only schedules ti's up to max_active_runs
-        (related to issue AIRFLOW-137)
+        Test dag after dag.clear, max_active_runs is respected
         """
         dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
-        dag.max_active_runs = 3
+        dag.max_active_runs = 1
 
         BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
 
@@ -575,48 +574,46 @@ class TestDagFileProcessor(unittest.TestCase):
         session.close()
         dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
 
+        # Write Dag to DB
+        dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
+        dagbag.bag_dag(dag, root_dag=dag)
+        dagbag.sync_to_db()
+
+        dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag.dag_id)
+
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
         self.scheduler_job.processor_agent = mock.MagicMock()
         self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-        dag.clear()
 
         date = DEFAULT_DATE
-        dr1 = dag.create_dagrun(
+        dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=date,
-            state=State.RUNNING,
+            state=State.QUEUED,
         )
         date = dag.following_schedule(date)
-        dr2 = dag.create_dagrun(
+        dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=date,
-            state=State.RUNNING,
+            state=State.QUEUED,
         )
         date = dag.following_schedule(date)
-        dr3 = dag.create_dagrun(
+        dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=date,
-            state=State.RUNNING,
+            state=State.QUEUED,
         )
+        dag.clear()
 
-        # First create up to 3 dagruns in RUNNING state.
-        assert dr1 is not None
-        assert dr2 is not None
-        assert dr3 is not None
-        assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3
-
-        # Reduce max_active_runs to 1
-        dag.max_active_runs = 1
+        assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 3
 
-        # and schedule them in, so we can check how many
-        # tasks are put on the task_instances_list (should be one, not 3)
-        with create_session() as session:
-            num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session)
-            assert num_scheduled == 1
-            num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session)
-            assert num_scheduled == 0
-            num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session)
-            assert num_scheduled == 0
+        session = settings.Session()
+        self.scheduler_job._start_queued_dagruns(session)
+        session.commit()
+        # Assert that only 1 dagrun is active
+        assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1
+        # Assert that the other two are queued
+        assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 2
 
     @patch.object(TaskInstance, 'handle_failure_with_callback')
     def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
@@ -698,7 +695,7 @@ class TestDagFileProcessor(unittest.TestCase):
         dr = drs[0]
 
         # Schedule TaskInstances
-        self.scheduler_job_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job_job._schedule_dag_run(dr, session)
         with create_session() as session:
             tis = session.query(TaskInstance).all()
 
@@ -724,7 +721,7 @@ class TestDagFileProcessor(unittest.TestCase):
                 assert end_date is None
                 assert duration is None
 
-        self.scheduler_job_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job_job._schedule_dag_run(dr, session)
         with create_session() as session:
             tis = session.query(TaskInstance).all()
 
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 9fe8517..0ee6f5f 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -23,7 +23,6 @@ import shutil
 import unittest
 from datetime import timedelta
 from tempfile import mkdtemp
-from time import sleep
 from unittest import mock
 from unittest.mock import MagicMock, patch
 from zipfile import ZipFile
@@ -430,7 +429,6 @@ class TestSchedulerJob(unittest.TestCase):
         session.flush()
 
         res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session)
-
         assert 2 == len(res)
         res_keys = map(lambda x: x.key, res)
         assert ti_no_dagrun.key in res_keys
@@ -1575,15 +1573,16 @@ class TestSchedulerJob(unittest.TestCase):
 
         dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
 
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
         self.scheduler_job._create_dag_runs([orm_dag], session)
+        self.scheduler_job._start_queued_dagruns(session)
 
         drs = DagRun.find(dag_id=dag.dag_id, session=session)
         assert len(drs) == 1
         dr = drs[0]
 
-        # Should not be able to create a new dag run, as we are at max active runs
-        assert orm_dag.next_dagrun_create_after is None
+        # This should have a value since we control max_active_runs
+        # by DagRun State.
+        assert orm_dag.next_dagrun_create_after
         # But we should record the date of _what run_ it would be
         assert isinstance(orm_dag.next_dagrun, datetime.datetime)
 
@@ -1595,7 +1594,7 @@ class TestSchedulerJob(unittest.TestCase):
         self.scheduler_job.processor_agent = mock.Mock()
         self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock()
 
-        self.scheduler_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
         session.flush()
 
         session.refresh(dr)
@@ -1652,7 +1651,7 @@ class TestSchedulerJob(unittest.TestCase):
         self.scheduler_job.processor_agent = mock.Mock()
         self.scheduler_job.processor_agent.send_callback_to_execute = mock.Mock()
 
-        self.scheduler_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
         session.flush()
 
         session.refresh(dr)
@@ -1711,7 +1710,7 @@ class TestSchedulerJob(unittest.TestCase):
         ti = dr.get_task_instance('dummy')
         ti.set_state(state, session)
 
-        self.scheduler_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
 
         expected_callback = DagCallbackRequest(
             full_filepath=dr.dag.fileloc,
@@ -1766,7 +1765,7 @@ class TestSchedulerJob(unittest.TestCase):
         ti = dr.get_task_instance('test_task')
         ti.set_state(state, session)
 
-        self.scheduler_job._schedule_dag_run(dr, set(), session)
+        self.scheduler_job._schedule_dag_run(dr, session)
 
         # Verify Callback is not set (i.e is None) when no callbacks are set on DAG
         self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None)
@@ -2146,13 +2145,13 @@ class TestSchedulerJob(unittest.TestCase):
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
         )
-        self.scheduler_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
         dr = dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=dag.following_schedule(dr.execution_date),
             state=State.RUNNING,
         )
-        self.scheduler_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
         task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
             max_tis=32, session=session
         )
@@ -2203,7 +2202,7 @@ class TestSchedulerJob(unittest.TestCase):
                 execution_date=date,
                 state=State.RUNNING,
             )
-            self.scheduler_job._schedule_dag_run(dr, {}, session)
+            self.scheduler_job._schedule_dag_run(dr, session)
             date = dag.following_schedule(date)
 
         task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
@@ -2266,7 +2265,7 @@ class TestSchedulerJob(unittest.TestCase):
                 execution_date=date,
                 state=State.RUNNING,
             )
-            scheduler._schedule_dag_run(dr, {}, session)
+            scheduler._schedule_dag_run(dr, session)
             date = dag_d1.following_schedule(date)
 
         date = DEFAULT_DATE
@@ -2276,7 +2275,7 @@ class TestSchedulerJob(unittest.TestCase):
                 execution_date=date,
                 state=State.RUNNING,
             )
-            scheduler._schedule_dag_run(dr, {}, session)
+            scheduler._schedule_dag_run(dr, session)
             date = dag_d2.following_schedule(date)
 
         scheduler._executable_task_instances_to_queued(max_tis=2, session=session)
@@ -2353,7 +2352,7 @@ class TestSchedulerJob(unittest.TestCase):
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
         )
-        self.scheduler_job._schedule_dag_run(dr, {}, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
 
         task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
             max_tis=32, session=session
@@ -2412,7 +2411,7 @@ class TestSchedulerJob(unittest.TestCase):
 
         # Verify that DagRun.verify_integrity is not called
         with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity:
-            scheduled_tis = self.scheduler_job._schedule_dag_run(dr, {}, session)
+            scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
             mock_verify_integrity.assert_not_called()
         session.flush()
 
@@ -2475,7 +2474,7 @@ class TestSchedulerJob(unittest.TestCase):
         dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
         assert dag_version_2 != dag_version_1
 
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, {}, session)
+        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
         session.flush()
 
         assert scheduled_tis == 2
@@ -3187,14 +3186,13 @@ class TestSchedulerJob(unittest.TestCase):
                 full_filepath=dag.fileloc, dag_id=dag_id
             )
 
-    @freeze_time(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9))
-    @mock.patch('airflow.jobs.scheduler_job.Stats.timing')
-    def test_create_dag_runs(self, stats_timing):
+    def test_create_dag_runs(self):
         """
         Test various invariants of _create_dag_runs.
 
         - That the run created has the creating_job_id set
-        - That we emit the right DagRun metrics
+        - That the run created is on QUEUED State
+        - That dag_model has next_dagrun
         """
         dag = DAG(dag_id='test_create_dag_runs', start_date=DEFAULT_DATE)
 
@@ -3218,8 +3216,51 @@ class TestSchedulerJob(unittest.TestCase):
         with create_session() as session:
             self.scheduler_job._create_dag_runs([dag_model], session)
 
+        dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first()
+        # Assert dr state is queued
+        assert dr.state == State.QUEUED
+        assert dr.start_date is None
+
+        assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id
+
+    @freeze_time(DEFAULT_DATE + datetime.timedelta(days=1, seconds=9))
+    @mock.patch('airflow.jobs.scheduler_job.Stats.timing')
+    def test_start_dagruns(self, stats_timing):
+        """
+        Test that _start_dagrun:
+
+        - moves runs to RUNNING State
+        - emit the right DagRun metrics
+        """
+        dag = DAG(dag_id='test_start_dag_runs', start_date=DEFAULT_DATE)
+
+        DummyOperator(
+            task_id='dummy',
+            dag=dag,
+        )
+
+        dagbag = DagBag(
+            dag_folder=os.devnull,
+            include_examples=False,
+            read_dags_from_db=True,
+        )
+        dagbag.bag_dag(dag=dag, root_dag=dag)
+        dagbag.sync_to_db()
+        dag_model = DagModel.get_dagmodel(dag.dag_id)
+
+        self.scheduler_job = SchedulerJob(executor=self.null_exec)
+        self.scheduler_job.processor_agent = mock.MagicMock()
+
+        with create_session() as session:
+            self.scheduler_job._create_dag_runs([dag_model], session)
+            self.scheduler_job._start_queued_dagruns(session)
+
+        dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).first()
+        # Assert dr state is running
+        assert dr.state == State.RUNNING
+
         stats_timing.assert_called_once_with(
-            "dagrun.schedule_delay.test_create_dag_runs", datetime.timedelta(seconds=9)
+            "dagrun.schedule_delay.test_start_dag_runs", datetime.timedelta(seconds=9)
         )
 
         assert dag.get_last_dagrun().creating_job_id == self.scheduler_job.id
@@ -3418,61 +3459,7 @@ class TestSchedulerJob(unittest.TestCase):
         assert dag_model.next_dagrun == DEFAULT_DATE + timedelta(days=1)
         session.rollback()
 
-    def test_do_schedule_max_active_runs_upstream_failed(self):
-        """
-        Test that tasks in upstream failed don't count as actively running.
-
-        This test can be removed when adding a queued state to DagRuns.
-        """
-
-        with DAG(
-            dag_id='test_max_active_run_with_upstream_failed',
-            start_date=DEFAULT_DATE,
-            schedule_interval='@once',
-            max_active_runs=1,
-        ) as dag:
-            # Can't use DummyOperator as that goes straight to success
-            task1 = BashOperator(task_id='dummy1', bash_command='true')
-
-        session = settings.Session()
-        dagbag = DagBag(
-            dag_folder=os.devnull,
-            include_examples=False,
-            read_dags_from_db=True,
-        )
-
-        dagbag.bag_dag(dag=dag, root_dag=dag)
-        dagbag.sync_to_db(session=session)
-
-        run1 = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-            session=session,
-        )
-
-        ti = run1.get_task_instance(task1.task_id, session)
-        ti.state = State.UPSTREAM_FAILED
-
-        run2 = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=DEFAULT_DATE + timedelta(hours=1),
-            state=State.RUNNING,
-            session=session,
-        )
-
-        dag.sync_to_db(session=session)  # Update the date fields
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.executor = MockExecutor(do_update=False)
-        self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
-
-        num_queued = self.scheduler_job._do_scheduling(session)
-
-        assert num_queued == 1
-        ti = run2.get_task_instance(task1.task_id, session)
-        assert ti.state == State.QUEUED
-
+    @conf_vars({('scheduler', 'use_job_schedule'): "false"})
     def test_do_schedule_max_active_runs_dag_timed_out(self):
         """Test that tasks are set to a finished state when their DAG times out"""
 
@@ -3505,33 +3492,36 @@ class TestSchedulerJob(unittest.TestCase):
             run_type=DagRunType.SCHEDULED,
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
+            start_date=timezone.utcnow() - timedelta(seconds=2),
             session=session,
         )
+
         run1_ti = run1.get_task_instance(task1.task_id, session)
         run1_ti.state = State.RUNNING
 
-        sleep(1)
-
         run2 = dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=DEFAULT_DATE + timedelta(seconds=10),
-            state=State.RUNNING,
+            state=State.QUEUED,
             session=session,
         )
 
         dag.sync_to_db(session=session)
-
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
         self.scheduler_job.executor = MockExecutor()
         self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
 
-        _ = self.scheduler_job._do_scheduling(session)
-
+        self.scheduler_job._do_scheduling(session)
+        session.add(run1)
+        session.refresh(run1)
         assert run1.state == State.FAILED
         assert run1_ti.state == State.SKIPPED
-        assert run2.state == State.RUNNING
 
-        _ = self.scheduler_job._do_scheduling(session)
+        # Run scheduling again to assert run2 has started
+        self.scheduler_job._do_scheduling(session)
+        session.add(run2)
+        session.refresh(run2)
+        assert run2.state == State.RUNNING
         run2_ti = run2.get_task_instance(task1.task_id, session)
         assert run2_ti.state == State.QUEUED
 
@@ -3581,8 +3571,8 @@ class TestSchedulerJob(unittest.TestCase):
 
     def test_do_schedule_max_active_runs_and_manual_trigger(self):
         """
-        Make sure that when a DAG is already at max_active_runs, that manually triggering a run doesn't cause
-        the dag to "stall".
+        Make sure that when a DAG is already at max_active_runs, that manually triggered
+        dagruns don't start running.
         """
 
         with DAG(
@@ -3597,7 +3587,7 @@ class TestSchedulerJob(unittest.TestCase):
 
             task1 >> task2
 
-            task3 = BashOperator(task_id='dummy3', bash_command='true')
+            BashOperator(task_id='dummy3', bash_command='true')
 
         session = settings.Session()
         dagbag = DagBag(
@@ -3612,7 +3602,7 @@ class TestSchedulerJob(unittest.TestCase):
         dag_run = dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
+            state=State.QUEUED,
             session=session,
         )
 
@@ -3630,47 +3620,23 @@ class TestSchedulerJob(unittest.TestCase):
 
         assert num_queued == 2
         assert dag_run.state == State.RUNNING
-        ti1 = dag_run.get_task_instance(task1.task_id, session)
-        assert ti1.state == State.QUEUED
-
-        # Set task1 to success (so task2 can run) but keep task3 as "running"
-        ti1.state = State.SUCCESS
-
-        ti3 = dag_run.get_task_instance(task3.task_id, session)
-        ti3.state = State.RUNNING
-
-        session.flush()
-
-        # At this point, ti2 and ti3 of the scheduled dag run should be running
-        num_queued = self.scheduler_job._do_scheduling(session)
-
-        assert num_queued == 1
-        # Should have queued task2
-        ti2 = dag_run.get_task_instance(task2.task_id, session)
-        assert ti2.state == State.QUEUED
-
-        ti2.state = None
-        session.flush()
 
         # Now that this one is running, manually trigger a dag.
 
-        manual_run = dag.create_dagrun(
+        dag.create_dagrun(
             run_type=DagRunType.MANUAL,
             execution_date=DEFAULT_DATE + timedelta(hours=1),
-            state=State.RUNNING,
+            state=State.QUEUED,
             session=session,
         )
         session.flush()
 
-        num_queued = self.scheduler_job._do_scheduling(session)
+        self.scheduler_job._do_scheduling(session)
 
-        assert num_queued == 1
-        # Should have queued task2 again.
-        ti2 = dag_run.get_task_instance(task2.task_id, session)
-        assert ti2.state == State.QUEUED
-        # Manual run shouldn't have been started, because we're at max_active_runs with DR1
-        ti1 = manual_run.get_task_instance(task1.task_id, session)
-        assert ti1.state is None
+        # Assert that only 1 dagrun is active
+        assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 1
+        # Assert that the other one is queued
+        assert len(DagRun.find(dag_id=dag.dag_id, state=State.QUEUED, session=session)) == 1
 
 
 @pytest.mark.xfail(reason="Work out where this goes")
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index 9b8fbd0..4f64347 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -19,6 +19,8 @@
 import datetime
 import unittest
 
+from parameterized import parameterized
+
 from airflow import settings
 from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
 from airflow.operators.dummy import DummyOperator
@@ -92,6 +94,41 @@ class TestClearTasks(unittest.TestCase):
             assert ti0.state is None
             assert ti0.external_executor_id is None
 
+    @parameterized.expand([(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)])
+    def test_clear_task_instances_dr_state(self, state, last_scheduling):
+        """Test that DR state is set to None after clear.
+        And that DR.last_scheduling_decision is handled OK.
+        start_date is also set to None
+        """
+        dag = DAG(
+            'test_clear_task_instances',
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+        )
+        task0 = DummyOperator(task_id='0', owner='test', dag=dag)
+        task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2)
+        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
+        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
+        session = settings.Session()
+        dr = dag.create_dagrun(
+            execution_date=ti0.execution_date,
+            state=State.RUNNING,
+            run_type=DagRunType.SCHEDULED,
+        )
+        dr.last_scheduling_decision = DEFAULT_DATE
+        session.add(dr)
+        session.commit()
+
+        ti0.run()
+        ti1.run()
+        qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+        clear_task_instances(qry, session, dag_run_state=state, dag=dag)
+
+        dr = ti0.get_dagrun()
+        assert dr.state == state
+        assert dr.start_date is None
+        assert dr.last_scheduling_decision == last_scheduling
+
     def test_clear_task_instances_without_task(self):
         dag = DAG(
             'test_clear_task_instances_without_task',
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 7899199..fac38bf 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -39,7 +39,7 @@ from airflow.utils.state import State
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE
-from tests.test_utils.db import clear_db_jobs, clear_db_pools, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_jobs, clear_db_pools, clear_db_runs
 
 
 class TestDagRun(unittest.TestCase):
@@ -50,6 +50,12 @@ class TestDagRun(unittest.TestCase):
     def setUp(self):
         clear_db_runs()
         clear_db_pools()
+        clear_db_dags()
+
+    def tearDown(self) -> None:
+        clear_db_runs()
+        clear_db_pools()
+        clear_db_dags()
 
     def create_dag_run(
         self,
@@ -102,7 +108,7 @@ class TestDagRun(unittest.TestCase):
         session.commit()
         ti0.refresh_from_db()
         dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == now).first()
-        assert dr0.state == State.RUNNING
+        assert dr0.state == State.QUEUED
 
     def test_dagrun_find(self):
         session = settings.Session()
@@ -692,9 +698,11 @@ class TestDagRun(unittest.TestCase):
         ti.run()
         assert (ti.state == State.SUCCESS) == is_ti_success
 
-    def test_next_dagruns_to_examine_only_unpaused(self):
+    @parameterized.expand([(State.QUEUED,), (State.RUNNING,)])
+    def test_next_dagruns_to_examine_only_unpaused(self, state):
         """
         Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs
+        and gets running/queued dagruns
         """
 
         dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
@@ -712,25 +720,22 @@ class TestDagRun(unittest.TestCase):
         session.flush()
         dr = dag.create_dagrun(
             run_type=DagRunType.SCHEDULED,
-            state=State.RUNNING,
+            state=state,
             execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE if state == State.RUNNING else None,
             session=session,
         )
 
-        runs = DagRun.next_dagruns_to_examine(session).all()
+        runs = DagRun.next_dagruns_to_examine(state, session).all()
 
         assert runs == [dr]
 
         orm_dag.is_paused = True
         session.flush()
 
-        runs = DagRun.next_dagruns_to_examine(session).all()
+        runs = DagRun.next_dagruns_to_examine(state, session).all()
         assert runs == []
 
-        session.rollback()
-        session.close()
-
     @mock.patch.object(Stats, 'timing')
     def test_no_scheduling_delay_for_nonscheduled_runs(self, stats_mock):
         """
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 187fdb2..274766c 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -545,16 +545,16 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child):
     task_0 = dag_0.get_task("task_0")
     clear_tasks(dag_bag, dag_0, task_0, start_date=day_1, end_date=day_2)
 
-    # Assert that dagruns of all the affected dags are set to RUNNING after tasks are cleared.
+    # Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared.
     # Unaffected dagruns should be left as SUCCESS.
     dagrun_0_1 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_1)
     dagrun_0_2 = dag_bag.get_dag('parent_dag_0').get_dagrun(execution_date=day_2)
     dagrun_1_1 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_1)
     dagrun_1_2 = dag_bag.get_dag('child_dag_1').get_dagrun(execution_date=day_2)
 
-    assert dagrun_0_1.state == State.RUNNING
-    assert dagrun_0_2.state == State.RUNNING
-    assert dagrun_1_1.state == State.RUNNING
+    assert dagrun_0_1.state == State.QUEUED
+    assert dagrun_0_2.state == State.QUEUED
+    assert dagrun_1_1.state == State.QUEUED
     assert dagrun_1_2.state == State.SUCCESS
 
 

[airflow] 06/09: Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 343beb65685adc4b87107a18a43c509731985499
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Jul 20 18:48:35 2021 +0100

    Add Pytest fixture to create dag and dagrun and use it on local task job tests (#16889)
    
    This change adds pytest fixture to create dag and dagrun then use it on local task job tests
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
    (cherry picked from commit 7c0d8a2f83cc6db25bdddcf6cecb6fb56f05f02f)
---
 tests/conftest.py                 |  50 +++++++++++
 tests/jobs/test_local_task_job.py | 178 +++++++++++++-------------------------
 2 files changed, 111 insertions(+), 117 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 55e1593..f2c5345 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -425,3 +425,53 @@ def app():
     from airflow.www import app
 
     return app.create_app(testing=True)
+
+
+@pytest.fixture
+def dag_maker(request):
+    from airflow.models import DAG
+    from airflow.utils import timezone
+    from airflow.utils.state import State
+
+    DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+    class DagFactory:
+        def __enter__(self):
+            self.dag.__enter__()
+            return self.dag
+
+        def __exit__(self, type, value, traceback):
+            dag = self.dag
+            dag.__exit__(type, value, traceback)
+            if type is None:
+                dag.clear()
+                self.dag_run = dag.create_dagrun(
+                    run_id=self.kwargs.get("run_id", "test"),
+                    state=self.kwargs.get('state', State.RUNNING),
+                    execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']),
+                    start_date=self.kwargs['start_date'],
+                )
+
+        def __call__(self, dag_id='test_dag', **kwargs):
+            self.kwargs = kwargs
+            if "start_date" not in kwargs:
+                if hasattr(request.module, 'DEFAULT_DATE'):
+                    kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE')
+                else:
+                    kwargs['start_date'] = DEFAULT_DATE
+            dagrun_fields_not_in_dag = [
+                'state',
+                'execution_date',
+                'run_type',
+                'queued_at',
+                "run_id",
+                "creating_job_id",
+                "external_trigger",
+                "last_scheduling_decision",
+                "dag_hash",
+            ]
+            kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag}
+            self.dag = DAG(dag_id, **kwargs)
+            return self
+
+    return DagFactory()
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 11e9adf..d9f1398 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -20,7 +20,6 @@ import multiprocessing
 import os
 import signal
 import time
-import unittest
 import uuid
 from multiprocessing import Lock, Value
 from unittest import mock
@@ -57,21 +56,30 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']
 
 
-class TestLocalTaskJob(unittest.TestCase):
-    def setUp(self):
-        db.clear_db_dags()
-        db.clear_db_jobs()
-        db.clear_db_runs()
-        db.clear_db_task_fail()
-        patcher = patch('airflow.jobs.base_job.sleep')
-        self.addCleanup(patcher.stop)
-        self.mock_base_job_sleep = patcher.start()
+@pytest.fixture
+def clear_db():
+    db.clear_db_dags()
+    db.clear_db_jobs()
+    db.clear_db_runs()
+    db.clear_db_task_fail()
+    yield
+
+
+@pytest.fixture(scope='class')
+def clear_db_class():
+    yield
+    db.clear_db_dags()
+    db.clear_db_jobs()
+    db.clear_db_runs()
+    db.clear_db_task_fail()
+
 
-    def tearDown(self) -> None:
-        db.clear_db_dags()
-        db.clear_db_jobs()
-        db.clear_db_runs()
-        db.clear_db_task_fail()
+@pytest.mark.usefixtures('clear_db_class', 'clear_db')
+class TestLocalTaskJob:
+    @pytest.fixture(autouse=True)
+    def set_instance_attrs(self):
+        with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep:
+            yield
 
     def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
         for task_id, expected_state in ti_state_mapping.items():
@@ -79,23 +87,19 @@ class TestLocalTaskJob(unittest.TestCase):
             task_instance.refresh_from_db()
             assert task_instance.state == expected_state, error_message
 
-    def test_localtaskjob_essential_attr(self):
+    def test_localtaskjob_essential_attr(self, dag_maker):
         """
         Check whether essential attributes
         of LocalTaskJob can be assigned with
         proper values without intervention
         """
-        dag = DAG(
+        with dag_maker(
             'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
-        )
-
-        with dag:
+        ):
             op1 = DummyOperator(task_id='op1')
 
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE
-        )
+        dr = dag_maker.dag_run
+
         ti = dr.get_task_instance(task_id=op1.task_id)
 
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -108,21 +112,12 @@ class TestLocalTaskJob(unittest.TestCase):
         check_result_2 = [getattr(job1, attr) is not None for attr in essential_attr]
         assert all(check_result_2)
 
-    def test_localtaskjob_heartbeat(self):
+    def test_localtaskjob_heartbeat(self, dag_maker):
         session = settings.Session()
-        dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
-        with dag:
+        with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1')
 
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_id="test",
-            state=State.SUCCESS,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
+        dr = dag_maker.dag_run
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.hostname = "blablabla"
@@ -150,22 +145,11 @@ class TestLocalTaskJob(unittest.TestCase):
             job1.heartbeat_callback()
 
     @mock.patch('airflow.jobs.local_task_job.psutil')
-    def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock):
+    def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, dag_maker):
         session = settings.Session()
-        dag = DAG('test_localtaskjob_heartbeat', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
-        with dag:
+        with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1', run_as_user='myuser')
-
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_id="test",
-            state=State.SUCCESS,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
-
+        dr = dag_maker.dag_run
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.pid = 2
@@ -248,7 +232,8 @@ class TestLocalTaskJob(unittest.TestCase):
         Test that task heartbeat will sleep when it fails fast
         """
         self.mock_base_job_sleep.side_effect = time.sleep
-
+        dag_id = 'test_heartbeat_failed_fast'
+        task_id = 'test_heartbeat_failed_fast_op'
         with create_session() as session:
             dagbag = DagBag(
                 dag_folder=TEST_DAG_FOLDER,
@@ -266,6 +251,7 @@ class TestLocalTaskJob(unittest.TestCase):
                 start_date=DEFAULT_DATE,
                 session=session,
             )
+
             ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
             ti.refresh_from_db()
             ti.state = State.RUNNING
@@ -331,6 +317,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert State.SUCCESS == ti.state
 
     def test_localtaskjob_double_trigger(self):
+
         dagbag = DagBag(
             dag_folder=TEST_DAG_FOLDER,
             include_examples=False,
@@ -348,6 +335,7 @@ class TestLocalTaskJob(unittest.TestCase):
             start_date=DEFAULT_DATE,
             session=session,
         )
+
         ti = dr.get_task_instance(task_id=task.task_id, session=session)
         ti.state = State.RUNNING
         ti.hostname = get_hostname()
@@ -418,7 +406,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert time_end - time_start < job1.heartrate
         session.close()
 
-    def test_mark_failure_on_failure_callback(self):
+    def test_mark_failure_on_failure_callback(self, dag_maker):
         """
         Test that ensures that mark_failure in the UI fails
         the task, and executes on_failure_callback
@@ -447,22 +435,12 @@ class TestLocalTaskJob(unittest.TestCase):
             with task_terminated_externally.get_lock():
                 task_terminated_externally.value = 0
 
-        with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag:
+        with dag_maker("test_mark_failure", start_date=DEFAULT_DATE):
             task = PythonOperator(
                 task_id='test_state_succeeded1',
                 python_callable=task_function,
                 on_failure_callback=check_failure,
             )
-
-        dag.clear()
-        with create_session() as session:
-            dag.create_dagrun(
-                run_id="test",
-                state=State.RUNNING,
-                execution_date=DEFAULT_DATE,
-                start_date=DEFAULT_DATE,
-                session=session,
-            )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
 
@@ -479,7 +457,7 @@ class TestLocalTaskJob(unittest.TestCase):
 
     @patch('airflow.utils.process_utils.subprocess.check_call')
     @patch.object(StandardTaskRunner, 'return_code')
-    def test_failure_callback_only_called_once(self, mock_return_code, _check_call):
+    def test_failure_callback_only_called_once(self, mock_return_code, _check_call, dag_maker):
         """
         Test that ensures that when a task exits with failure by itself,
         failure callback is only called once
@@ -498,22 +476,11 @@ class TestLocalTaskJob(unittest.TestCase):
         def task_function(ti):
             raise AirflowFailException()
 
-        dag = DAG(dag_id='test_failure_callback_race', start_date=DEFAULT_DATE)
-        task = PythonOperator(
-            task_id='test_exit_on_failure',
-            python_callable=task_function,
-            on_failure_callback=failure_callback,
-            dag=dag,
-        )
-
-        dag.clear()
-        with create_session() as session:
-            dag.create_dagrun(
-                run_id="test",
-                state=State.RUNNING,
-                execution_date=DEFAULT_DATE,
-                start_date=DEFAULT_DATE,
-                session=session,
+        with dag_maker("test_failure_callback_race"):
+            task = PythonOperator(
+                task_id='test_exit_on_failure',
+                python_callable=task_function,
+                on_failure_callback=failure_callback,
             )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
@@ -544,7 +511,7 @@ class TestLocalTaskJob(unittest.TestCase):
         assert failure_callback_called.value == 1
 
     @pytest.mark.quarantined
-    def test_mark_success_on_success_callback(self):
+    def test_mark_success_on_success_callback(self, dag_maker):
         """
         Test that ensures that where a task is marked success in the UI
         on_success_callback gets executed
@@ -560,8 +527,6 @@ class TestLocalTaskJob(unittest.TestCase):
                 success_callback_called.value += 1
             assert context['dag_run'].dag_id == 'test_mark_success'
 
-        dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
         def task_function(ti):
 
             time.sleep(60)
@@ -569,23 +534,15 @@ class TestLocalTaskJob(unittest.TestCase):
             with shared_mem_lock:
                 task_terminated_externally.value = 0
 
-        task = PythonOperator(
-            task_id='test_state_succeeded1',
-            python_callable=task_function,
-            on_success_callback=success_callback,
-            dag=dag,
-        )
+        with dag_maker(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
+            task = PythonOperator(
+                task_id='test_state_succeeded1',
+                python_callable=task_function,
+                on_success_callback=success_callback,
+            )
 
         session = settings.Session()
 
-        dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
@@ -616,7 +573,7 @@ class TestLocalTaskJob(unittest.TestCase):
             (signal.SIGKILL,),
         ]
     )
-    def test_process_kill_calls_on_failure_callback(self, signal_type):
+    def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker):
         """
         Test that ensures that when a task is killed with sigterm or sigkill
         on_failure_callback gets executed
@@ -632,8 +589,6 @@ class TestLocalTaskJob(unittest.TestCase):
                 failure_callback_called.value += 1
             assert context['dag_run'].dag_id == 'test_mark_failure'
 
-        dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
-
         def task_function(ti):
 
             time.sleep(60)
@@ -641,23 +596,12 @@ class TestLocalTaskJob(unittest.TestCase):
             with shared_mem_lock:
                 task_terminated_externally.value = 0
 
-        task = PythonOperator(
-            task_id='test_on_failure',
-            python_callable=task_function,
-            on_failure_callback=failure_callback,
-            dag=dag,
-        )
-
-        session = settings.Session()
-
-        dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
+        with dag_maker(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                on_failure_callback=failure_callback,
+            )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())

[airflow] 07/09: Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 12f2467181c4c521a0072315c8ab9c66e3bf553a
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 28 15:57:35 2021 +0100

    Fix task retries when they receive sigkill and have retries and properly handle sigterm (#16301)
    
    Currently, tasks are not retried when they receive SIGKILL or SIGTERM even if the task has retry. This change fixes it
    and added test for both SIGTERM and SIGKILL so we don't experience regression
    
    Also, SIGTERM sets the task as failed and raises AirflowException which heartbeat sometimes see as externally set to fail
    and not call failure_callbacks. This commit also fixes this by calling handle_task_exit when a task gets SIGTERM
    
    Co-authored-by: Ash Berlin-Taylor <as...@firemirror.com>
    (cherry picked from commit 4e2a94c6d1bde5ddf2aa0251190c318ac22f3b17)
---
 airflow/jobs/local_task_job.py    |  24 +++---
 tests/jobs/test_local_task_job.py | 166 +++++++++++++++++++++++++++++++++-----
 tests/models/test_taskinstance.py |  32 ++++++++
 3 files changed, 189 insertions(+), 33 deletions(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index cce4e64..70ef60c 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -78,12 +78,9 @@ class LocalTaskJob(BaseJob):
         def signal_handler(signum, frame):
             """Setting kill signal handler"""
             self.log.error("Received SIGTERM. Terminating subprocesses")
-            self.on_kill()
-            self.task_instance.refresh_from_db()
-            if self.task_instance.state not in State.finished:
-                self.task_instance.set_state(State.FAILED)
-            self.task_instance._run_finished_callback(error="task received sigterm")
-            raise AirflowException("LocalTaskJob received SIGTERM signal")
+            self.task_runner.terminate()
+            self.handle_task_exit(128 + signum)
+            return
 
         signal.signal(signal.SIGTERM, signal_handler)
 
@@ -148,16 +145,19 @@ class LocalTaskJob(BaseJob):
             self.on_kill()
 
     def handle_task_exit(self, return_code: int) -> None:
-        """Handle case where self.task_runner exits by itself"""
+        """Handle case where self.task_runner exits by itself or is externally killed"""
+        # Without setting this, heartbeat may get us
+        self.terminating = True
         self.log.info("Task exited with return code %s", return_code)
         self.task_instance.refresh_from_db()
-        # task exited by itself, so we need to check for error file
+
+        if self.task_instance.state == State.RUNNING:
+            # This is for a case where the task received a SIGKILL
+            # while running or the task runner received a sigterm
+            self.task_instance.handle_failure(error=None)
+        # We need to check for error file
         # in case it failed due to runtime exception/error
         error = None
-        if self.task_instance.state == State.RUNNING:
-            # This is for a case where the task received a sigkill
-            # while running
-            self.task_instance.set_state(State.FAILED)
         if self.task_instance.state != State.SUCCESS:
             error = self.task_runner.deserialize_run_error()
         self.task_instance._run_finished_callback(error=error)  # pylint: disable=protected-access
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index d9f1398..94f894d 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -21,6 +21,7 @@ import os
 import signal
 import time
 import uuid
+from datetime import timedelta
 from multiprocessing import Lock, Value
 from unittest import mock
 from unittest.mock import patch
@@ -272,7 +273,6 @@ class TestLocalTaskJob:
                 delta = (time2 - time1).total_seconds()
                 assert abs(delta - job.heartrate) < 0.5
 
-    @pytest.mark.quarantined
     def test_mark_success_no_kill(self):
         """
         Test that ensures that mark_success in the UI doesn't cause
@@ -300,7 +300,6 @@ class TestLocalTaskJob:
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
         process = multiprocessing.Process(target=job1.run)
         process.start()
-        ti.refresh_from_db()
         for _ in range(0, 50):
             if ti.state == State.RUNNING:
                 break
@@ -510,7 +509,6 @@ class TestLocalTaskJob:
         assert ti.state == State.FAILED  # task exits with failure state
         assert failure_callback_called.value == 1
 
-    @pytest.mark.quarantined
     def test_mark_success_on_success_callback(self, dag_maker):
         """
         Test that ensures that where a task is marked success in the UI
@@ -567,15 +565,9 @@ class TestLocalTaskJob:
         assert task_terminated_externally.value == 1
         assert not process.is_alive()
 
-    @parameterized.expand(
-        [
-            (signal.SIGTERM,),
-            (signal.SIGKILL,),
-        ]
-    )
-    def test_process_kill_calls_on_failure_callback(self, signal_type, dag_maker):
+    def test_task_sigkill_calls_on_failure_callback(self, dag_maker):
         """
-        Test that ensures that when a task is killed with sigterm or sigkill
+        Test that ensures that when a task is killed with sigkill
         on_failure_callback gets executed
         """
         # use shared memory value so we can properly track value change even if
@@ -587,10 +579,50 @@ class TestLocalTaskJob:
         def failure_callback(context):
             with shared_mem_lock:
                 failure_callback_called.value += 1
-            assert context['dag_run'].dag_id == 'test_mark_failure'
+            assert context['dag_run'].dag_id == 'test_send_sigkill'
 
         def task_function(ti):
+            os.kill(os.getpid(), signal.SIGKILL)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        with dag_maker(dag_id='test_send_sigkill'):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                on_failure_callback=failure_callback,
+            )
+
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        settings.engine.dispose()
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        time.sleep(0.3)
+        process.join(timeout=10)
+        assert failure_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+        assert not process.is_alive()
+
+    def test_process_sigterm_calls_on_failure_callback(self, dag_maker):
+        """
+        Test that ensures that when a task runner is killed with sigterm
+        on_failure_callback gets executed
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        failure_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
 
+        def failure_callback(context):
+            with shared_mem_lock:
+                failure_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_mark_failure'
+
+        def task_function(ti):
             time.sleep(60)
             # This should not happen -- the state change should be noticed and the task should get killed
             with shared_mem_lock:
@@ -605,20 +637,16 @@ class TestLocalTaskJob:
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
-        job1.task_runner = StandardTaskRunner(job1)
-
         settings.engine.dispose()
         process = multiprocessing.Process(target=job1.run)
         process.start()
-
-        for _ in range(0, 20):
+        for _ in range(0, 25):
             ti.refresh_from_db()
-            if ti.state == State.RUNNING and ti.pid is not None:
+            if ti.state == State.RUNNING:
                 break
             time.sleep(0.2)
-        assert ti.pid is not None
-        assert ti.state == State.RUNNING
-        os.kill(ti.pid, signal_type)
+        os.kill(process.pid, signal.SIGTERM)
+        ti.refresh_from_db()
         process.join(timeout=10)
         assert failure_callback_called.value == 1
         assert task_terminated_externally.value == 1
@@ -726,6 +754,102 @@ class TestLocalTaskJob:
             if scheduler_job.processor_agent:
                 scheduler_job.processor_agent.end()
 
+    def test_task_sigkill_works_with_retries(self, dag_maker):
+        """
+        Test that ensures that tasks are retried when they receive sigkill
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        retry_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
+
+        def retry_callback(context):
+            with shared_mem_lock:
+                retry_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_mark_failure_2'
+
+        def task_function(ti):
+            os.kill(os.getpid(), signal.SIGKILL)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        with dag_maker(
+            dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}
+        ):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                retries=1,
+                retry_delay=timedelta(seconds=2),
+                on_retry_callback=retry_callback,
+            )
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.task_runner.start()
+        settings.engine.dispose()
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        time.sleep(0.4)
+        process.join()
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
+        assert retry_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+
+    def test_process_sigterm_works_with_retries(self, dag_maker):
+        """
+        Test that ensures that task runner sets tasks to retry when they(task runner)
+         receive sigterm
+        """
+        # use shared memory value so we can properly track value change even if
+        # it's been updated across processes.
+        retry_callback_called = Value('i', 0)
+        task_terminated_externally = Value('i', 1)
+        shared_mem_lock = Lock()
+
+        def retry_callback(context):
+            with shared_mem_lock:
+                retry_callback_called.value += 1
+            assert context['dag_run'].dag_id == 'test_mark_failure_2'
+
+        def task_function(ti):
+            time.sleep(60)
+            # This should not happen -- the state change should be noticed and the task should get killed
+            with shared_mem_lock:
+                task_terminated_externally.value = 0
+
+        with dag_maker(dag_id='test_mark_failure_2'):
+            task = PythonOperator(
+                task_id='test_on_failure',
+                python_callable=task_function,
+                retries=1,
+                retry_delay=timedelta(seconds=2),
+                on_retry_callback=retry_callback,
+            )
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.task_runner.start()
+        settings.engine.dispose()
+        process = multiprocessing.Process(target=job1.run)
+        process.start()
+        for _ in range(0, 25):
+            ti.refresh_from_db()
+            if ti.state == State.RUNNING and ti.pid is not None:
+                break
+            time.sleep(0.2)
+        os.kill(process.pid, signal.SIGTERM)
+        process.join()
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
+        assert retry_callback_called.value == 1
+        assert task_terminated_externally.value == 1
+
     def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
         """Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
         dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
@@ -788,5 +912,5 @@ class TestLocalTaskJobPerformance:
         mock_get_task_runner.return_value.return_code.side_effects = return_codes
 
         job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
-        with assert_queries_count(16):
+        with assert_queries_count(18):
             job.run()
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index c1882e1..db23271 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -18,6 +18,7 @@
 
 import datetime
 import os
+import signal
 import time
 import unittest
 import urllib
@@ -522,6 +523,37 @@ class TestTaskInstance(unittest.TestCase):
         ti.run()
         assert State.SKIPPED == ti.state
 
+    def test_task_sigterm_works_with_retries(self):
+        """
+        Test that ensures that tasks are retried when they receive sigterm
+        """
+        dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
+
+        def task_function(ti):
+            # pylint: disable=unused-argument
+            os.kill(ti.pid, signal.SIGTERM)
+
+        task = PythonOperator(
+            task_id='test_on_failure',
+            python_callable=task_function,
+            retries=1,
+            retry_delay=datetime.timedelta(seconds=2),
+            dag=dag,
+        )
+
+        dag.create_dagrun(
+            run_id="test",
+            state=State.RUNNING,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+        )
+        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti.refresh_from_db()
+        with self.assertRaises(AirflowException):
+            ti.run()
+        ti.refresh_from_db()
+        assert ti.state == State.UP_FOR_RETRY
+
     def test_retry_delay(self):
         """
         Test that retry delays are respected

[airflow] 05/09: Fix race condition with dagrun callbacks (#16741)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7cc60024eb20998efbe1c1246e8cc8f7064e78d2
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Thu Jul 29 11:05:02 2021 -0600

    Fix race condition with dagrun callbacks (#16741)
    
    Instead of immediately sending callbacks to be processed, wait until
    after we commit so the dagrun.end_date is guaranteed to be there when
    the callback runs.
    
    (cherry picked from commit fb3031acf51f95384154143553aac1a40e568ebf)
---
 airflow/jobs/scheduler_job.py          | 18 +++++---
 tests/dag_processing/test_processor.py | 20 +++++----
 tests/jobs/test_scheduler_job.py       | 80 +++++++++++++++++++++++++++++-----
 3 files changed, 94 insertions(+), 24 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 7a37b25..18ec981 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -888,6 +888,7 @@ class SchedulerJob(BaseJob):
             # Bulk fetch the currently active dag runs for the dags we are
             # examining, rather than making one query per DagRun
 
+            callback_tuples = []
             for dag_run in dag_runs:
                 # Use try_except to not stop the Scheduler when a Serialized DAG is not found
                 # This takes care of Dynamic DAGs especially
@@ -896,13 +897,18 @@ class SchedulerJob(BaseJob):
                 # But this would take care of the scenario when the Scheduler is restarted after DagRun is
                 # created and the DAG is deleted / renamed
                 try:
-                    self._schedule_dag_run(dag_run, session)
+                    callback_to_run = self._schedule_dag_run(dag_run, session)
+                    callback_tuples.append((dag_run, callback_to_run))
                 except SerializedDagNotFound:
                     self.log.exception("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
                     continue
 
             guard.commit()
 
+            # Send the callbacks after we commit to ensure the context is up to date when it gets run
+            for dag_run, callback_to_run in callback_tuples:
+                self._send_dag_callbacks_to_processor(dag_run, callback_to_run)
+
             # Without this, the session has an invalid view of the DB
             session.expunge_all()
             # END: schedule TIs
@@ -1064,12 +1070,12 @@ class SchedulerJob(BaseJob):
         self,
         dag_run: DagRun,
         session: Session,
-    ) -> int:
+    ) -> Optional[DagCallbackRequest]:
         """
         Make scheduling decisions about an individual dag run
 
         :param dag_run: The DagRun to schedule
-        :return: Number of tasks scheduled
+        :return: Callback that needs to be executed
         """
         dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
 
@@ -1116,13 +1122,13 @@ class SchedulerJob(BaseJob):
         # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
         schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
 
-        self._send_dag_callbacks_to_processor(dag_run, callback_to_run)
-
         # This will do one query per dag run. We "could" build up a complex
         # query to update all the TIs across all the execution dates and dag
         # IDs in a single query, but it turns out that can be _very very slow_
         # see #11147/commit ee90807ac for more details
-        return dag_run.schedule_tis(schedulable_tis, session)
+        dag_run.schedule_tis(schedulable_tis, session)
+
+        return callback_to_run
 
     @provide_session
     def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index feb3497..b7f8e7d 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -115,6 +115,10 @@ class TestDagFileProcessor(unittest.TestCase):
         non_serialized_dagbag.sync_to_db()
         cls.dagbag = DagBag(read_dags_from_db=True)
 
+    @staticmethod
+    def assert_scheduled_ti_count(session, count):
+        assert count == session.query(TaskInstance).filter_by(state=State.SCHEDULED).count()
+
     def test_dag_file_processor_sla_miss_callback(self):
         """
         Test that the dag file processor calls the sla miss callback
@@ -387,8 +391,8 @@ class TestDagFileProcessor(unittest.TestCase):
             ti.start_date = start_date
             ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, session)
-            assert count == 1
+            self.scheduler_job._schedule_dag_run(dr, session)
+            self.assert_scheduled_ti_count(session, 1)
 
             session.refresh(ti)
             assert ti.state == State.SCHEDULED
@@ -444,8 +448,8 @@ class TestDagFileProcessor(unittest.TestCase):
             ti.start_date = start_date
             ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, session)
-            assert count == 1
+            self.scheduler_job._schedule_dag_run(dr, session)
+            self.assert_scheduled_ti_count(session, 1)
 
             session.refresh(ti)
             assert ti.state == State.SCHEDULED
@@ -504,8 +508,8 @@ class TestDagFileProcessor(unittest.TestCase):
                 ti.start_date = start_date
                 ti.end_date = end_date
 
-            count = self.scheduler_job._schedule_dag_run(dr, session)
-            assert count == 2
+            self.scheduler_job._schedule_dag_run(dr, session)
+            self.assert_scheduled_ti_count(session, 2)
 
             session.refresh(tis[0])
             session.refresh(tis[1])
@@ -547,9 +551,9 @@ class TestDagFileProcessor(unittest.TestCase):
         BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
         SerializedDagModel.write_dag(dag=dag)
 
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
+        self.assert_scheduled_ti_count(session, 2)
         session.flush()
-        assert scheduled_tis == 2
 
         drs = DagRun.find(dag_id=dag.dag_id, session=session)
         assert len(drs) == 1
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 0ee6f5f..5de365e 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -1710,10 +1710,11 @@ class TestSchedulerJob(unittest.TestCase):
         ti = dr.get_task_instance('dummy')
         ti.set_state(state, session)
 
-        self.scheduler_job._schedule_dag_run(dr, session)
+        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
+            self.scheduler_job._do_scheduling(session)
 
         expected_callback = DagCallbackRequest(
-            full_filepath=dr.dag.fileloc,
+            full_filepath=dag.fileloc,
             dag_id=dr.dag_id,
             is_failure_callback=bool(state == State.FAILED),
             execution_date=dr.execution_date,
@@ -1729,6 +1730,64 @@ class TestSchedulerJob(unittest.TestCase):
         session.rollback()
         session.close()
 
+    def test_dagrun_callbacks_commited_before_sent(self):
+        """
+        Tests that before any callbacks are sent to the processor, the session is committed. This ensures
+        that the dagrun details are up to date when the callbacks are run.
+        """
+        dag = DAG(dag_id='test_dagrun_callbacks_commited_before_sent', start_date=DEFAULT_DATE)
+        DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.Mock()
+        self.scheduler_job._send_dag_callbacks_to_processor = mock.Mock()
+        self.scheduler_job._schedule_dag_run = mock.Mock()
+
+        # Sync DAG into DB
+        with mock.patch.object(settings, "STORE_DAG_CODE", False):
+            self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+            self.scheduler_job.dagbag.sync_to_db()
+
+        session = settings.Session()
+        orm_dag = session.query(DagModel).get(dag.dag_id)
+        assert orm_dag is not None
+
+        # Create DagRun
+        self.scheduler_job._create_dag_runs([orm_dag], session)
+
+        drs = DagRun.find(dag_id=dag.dag_id, session=session)
+        assert len(drs) == 1
+        dr = drs[0]
+
+        ti = dr.get_task_instance('dummy')
+        ti.set_state(State.SUCCESS, session)
+
+        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False), mock.patch(
+            "airflow.jobs.scheduler_job.prohibit_commit"
+        ) as mock_gaurd:
+            mock_gaurd.return_value.__enter__.return_value.commit.side_effect = session.commit
+
+            def mock_schedule_dag_run(*args, **kwargs):
+                mock_gaurd.reset_mock()
+                return None
+
+            def mock_send_dag_callbacks_to_processor(*args, **kwargs):
+                mock_gaurd.return_value.__enter__.return_value.commit.assert_called_once()
+
+            self.scheduler_job._send_dag_callbacks_to_processor.side_effect = (
+                mock_send_dag_callbacks_to_processor
+            )
+            self.scheduler_job._schedule_dag_run.side_effect = mock_schedule_dag_run
+
+            self.scheduler_job._do_scheduling(session)
+
+        # Verify dag failure callback request is sent to file processor
+        self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once()
+        # and mock_send_dag_callbacks_to_processor has asserted the callback was sent after a commit
+
+        session.rollback()
+        session.close()
+
     @parameterized.expand([(State.SUCCESS,), (State.FAILED,)])
     def test_dagrun_callbacks_are_not_added_when_callbacks_are_not_defined(self, state):
         """
@@ -1765,10 +1824,15 @@ class TestSchedulerJob(unittest.TestCase):
         ti = dr.get_task_instance('test_task')
         ti.set_state(state, session)
 
-        self.scheduler_job._schedule_dag_run(dr, session)
+        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False):
+            self.scheduler_job._do_scheduling(session)
 
         # Verify Callback is not set (i.e is None) when no callbacks are set on DAG
-        self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once_with(dr, None)
+        self.scheduler_job._send_dag_callbacks_to_processor.assert_called_once()
+        call_args = self.scheduler_job._send_dag_callbacks_to_processor.call_args[0]
+        assert call_args[0].dag_id == dr.dag_id
+        assert call_args[0].execution_date == dr.execution_date
+        assert call_args[1] is None
 
         session.rollback()
         session.close()
@@ -2411,12 +2475,10 @@ class TestSchedulerJob(unittest.TestCase):
 
         # Verify that DagRun.verify_integrity is not called
         with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity:
-            scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+            self.scheduler_job._schedule_dag_run(dr, session)
             mock_verify_integrity.assert_not_called()
         session.flush()
 
-        assert scheduled_tis == 1
-
         tis_count = (
             session.query(func.count(TaskInstance.task_id))
             .filter(
@@ -2474,11 +2536,9 @@ class TestSchedulerJob(unittest.TestCase):
         dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
         assert dag_version_2 != dag_version_1
 
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, session)
+        self.scheduler_job._schedule_dag_run(dr, session)
         session.flush()
 
-        assert scheduled_tis == 2
-
         drs = DagRun.find(dag_id=dag.dag_id, session=session)
         assert len(drs) == 1
         dr = drs[0]

[airflow] 02/09: Move DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py (#16581)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 22567165712289861521389dc5b4aee874c4d3d5
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Fri Jun 25 05:36:56 2021 +0100

    Move DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py (#16581)
    
    This change moves DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py.
    
    Also, dag_processing.py was moved out of airflow/utils.
    
    (cherry picked from commit 88ee2aa7ddf91799f25add9c57e1ea128de2b7aa)
---
 .github/boring-cyborg.yml                          |   2 +-
 airflow/dag_processing/__init__.py                 |  16 +
 .../manager.py}                                    |   0
 airflow/dag_processing/processor.py                | 650 ++++++++++++++++++
 airflow/jobs/scheduler_job.py                      | 619 +----------------
 tests/dag_processing/__init__.py                   |  16 +
 .../test_manager.py}                               |  16 +-
 tests/dag_processing/test_processor.py             | 749 +++++++++++++++++++++
 tests/jobs/test_scheduler_job.py                   | 700 +------------------
 tests/test_utils/perf/perf_kit/python.py           |   2 +-
 tests/test_utils/perf/perf_kit/sqlalchemy.py       |   2 +-
 11 files changed, 1456 insertions(+), 1316 deletions(-)

diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml
index d5f7632..8ae0532 100644
--- a/.github/boring-cyborg.yml
+++ b/.github/boring-cyborg.yml
@@ -157,7 +157,7 @@ labelPRBasedOnFilePath:
     - airflow/executors/**/*
     - airflow/jobs/**/*
     - airflow/task/task_runner/**/*
-    - airflow/utils/dag_processing.py
+    - airflow/dag_processing/**/*
     - docs/apache-airflow/executor/**/*
     - docs/apache-airflow/scheduler.rst
     - tests/executors/**/*
diff --git a/airflow/dag_processing/__init__.py b/airflow/dag_processing/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/airflow/dag_processing/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/utils/dag_processing.py b/airflow/dag_processing/manager.py
similarity index 100%
rename from airflow/utils/dag_processing.py
rename to airflow/dag_processing/manager.py
diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
new file mode 100644
index 0000000..44dc5f2
--- /dev/null
+++ b/airflow/dag_processing/processor.py
@@ -0,0 +1,650 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import logging
+import multiprocessing
+import os
+import signal
+import threading
+from contextlib import redirect_stderr, redirect_stdout, suppress
+from datetime import timedelta
+from multiprocessing.connection import Connection as MultiprocessingConnection
+from typing import List, Optional, Set, Tuple
+
+from setproctitle import setproctitle  # pylint: disable=no-name-in-module
+from sqlalchemy import func, or_
+from sqlalchemy.orm.session import Session
+
+from airflow import models, settings
+from airflow.configuration import conf
+from airflow.dag_processing.manager import AbstractDagFileProcessorProcess
+from airflow.exceptions import AirflowException, TaskNotFound
+from airflow.models import DAG, DagModel, SlaMiss, errors
+from airflow.models.dagbag import DagBag
+from airflow.stats import Stats
+from airflow.utils import timezone
+from airflow.utils.callback_requests import (
+    CallbackRequest,
+    DagCallbackRequest,
+    SlaCallbackRequest,
+    TaskCallbackRequest,
+)
+from airflow.utils.email import get_email_address_list, send_email
+from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context
+from airflow.utils.mixins import MultiprocessingStartMethodMixin
+from airflow.utils.session import provide_session
+from airflow.utils.state import State
+
+TI = models.TaskInstance
+
+
+class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, MultiprocessingStartMethodMixin):
+    """Runs DAG processing in a separate process using DagFileProcessor
+
+    :param file_path: a Python file containing Airflow DAG definitions
+    :type file_path: str
+    :param pickle_dags: whether to serialize the DAG objects to the DB
+    :type pickle_dags: bool
+    :param dag_ids: If specified, only look at these DAG ID's
+    :type dag_ids: List[str]
+    :param callback_requests: failure callback to execute
+    :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
+    """
+
+    # Counter that increments every time an instance of this class is created
+    class_creation_counter = 0
+
+    def __init__(
+        self,
+        file_path: str,
+        pickle_dags: bool,
+        dag_ids: Optional[List[str]],
+        callback_requests: List[CallbackRequest],
+    ):
+        super().__init__()
+        self._file_path = file_path
+        self._pickle_dags = pickle_dags
+        self._dag_ids = dag_ids
+        self._callback_requests = callback_requests
+
+        # The process that was launched to process the given .
+        self._process: Optional[multiprocessing.process.BaseProcess] = None
+        # The result of DagFileProcessor.process_file(file_path).
+        self._result: Optional[Tuple[int, int]] = None
+        # Whether the process is done running.
+        self._done = False
+        # When the process started.
+        self._start_time: Optional[datetime.datetime] = None
+        # This ID is use to uniquely name the process / thread that's launched
+        # by this processor instance
+        self._instance_id = DagFileProcessorProcess.class_creation_counter
+
+        self._parent_channel: Optional[MultiprocessingConnection] = None
+        DagFileProcessorProcess.class_creation_counter += 1
+
+    @property
+    def file_path(self) -> str:
+        return self._file_path
+
+    @staticmethod
+    def _run_file_processor(
+        result_channel: MultiprocessingConnection,
+        parent_channel: MultiprocessingConnection,
+        file_path: str,
+        pickle_dags: bool,
+        dag_ids: Optional[List[str]],
+        thread_name: str,
+        callback_requests: List[CallbackRequest],
+    ) -> None:
+        """
+        Process the given file.
+
+        :param result_channel: the connection to use for passing back the result
+        :type result_channel: multiprocessing.Connection
+        :param parent_channel: the parent end of the channel to close in the child
+        :type parent_channel: multiprocessing.Connection
+        :param file_path: the file to process
+        :type file_path: str
+        :param pickle_dags: whether to pickle the DAGs found in the file and
+            save them to the DB
+        :type pickle_dags: bool
+        :param dag_ids: if specified, only examine DAG ID's that are
+            in this list
+        :type dag_ids: list[str]
+        :param thread_name: the name to use for the process that is launched
+        :type thread_name: str
+        :param callback_requests: failure callback to execute
+        :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
+        :return: the process that was launched
+        :rtype: multiprocessing.Process
+        """
+        # This helper runs in the newly created process
+        log: logging.Logger = logging.getLogger("airflow.processor")
+
+        # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in
+        # the child, else it won't get closed properly until we exit.
+        log.info("Closing parent pipe")
+
+        parent_channel.close()
+        del parent_channel
+
+        set_context(log, file_path)
+        setproctitle(f"airflow scheduler - DagFileProcessor {file_path}")
+
+        try:
+            # redirect stdout/stderr to log
+            with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr(
+                StreamLogWriter(log, logging.WARN)
+            ), Stats.timer() as timer:
+                # Re-configure the ORM engine as there are issues with multiple processes
+                settings.configure_orm()
+
+                # Change the thread name to differentiate log lines. This is
+                # really a separate process, but changing the name of the
+                # process doesn't work, so changing the thread name instead.
+                threading.current_thread().name = thread_name
+
+                log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
+                dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log)
+                result: Tuple[int, int] = dag_file_processor.process_file(
+                    file_path=file_path,
+                    pickle_dags=pickle_dags,
+                    callback_requests=callback_requests,
+                )
+                result_channel.send(result)
+            log.info("Processing %s took %.3f seconds", file_path, timer.duration)
+        except Exception:  # pylint: disable=broad-except
+            # Log exceptions through the logging framework.
+            log.exception("Got an exception! Propagating...")
+            raise
+        finally:
+            # We re-initialized the ORM within this Process above so we need to
+            # tear it down manually here
+            settings.dispose_orm()
+
+            result_channel.close()
+
+    def start(self) -> None:
+        """Launch the process and start processing the DAG."""
+        start_method = self._get_multiprocessing_start_method()
+        context = multiprocessing.get_context(start_method)
+
+        _parent_channel, _child_channel = context.Pipe(duplex=False)
+        process = context.Process(
+            target=type(self)._run_file_processor,
+            args=(
+                _child_channel,
+                _parent_channel,
+                self.file_path,
+                self._pickle_dags,
+                self._dag_ids,
+                f"DagFileProcessor{self._instance_id}",
+                self._callback_requests,
+            ),
+            name=f"DagFileProcessor{self._instance_id}-Process",
+        )
+        self._process = process
+        self._start_time = timezone.utcnow()
+        process.start()
+
+        # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it
+        # from closing in some cases
+        _child_channel.close()
+        del _child_channel
+
+        # Don't store it on self until after we've started the child process - we don't want to keep it from
+        # getting GCd/closed
+        self._parent_channel = _parent_channel
+
+    def kill(self) -> None:
+        """Kill the process launched to process the file, and ensure consistent state."""
+        if self._process is None:
+            raise AirflowException("Tried to kill before starting!")
+        self._kill_process()
+
+    def terminate(self, sigkill: bool = False) -> None:
+        """
+        Terminate (and then kill) the process launched to process the file.
+
+        :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work.
+        :type sigkill: bool
+        """
+        if self._process is None or self._parent_channel is None:
+            raise AirflowException("Tried to call terminate before starting!")
+
+        self._process.terminate()
+        # Arbitrarily wait 5s for the process to die
+        with suppress(TimeoutError):
+            self._process._popen.wait(5)  # type: ignore  # pylint: disable=protected-access
+        if sigkill:
+            self._kill_process()
+        self._parent_channel.close()
+
+    def _kill_process(self) -> None:
+        if self._process is None:
+            raise AirflowException("Tried to kill process before starting!")
+
+        if self._process.is_alive() and self._process.pid:
+            self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid)
+            os.kill(self._process.pid, signal.SIGKILL)
+        if self._parent_channel:
+            self._parent_channel.close()
+
+    @property
+    def pid(self) -> int:
+        """
+        :return: the PID of the process launched to process the given file
+        :rtype: int
+        """
+        if self._process is None or self._process.pid is None:
+            raise AirflowException("Tried to get PID before starting!")
+        return self._process.pid
+
+    @property
+    def exit_code(self) -> Optional[int]:
+        """
+        After the process is finished, this can be called to get the return code
+
+        :return: the exit code of the process
+        :rtype: int
+        """
+        if self._process is None:
+            raise AirflowException("Tried to get exit code before starting!")
+        if not self._done:
+            raise AirflowException("Tried to call retcode before process was finished!")
+        return self._process.exitcode
+
+    @property
+    def done(self) -> bool:
+        """
+        Check if the process launched to process this file is done.
+
+        :return: whether the process is finished running
+        :rtype: bool
+        """
+        if self._process is None or self._parent_channel is None:
+            raise AirflowException("Tried to see if it's done before starting!")
+
+        if self._done:
+            return True
+
+        if self._parent_channel.poll():
+            try:
+                self._result = self._parent_channel.recv()
+                self._done = True
+                self.log.debug("Waiting for %s", self._process)
+                self._process.join()
+                self._parent_channel.close()
+                return True
+            except EOFError:
+                # If we get an EOFError, it means the child end of the pipe has been closed. This only happens
+                # in the finally block. But due to a possible race condition, the process may have not yet
+                # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a
+                # "suitable" timeout.
+                self._done = True
+                # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable.
+                self._process.join(timeout=5)
+                if self._process.is_alive():
+                    # Didn't shut down cleanly - kill it
+                    self._kill_process()
+
+        if not self._process.is_alive():
+            self._done = True
+            self.log.debug("Waiting for %s", self._process)
+            self._process.join()
+            self._parent_channel.close()
+            return True
+
+        return False
+
+    @property
+    def result(self) -> Optional[Tuple[int, int]]:
+        """
+        :return: result of running DagFileProcessor.process_file()
+        :rtype: tuple[int, int] or None
+        """
+        if not self.done:
+            raise AirflowException("Tried to get the result before it's done!")
+        return self._result
+
+    @property
+    def start_time(self) -> datetime.datetime:
+        """
+        :return: when this started to process the file
+        :rtype: datetime
+        """
+        if self._start_time is None:
+            raise AirflowException("Tried to get start time before it started!")
+        return self._start_time
+
+    @property
+    def waitable_handle(self):
+        return self._process.sentinel
+
+
+class DagFileProcessor(LoggingMixin):
+    """
+    Process a Python file containing Airflow DAGs.
+
+    This includes:
+
+    1. Execute the file and look for DAG objects in the namespace.
+    2. Execute any Callbacks if passed to DagFileProcessor.process_file
+    3. Serialize the DAGs and save it to DB (or update existing record in the DB).
+    4. Pickle the DAG and save it to the DB (if necessary).
+    5. Record any errors importing the file into ORM
+
+    Returns a tuple of 'number of dags found' and 'the count of import errors'
+
+    :param dag_ids: If specified, only look at these DAG ID's
+    :type dag_ids: List[str]
+    :param log: Logger to save the processing process
+    :type log: logging.Logger
+    """
+
+    UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE')
+
+    def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger):
+        super().__init__()
+        self.dag_ids = dag_ids
+        self._log = log
+
+    @provide_session
+    def manage_slas(self, dag: DAG, session: Session = None) -> None:
+        """
+        Finding all tasks that have SLAs defined, and sending alert emails
+        where needed. New SLA misses are also recorded in the database.
+
+        We are assuming that the scheduler runs often, so we only check for
+        tasks that should have succeeded in the past hour.
+        """
+        self.log.info("Running SLA Checks for %s", dag.dag_id)
+        if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
+            self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag)
+            return
+
+        qry = (
+            session.query(TI.task_id, func.max(TI.execution_date).label('max_ti'))
+            .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
+            .filter(TI.dag_id == dag.dag_id)
+            .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
+            .filter(TI.task_id.in_(dag.task_ids))
+            .group_by(TI.task_id)
+            .subquery('sq')
+        )
+
+        max_tis: List[TI] = (
+            session.query(TI)
+            .filter(
+                TI.dag_id == dag.dag_id,
+                TI.task_id == qry.c.task_id,
+                TI.execution_date == qry.c.max_ti,
+            )
+            .all()
+        )
+
+        ts = timezone.utcnow()
+        for ti in max_tis:
+            task = dag.get_task(ti.task_id)
+            if task.sla and not isinstance(task.sla, timedelta):
+                raise TypeError(
+                    f"SLA is expected to be timedelta object, got "
+                    f"{type(task.sla)} in {task.dag_id}:{task.task_id}"
+                )
+
+            dttm = dag.following_schedule(ti.execution_date)
+            while dttm < timezone.utcnow():
+                following_schedule = dag.following_schedule(dttm)
+                if following_schedule + task.sla < timezone.utcnow():
+                    session.merge(
+                        SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)
+                    )
+                dttm = dag.following_schedule(dttm)
+        session.commit()
+
+        # pylint: disable=singleton-comparison
+        slas: List[SlaMiss] = (
+            session.query(SlaMiss)
+            .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id)  # noqa
+            .all()
+        )
+        # pylint: enable=singleton-comparison
+
+        if slas:  # pylint: disable=too-many-nested-blocks
+            sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas]
+            fetched_tis: List[TI] = (
+                session.query(TI)
+                .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
+                .all()
+            )
+            blocking_tis: List[TI] = []
+            for ti in fetched_tis:
+                if ti.task_id in dag.task_ids:
+                    ti.task = dag.get_task(ti.task_id)
+                    blocking_tis.append(ti)
+                else:
+                    session.delete(ti)
+                    session.commit()
+
+            task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas)
+            blocking_task_list = "\n".join(
+                ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis
+            )
+            # Track whether email or any alert notification sent
+            # We consider email or the alert callback as notifications
+            email_sent = False
+            notification_sent = False
+            if dag.sla_miss_callback:
+                # Execute the alert callback
+                self.log.info('Calling SLA miss callback')
+                try:
+                    dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
+                    notification_sent = True
+                except Exception:  # pylint: disable=broad-except
+                    self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
+            email_content = f"""\
+            Here's a list of tasks that missed their SLAs:
+            <pre><code>{task_list}\n<code></pre>
+            Blocking tasks:
+            <pre><code>{blocking_task_list}<code></pre>
+            Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
+            """
+
+            tasks_missed_sla = []
+            for sla in slas:
+                try:
+                    task = dag.get_task(sla.task_id)
+                except TaskNotFound:
+                    # task already deleted from DAG, skip it
+                    self.log.warning(
+                        "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id
+                    )
+                    continue
+                tasks_missed_sla.append(task)
+
+            emails: Set[str] = set()
+            for task in tasks_missed_sla:
+                if task.email:
+                    if isinstance(task.email, str):
+                        emails |= set(get_email_address_list(task.email))
+                    elif isinstance(task.email, (list, tuple)):
+                        emails |= set(task.email)
+            if emails:
+                try:
+                    send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content)
+                    email_sent = True
+                    notification_sent = True
+                except Exception:  # pylint: disable=broad-except
+                    Stats.incr('sla_email_notification_failure')
+                    self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id)
+            # If we sent any notification, update the sla_miss table
+            if notification_sent:
+                for sla in slas:
+                    sla.email_sent = email_sent
+                    sla.notification_sent = True
+                    session.merge(sla)
+            session.commit()
+
+    @staticmethod
+    def update_import_errors(session: Session, dagbag: DagBag) -> None:
+        """
+        For the DAGs in the given DagBag, record any associated import errors and clears
+        errors for files that no longer have them. These are usually displayed through the
+        Airflow UI so that users know that there are issues parsing DAGs.
+
+        :param session: session for ORM operations
+        :type session: sqlalchemy.orm.session.Session
+        :param dagbag: DagBag containing DAGs with import errors
+        :type dagbag: airflow.DagBag
+        """
+        # Clear the errors of the processed files
+        for dagbag_file in dagbag.file_last_changed:
+            session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete()
+
+        # Add the errors of the processed files
+        for filename, stacktrace in dagbag.import_errors.items():
+            session.add(
+                errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace)
+            )
+        session.commit()
+
+    @provide_session
+    def execute_callbacks(
+        self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None
+    ) -> None:
+        """
+        Execute on failure callbacks. These objects can come from SchedulerJob or from
+        DagFileProcessorManager.
+
+        :param dagbag: Dag Bag of dags
+        :param callback_requests: failure callbacks to execute
+        :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
+        :param session: DB session.
+        """
+        for request in callback_requests:
+            self.log.debug("Processing Callback Request: %s", request)
+            try:
+                if isinstance(request, TaskCallbackRequest):
+                    self._execute_task_callbacks(dagbag, request)
+                elif isinstance(request, SlaCallbackRequest):
+                    self.manage_slas(dagbag.dags.get(request.dag_id))
+                elif isinstance(request, DagCallbackRequest):
+                    self._execute_dag_callbacks(dagbag, request, session)
+            except Exception:  # pylint: disable=broad-except
+                self.log.exception(
+                    "Error executing %s callback for file: %s",
+                    request.__class__.__name__,
+                    request.full_filepath,
+                )
+
+        session.commit()
+
+    @provide_session
+    def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
+        dag = dagbag.dags[request.dag_id]
+        dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session)
+        dag.handle_callback(
+            dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
+        )
+
+    def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
+        simple_ti = request.simple_task_instance
+        if simple_ti.dag_id in dagbag.dags:
+            dag = dagbag.dags[simple_ti.dag_id]
+            if simple_ti.task_id in dag.task_ids:
+                task = dag.get_task(simple_ti.task_id)
+                ti = TI(task, simple_ti.execution_date)
+                # Get properties needed for failure handling from SimpleTaskInstance.
+                ti.start_date = simple_ti.start_date
+                ti.end_date = simple_ti.end_date
+                ti.try_number = simple_ti.try_number
+                ti.state = simple_ti.state
+                ti.test_mode = self.UNIT_TEST_MODE
+                if request.is_failure_callback:
+                    ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
+                    self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
+
+    @provide_session
+    def process_file(
+        self,
+        file_path: str,
+        callback_requests: List[CallbackRequest],
+        pickle_dags: bool = False,
+        session: Session = None,
+    ) -> Tuple[int, int]:
+        """
+        Process a Python file containing Airflow DAGs.
+
+        This includes:
+
+        1. Execute the file and look for DAG objects in the namespace.
+        2. Execute any Callbacks if passed to this method.
+        3. Serialize the DAGs and save it to DB (or update existing record in the DB).
+        4. Pickle the DAG and save it to the DB (if necessary).
+        5. Record any errors importing the file into ORM
+
+        :param file_path: the path to the Python file that should be executed
+        :type file_path: str
+        :param callback_requests: failure callback to execute
+        :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest]
+        :param pickle_dags: whether serialize the DAGs found in the file and
+            save them to the db
+        :type pickle_dags: bool
+        :param session: Sqlalchemy ORM Session
+        :type session: Session
+        :return: number of dags found, count of import errors
+        :rtype: Tuple[int, int]
+        """
+        self.log.info("Processing file %s for tasks to queue", file_path)
+
+        try:
+            dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False)
+        except Exception:  # pylint: disable=broad-except
+            self.log.exception("Failed at reloading the DAG file %s", file_path)
+            Stats.incr('dag_file_refresh_error', 1, 1)
+            return 0, 0
+
+        if len(dagbag.dags) > 0:
+            self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path)
+        else:
+            self.log.warning("No viable dags retrieved from %s", file_path)
+            self.update_import_errors(session, dagbag)
+            return 0, len(dagbag.import_errors)
+
+        self.execute_callbacks(dagbag, callback_requests)
+
+        # Save individual DAGs in the ORM
+        dagbag.sync_to_db()
+
+        if pickle_dags:
+            paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
+
+            unpaused_dags: List[DAG] = [
+                dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids
+            ]
+
+            for dag in unpaused_dags:
+                dag.pickle(session)
+
+        # Record import errors into the ORM
+        try:
+            self.update_import_errors(session, dagbag)
+        except Exception:  # pylint: disable=broad-except
+            self.log.exception("Error logging import errors!")
+
+        return len(dagbag.dags), len(dagbag.import_errors)
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index fe8e0b0..5b24e00 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -23,15 +23,11 @@ import multiprocessing
 import os
 import signal
 import sys
-import threading
 import time
 from collections import defaultdict
-from contextlib import redirect_stderr, redirect_stdout, suppress
 from datetime import timedelta
-from multiprocessing.connection import Connection as MultiprocessingConnection
 from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple
 
-from setproctitle import setproctitle
 from sqlalchemy import and_, func, not_, or_, tuple_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm import load_only, selectinload
@@ -39,10 +35,13 @@ from sqlalchemy.orm.session import Session, make_transient
 
 from airflow import models, settings
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, SerializedDagNotFound, TaskNotFound
+from airflow.dag_processing.manager import DagFileProcessorAgent
+from airflow.dag_processing.processor import DagFileProcessorProcess
+from airflow.exceptions import SerializedDagNotFound
 from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS
 from airflow.jobs.base_job import BaseJob
-from airflow.models import DAG, DagModel, SlaMiss, errors
+from airflow.models import DAG
+from airflow.models.dag import DagModel
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun
 from airflow.models.serialized_dag import SerializedDagModel
@@ -50,17 +49,8 @@ from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
 from airflow.stats import Stats
 from airflow.ti_deps.dependencies_states import EXECUTION_STATES
 from airflow.utils import timezone
-from airflow.utils.callback_requests import (
-    CallbackRequest,
-    DagCallbackRequest,
-    SlaCallbackRequest,
-    TaskCallbackRequest,
-)
-from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent
-from airflow.utils.email import get_email_address_list, send_email
+from airflow.utils.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest
 from airflow.utils.event_scheduler import EventScheduler
-from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context
-from airflow.utils.mixins import MultiprocessingStartMethodMixin
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
@@ -72,603 +62,6 @@ DR = models.DagRun
 DM = models.DagModel
 
 
-class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, MultiprocessingStartMethodMixin):
-    """Runs DAG processing in a separate process using DagFileProcessor
-
-    :param file_path: a Python file containing Airflow DAG definitions
-    :type file_path: str
-    :param pickle_dags: whether to serialize the DAG objects to the DB
-    :type pickle_dags: bool
-    :param dag_ids: If specified, only look at these DAG ID's
-    :type dag_ids: List[str]
-    :param callback_requests: failure callback to execute
-    :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
-    """
-
-    # Counter that increments every time an instance of this class is created
-    class_creation_counter = 0
-
-    def __init__(
-        self,
-        file_path: str,
-        pickle_dags: bool,
-        dag_ids: Optional[List[str]],
-        callback_requests: List[CallbackRequest],
-    ):
-        super().__init__()
-        self._file_path = file_path
-        self._pickle_dags = pickle_dags
-        self._dag_ids = dag_ids
-        self._callback_requests = callback_requests
-
-        # The process that was launched to process the given .
-        self._process: Optional[multiprocessing.process.BaseProcess] = None
-        # The result of DagFileProcessor.process_file(file_path).
-        self._result: Optional[Tuple[int, int]] = None
-        # Whether the process is done running.
-        self._done = False
-        # When the process started.
-        self._start_time: Optional[datetime.datetime] = None
-        # This ID is use to uniquely name the process / thread that's launched
-        # by this processor instance
-        self._instance_id = DagFileProcessorProcess.class_creation_counter
-
-        self._parent_channel: Optional[MultiprocessingConnection] = None
-        DagFileProcessorProcess.class_creation_counter += 1
-
-    @property
-    def file_path(self) -> str:
-        return self._file_path
-
-    @staticmethod
-    def _run_file_processor(
-        result_channel: MultiprocessingConnection,
-        parent_channel: MultiprocessingConnection,
-        file_path: str,
-        pickle_dags: bool,
-        dag_ids: Optional[List[str]],
-        thread_name: str,
-        callback_requests: List[CallbackRequest],
-    ) -> None:
-        """
-        Process the given file.
-
-        :param result_channel: the connection to use for passing back the result
-        :type result_channel: multiprocessing.Connection
-        :param parent_channel: the parent end of the channel to close in the child
-        :type parent_channel: multiprocessing.Connection
-        :param file_path: the file to process
-        :type file_path: str
-        :param pickle_dags: whether to pickle the DAGs found in the file and
-            save them to the DB
-        :type pickle_dags: bool
-        :param dag_ids: if specified, only examine DAG ID's that are
-            in this list
-        :type dag_ids: list[str]
-        :param thread_name: the name to use for the process that is launched
-        :type thread_name: str
-        :param callback_requests: failure callback to execute
-        :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
-        :return: the process that was launched
-        :rtype: multiprocessing.Process
-        """
-        # This helper runs in the newly created process
-        log: logging.Logger = logging.getLogger("airflow.processor")
-
-        # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in
-        # the child, else it won't get closed properly until we exit.
-        log.info("Closing parent pipe")
-
-        parent_channel.close()
-        del parent_channel
-
-        set_context(log, file_path)
-        setproctitle(f"airflow scheduler - DagFileProcessor {file_path}")
-
-        try:
-            # redirect stdout/stderr to log
-            with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr(
-                StreamLogWriter(log, logging.WARN)
-            ), Stats.timer() as timer:
-                # Re-configure the ORM engine as there are issues with multiple processes
-                settings.configure_orm()
-
-                # Change the thread name to differentiate log lines. This is
-                # really a separate process, but changing the name of the
-                # process doesn't work, so changing the thread name instead.
-                threading.current_thread().name = thread_name
-
-                log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
-                dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log)
-                result: Tuple[int, int] = dag_file_processor.process_file(
-                    file_path=file_path,
-                    pickle_dags=pickle_dags,
-                    callback_requests=callback_requests,
-                )
-                result_channel.send(result)
-            log.info("Processing %s took %.3f seconds", file_path, timer.duration)
-        except Exception:  # pylint: disable=broad-except
-            # Log exceptions through the logging framework.
-            log.exception("Got an exception! Propagating...")
-            raise
-        finally:
-            # We re-initialized the ORM within this Process above so we need to
-            # tear it down manually here
-            settings.dispose_orm()
-
-            result_channel.close()
-
-    def start(self) -> None:
-        """Launch the process and start processing the DAG."""
-        start_method = self._get_multiprocessing_start_method()
-        context = multiprocessing.get_context(start_method)
-
-        _parent_channel, _child_channel = context.Pipe(duplex=False)
-        process = context.Process(
-            target=type(self)._run_file_processor,
-            args=(
-                _child_channel,
-                _parent_channel,
-                self.file_path,
-                self._pickle_dags,
-                self._dag_ids,
-                f"DagFileProcessor{self._instance_id}",
-                self._callback_requests,
-            ),
-            name=f"DagFileProcessor{self._instance_id}-Process",
-        )
-        self._process = process
-        self._start_time = timezone.utcnow()
-        process.start()
-
-        # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it
-        # from closing in some cases
-        _child_channel.close()
-        del _child_channel
-
-        # Don't store it on self until after we've started the child process - we don't want to keep it from
-        # getting GCd/closed
-        self._parent_channel = _parent_channel
-
-    def kill(self) -> None:
-        """Kill the process launched to process the file, and ensure consistent state."""
-        if self._process is None:
-            raise AirflowException("Tried to kill before starting!")
-        self._kill_process()
-
-    def terminate(self, sigkill: bool = False) -> None:
-        """
-        Terminate (and then kill) the process launched to process the file.
-
-        :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work.
-        :type sigkill: bool
-        """
-        if self._process is None or self._parent_channel is None:
-            raise AirflowException("Tried to call terminate before starting!")
-
-        self._process.terminate()
-        # Arbitrarily wait 5s for the process to die
-        with suppress(TimeoutError):
-            self._process._popen.wait(5)  # type: ignore  # pylint: disable=protected-access
-        if sigkill:
-            self._kill_process()
-        self._parent_channel.close()
-
-    def _kill_process(self) -> None:
-        if self._process is None:
-            raise AirflowException("Tried to kill process before starting!")
-
-        if self._process.is_alive() and self._process.pid:
-            self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid)
-            os.kill(self._process.pid, signal.SIGKILL)
-        if self._parent_channel:
-            self._parent_channel.close()
-
-    @property
-    def pid(self) -> int:
-        """
-        :return: the PID of the process launched to process the given file
-        :rtype: int
-        """
-        if self._process is None or self._process.pid is None:
-            raise AirflowException("Tried to get PID before starting!")
-        return self._process.pid
-
-    @property
-    def exit_code(self) -> Optional[int]:
-        """
-        After the process is finished, this can be called to get the return code
-
-        :return: the exit code of the process
-        :rtype: int
-        """
-        if self._process is None:
-            raise AirflowException("Tried to get exit code before starting!")
-        if not self._done:
-            raise AirflowException("Tried to call retcode before process was finished!")
-        return self._process.exitcode
-
-    @property
-    def done(self) -> bool:
-        """
-        Check if the process launched to process this file is done.
-
-        :return: whether the process is finished running
-        :rtype: bool
-        """
-        if self._process is None or self._parent_channel is None:
-            raise AirflowException("Tried to see if it's done before starting!")
-
-        if self._done:
-            return True
-
-        if self._parent_channel.poll():
-            try:
-                self._result = self._parent_channel.recv()
-                self._done = True
-                self.log.debug("Waiting for %s", self._process)
-                self._process.join()
-                self._parent_channel.close()
-                return True
-            except EOFError:
-                # If we get an EOFError, it means the child end of the pipe has been closed. This only happens
-                # in the finally block. But due to a possible race condition, the process may have not yet
-                # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a
-                # "suitable" timeout.
-                self._done = True
-                # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable.
-                self._process.join(timeout=5)
-                if self._process.is_alive():
-                    # Didn't shut down cleanly - kill it
-                    self._kill_process()
-
-        if not self._process.is_alive():
-            self._done = True
-            self.log.debug("Waiting for %s", self._process)
-            self._process.join()
-            self._parent_channel.close()
-            return True
-
-        return False
-
-    @property
-    def result(self) -> Optional[Tuple[int, int]]:
-        """
-        :return: result of running DagFileProcessor.process_file()
-        :rtype: tuple[int, int] or None
-        """
-        if not self.done:
-            raise AirflowException("Tried to get the result before it's done!")
-        return self._result
-
-    @property
-    def start_time(self) -> datetime.datetime:
-        """
-        :return: when this started to process the file
-        :rtype: datetime
-        """
-        if self._start_time is None:
-            raise AirflowException("Tried to get start time before it started!")
-        return self._start_time
-
-    @property
-    def waitable_handle(self):
-        return self._process.sentinel
-
-
-class DagFileProcessor(LoggingMixin):
-    """
-    Process a Python file containing Airflow DAGs.
-
-    This includes:
-
-    1. Execute the file and look for DAG objects in the namespace.
-    2. Execute any Callbacks if passed to DagFileProcessor.process_file
-    3. Serialize the DAGs and save it to DB (or update existing record in the DB).
-    4. Pickle the DAG and save it to the DB (if necessary).
-    5. Record any errors importing the file into ORM
-
-    Returns a tuple of 'number of dags found' and 'the count of import errors'
-
-    :param dag_ids: If specified, only look at these DAG ID's
-    :type dag_ids: List[str]
-    :param log: Logger to save the processing process
-    :type log: logging.Logger
-    """
-
-    UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE')
-
-    def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger):
-        super().__init__()
-        self.dag_ids = dag_ids
-        self._log = log
-
-    @provide_session
-    def manage_slas(self, dag: DAG, session: Session = None) -> None:
-        """
-        Finding all tasks that have SLAs defined, and sending alert emails
-        where needed. New SLA misses are also recorded in the database.
-
-        We are assuming that the scheduler runs often, so we only check for
-        tasks that should have succeeded in the past hour.
-        """
-        self.log.info("Running SLA Checks for %s", dag.dag_id)
-        if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
-            self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag)
-            return
-
-        qry = (
-            session.query(TI.task_id, func.max(TI.execution_date).label('max_ti'))
-            .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
-            .filter(TI.dag_id == dag.dag_id)
-            .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
-            .filter(TI.task_id.in_(dag.task_ids))
-            .group_by(TI.task_id)
-            .subquery('sq')
-        )
-
-        max_tis: List[TI] = (
-            session.query(TI)
-            .filter(
-                TI.dag_id == dag.dag_id,
-                TI.task_id == qry.c.task_id,
-                TI.execution_date == qry.c.max_ti,
-            )
-            .all()
-        )
-
-        ts = timezone.utcnow()
-        for ti in max_tis:
-            task = dag.get_task(ti.task_id)
-            if task.sla and not isinstance(task.sla, timedelta):
-                raise TypeError(
-                    f"SLA is expected to be timedelta object, got "
-                    f"{type(task.sla)} in {task.dag_id}:{task.task_id}"
-                )
-
-            dttm = dag.following_schedule(ti.execution_date)
-            while dttm < timezone.utcnow():
-                following_schedule = dag.following_schedule(dttm)
-                if following_schedule + task.sla < timezone.utcnow():
-                    session.merge(
-                        SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)
-                    )
-                dttm = dag.following_schedule(dttm)
-        session.commit()
-
-        # pylint: disable=singleton-comparison
-        slas: List[SlaMiss] = (
-            session.query(SlaMiss)
-            .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id)  # noqa
-            .all()
-        )
-        # pylint: enable=singleton-comparison
-
-        if slas:  # pylint: disable=too-many-nested-blocks
-            sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas]
-            fetched_tis: List[TI] = (
-                session.query(TI)
-                .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
-                .all()
-            )
-            blocking_tis: List[TI] = []
-            for ti in fetched_tis:
-                if ti.task_id in dag.task_ids:
-                    ti.task = dag.get_task(ti.task_id)
-                    blocking_tis.append(ti)
-                else:
-                    session.delete(ti)
-                    session.commit()
-
-            task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas)
-            blocking_task_list = "\n".join(
-                ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis
-            )
-            # Track whether email or any alert notification sent
-            # We consider email or the alert callback as notifications
-            email_sent = False
-            notification_sent = False
-            if dag.sla_miss_callback:
-                # Execute the alert callback
-                self.log.info('Calling SLA miss callback')
-                try:
-                    dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
-                    notification_sent = True
-                except Exception:  # pylint: disable=broad-except
-                    self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
-            email_content = f"""\
-            Here's a list of tasks that missed their SLAs:
-            <pre><code>{task_list}\n<code></pre>
-            Blocking tasks:
-            <pre><code>{blocking_task_list}<code></pre>
-            Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
-            """
-
-            tasks_missed_sla = []
-            for sla in slas:
-                try:
-                    task = dag.get_task(sla.task_id)
-                except TaskNotFound:
-                    # task already deleted from DAG, skip it
-                    self.log.warning(
-                        "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id
-                    )
-                    continue
-                tasks_missed_sla.append(task)
-
-            emails: Set[str] = set()
-            for task in tasks_missed_sla:
-                if task.email:
-                    if isinstance(task.email, str):
-                        emails |= set(get_email_address_list(task.email))
-                    elif isinstance(task.email, (list, tuple)):
-                        emails |= set(task.email)
-            if emails:
-                try:
-                    send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content)
-                    email_sent = True
-                    notification_sent = True
-                except Exception:  # pylint: disable=broad-except
-                    Stats.incr('sla_email_notification_failure')
-                    self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id)
-            # If we sent any notification, update the sla_miss table
-            if notification_sent:
-                for sla in slas:
-                    sla.email_sent = email_sent
-                    sla.notification_sent = True
-                    session.merge(sla)
-            session.commit()
-
-    @staticmethod
-    def update_import_errors(session: Session, dagbag: DagBag) -> None:
-        """
-        For the DAGs in the given DagBag, record any associated import errors and clears
-        errors for files that no longer have them. These are usually displayed through the
-        Airflow UI so that users know that there are issues parsing DAGs.
-
-        :param session: session for ORM operations
-        :type session: sqlalchemy.orm.session.Session
-        :param dagbag: DagBag containing DAGs with import errors
-        :type dagbag: airflow.DagBag
-        """
-        # Clear the errors of the processed files
-        for dagbag_file in dagbag.file_last_changed:
-            session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete()
-
-        # Add the errors of the processed files
-        for filename, stacktrace in dagbag.import_errors.items():
-            session.add(
-                errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace)
-            )
-        session.commit()
-
-    @provide_session
-    def execute_callbacks(
-        self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None
-    ) -> None:
-        """
-        Execute on failure callbacks. These objects can come from SchedulerJob or from
-        DagFileProcessorManager.
-
-        :param dagbag: Dag Bag of dags
-        :param callback_requests: failure callbacks to execute
-        :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest]
-        :param session: DB session.
-        """
-        for request in callback_requests:
-            self.log.debug("Processing Callback Request: %s", request)
-            try:
-                if isinstance(request, TaskCallbackRequest):
-                    self._execute_task_callbacks(dagbag, request)
-                elif isinstance(request, SlaCallbackRequest):
-                    self.manage_slas(dagbag.dags.get(request.dag_id))
-                elif isinstance(request, DagCallbackRequest):
-                    self._execute_dag_callbacks(dagbag, request, session)
-            except Exception:  # pylint: disable=broad-except
-                self.log.exception(
-                    "Error executing %s callback for file: %s",
-                    request.__class__.__name__,
-                    request.full_filepath,
-                )
-
-        session.commit()
-
-    @provide_session
-    def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
-        dag = dagbag.dags[request.dag_id]
-        dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session)
-        dag.handle_callback(
-            dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
-        )
-
-    def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
-        simple_ti = request.simple_task_instance
-        if simple_ti.dag_id in dagbag.dags:
-            dag = dagbag.dags[simple_ti.dag_id]
-            if simple_ti.task_id in dag.task_ids:
-                task = dag.get_task(simple_ti.task_id)
-                ti = TI(task, simple_ti.execution_date)
-                # Get properties needed for failure handling from SimpleTaskInstance.
-                ti.start_date = simple_ti.start_date
-                ti.end_date = simple_ti.end_date
-                ti.try_number = simple_ti.try_number
-                ti.state = simple_ti.state
-                ti.test_mode = self.UNIT_TEST_MODE
-                if request.is_failure_callback:
-                    ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
-                    self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
-
-    @provide_session
-    def process_file(
-        self,
-        file_path: str,
-        callback_requests: List[CallbackRequest],
-        pickle_dags: bool = False,
-        session: Session = None,
-    ) -> Tuple[int, int]:
-        """
-        Process a Python file containing Airflow DAGs.
-
-        This includes:
-
-        1. Execute the file and look for DAG objects in the namespace.
-        2. Execute any Callbacks if passed to this method.
-        3. Serialize the DAGs and save it to DB (or update existing record in the DB).
-        4. Pickle the DAG and save it to the DB (if necessary).
-        5. Record any errors importing the file into ORM
-
-        :param file_path: the path to the Python file that should be executed
-        :type file_path: str
-        :param callback_requests: failure callback to execute
-        :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest]
-        :param pickle_dags: whether serialize the DAGs found in the file and
-            save them to the db
-        :type pickle_dags: bool
-        :param session: Sqlalchemy ORM Session
-        :type session: Session
-        :return: number of dags found, count of import errors
-        :rtype: Tuple[int, int]
-        """
-        self.log.info("Processing file %s for tasks to queue", file_path)
-
-        try:
-            dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False)
-        except Exception:  # pylint: disable=broad-except
-            self.log.exception("Failed at reloading the DAG file %s", file_path)
-            Stats.incr('dag_file_refresh_error', 1, 1)
-            return 0, 0
-
-        if len(dagbag.dags) > 0:
-            self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path)
-        else:
-            self.log.warning("No viable dags retrieved from %s", file_path)
-            self.update_import_errors(session, dagbag)
-            return 0, len(dagbag.import_errors)
-
-        self.execute_callbacks(dagbag, callback_requests)
-
-        # Save individual DAGs in the ORM
-        dagbag.sync_to_db()
-
-        if pickle_dags:
-            paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
-
-            unpaused_dags: List[DAG] = [
-                dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids
-            ]
-
-            for dag in unpaused_dags:
-                dag.pickle(session)
-
-        # Record import errors into the ORM
-        try:
-            self.update_import_errors(session, dagbag)
-        except Exception:  # pylint: disable=broad-except
-            self.log.exception("Error logging import errors!")
-
-        return len(dagbag.dags), len(dagbag.import_errors)
-
-
 def _is_parent_process():
     """
     Returns True if the current process is the parent process. False if the current process is a child
diff --git a/tests/dag_processing/__init__.py b/tests/dag_processing/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/tests/dag_processing/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/utils/test_dag_processing.py b/tests/dag_processing/test_manager.py
similarity index 99%
rename from tests/utils/test_dag_processing.py
rename to tests/dag_processing/test_manager.py
index 58ad010..0ab7f2b 100644
--- a/tests/utils/test_dag_processing.py
+++ b/tests/dag_processing/test_manager.py
@@ -34,20 +34,20 @@ import pytest
 from freezegun import freeze_time
 
 from airflow.configuration import conf
-from airflow.jobs.local_task_job import LocalTaskJob as LJ
-from airflow.jobs.scheduler_job import DagFileProcessorProcess
-from airflow.models import DagBag, DagModel, TaskInstance as TI
-from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import SimpleTaskInstance
-from airflow.utils import timezone
-from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest
-from airflow.utils.dag_processing import (
+from airflow.dag_processing.manager import (
     DagFileProcessorAgent,
     DagFileProcessorManager,
     DagFileStat,
     DagParsingSignal,
     DagParsingStat,
 )
+from airflow.dag_processing.processor import DagFileProcessorProcess
+from airflow.jobs.local_task_job import LocalTaskJob as LJ
+from airflow.models import DagBag, DagModel, TaskInstance as TI
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.utils import timezone
+from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest
 from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session
 from airflow.utils.state import State
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
new file mode 100644
index 0000000..5953517
--- /dev/null
+++ b/tests/dag_processing/test_processor.py
@@ -0,0 +1,749 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# pylint: disable=attribute-defined-outside-init
+import datetime
+import os
+import unittest
+from datetime import timedelta
+from tempfile import NamedTemporaryFile
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import pytest
+from parameterized import parameterized
+
+from airflow import settings
+from airflow.configuration import conf
+from airflow.dag_processing.processor import DagFileProcessor
+from airflow.jobs.scheduler_job import SchedulerJob
+from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance
+from airflow.models.dagrun import DagRun
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.operators.bash import BashOperator
+from airflow.operators.dummy import DummyOperator
+from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.utils import timezone
+from airflow.utils.callback_requests import TaskCallbackRequest
+from airflow.utils.dates import days_ago
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from airflow.utils.types import DagRunType
+from tests.test_utils.config import conf_vars, env_vars
+from tests.test_utils.db import (
+    clear_db_dags,
+    clear_db_import_errors,
+    clear_db_jobs,
+    clear_db_pools,
+    clear_db_runs,
+    clear_db_serialized_dags,
+    clear_db_sla_miss,
+)
+from tests.test_utils.mock_executor import MockExecutor
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+
+@pytest.fixture(scope="class")
+def disable_load_example():
+    with conf_vars({('core', 'load_examples'): 'false'}):
+        with env_vars({('core', 'load_examples'): 'false'}):
+            yield
+
+
+@pytest.mark.usefixtures("disable_load_example")
+class TestDagFileProcessor(unittest.TestCase):
+    @staticmethod
+    def clean_db():
+        clear_db_runs()
+        clear_db_pools()
+        clear_db_dags()
+        clear_db_sla_miss()
+        clear_db_import_errors()
+        clear_db_jobs()
+        clear_db_serialized_dags()
+
+    def setUp(self):
+        self.clean_db()
+
+        # Speed up some tests by not running the tasks, just look at what we
+        # enqueue!
+        self.null_exec = MockExecutor()
+        self.scheduler_job = None
+
+    def tearDown(self) -> None:
+        if self.scheduler_job and self.scheduler_job.processor_agent:
+            self.scheduler_job.processor_agent.end()
+            self.scheduler_job = None
+        self.clean_db()
+
+    def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs):
+        dag = DAG(
+            dag_id='test_scheduler_reschedule',
+            start_date=start_date,
+            # Make sure it only creates a single DAG Run
+            end_date=end_date,
+        )
+        dag.clear()
+        dag.is_subdag = False
+        with create_session() as session:
+            orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
+            session.merge(orm_dag)
+            session.commit()
+        return dag
+
+    @classmethod
+    def setUpClass(cls):
+        # Ensure the DAGs we are looking at from the DB are up-to-date
+        non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
+        non_serialized_dagbag.sync_to_db()
+        cls.dagbag = DagBag(read_dags_from_db=True)
+
+    def test_dag_file_processor_sla_miss_callback(self):
+        """
+        Test that the dag file processor calls the sla miss callback
+        """
+        session = settings.Session()
+
+        sla_callback = MagicMock()
+
+        # Create dag with a start of 1 day ago, but an sla of 0
+        # so we'll already have an sla_miss on the books.
+        test_start_date = days_ago(1)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            sla_miss_callback=sla_callback,
+            default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
+        )
+
+        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        dag_file_processor.manage_slas(dag=dag, session=session)
+
+        assert sla_callback.called
+
+    def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
+        """
+        Test that the dag file processor does not call the sla miss callback when
+        given an invalid sla
+        """
+        session = settings.Session()
+
+        sla_callback = MagicMock()
+
+        # Create dag with a start of 1 day ago, but an sla of 0
+        # so we'll already have an sla_miss on the books.
+        # Pass anything besides a timedelta object to the sla argument.
+        test_start_date = days_ago(1)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            sla_miss_callback=sla_callback,
+            default_args={'start_date': test_start_date, 'sla': None},
+        )
+
+        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        dag_file_processor.manage_slas(dag=dag, session=session)
+        sla_callback.assert_not_called()
+
+    def test_dag_file_processor_sla_miss_callback_sent_notification(self):
+        """
+        Test that the dag file processor does not call the sla_miss_callback when a
+        notification has already been sent
+        """
+        session = settings.Session()
+
+        # Mock the callback function so we can verify that it was not called
+        sla_callback = MagicMock()
+
+        # Create dag with a start of 2 days ago, but an sla of 1 day
+        # ago so we'll already have an sla_miss on the books
+        test_start_date = days_ago(2)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            sla_miss_callback=sla_callback,
+            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+        )
+
+        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
+
+        # Create a TaskInstance for two days ago
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
+
+        # Create an SlaMiss where notification was sent, but email was not
+        session.merge(
+            SlaMiss(
+                task_id='dummy',
+                dag_id='test_sla_miss',
+                execution_date=test_start_date,
+                email_sent=False,
+                notification_sent=True,
+            )
+        )
+
+        # Now call manage_slas and see if the sla_miss callback gets called
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        dag_file_processor.manage_slas(dag=dag, session=session)
+
+        sla_callback.assert_not_called()
+
+    def test_dag_file_processor_sla_miss_callback_exception(self):
+        """
+        Test that the dag file processor gracefully logs an exception if there is a problem
+        calling the sla_miss_callback
+        """
+        session = settings.Session()
+
+        sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))
+
+        test_start_date = days_ago(2)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            sla_miss_callback=sla_callback,
+            default_args={'start_date': test_start_date},
+        )
+
+        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))
+
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+        # Create an SlaMiss where notification was sent, but email was not
+        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+        # Now call manage_slas and see if the sla_miss callback gets called
+        mock_log = mock.MagicMock()
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+        dag_file_processor.manage_slas(dag=dag, session=session)
+        assert sla_callback.called
+        mock_log.exception.assert_called_once_with(
+            'Could not call sla_miss_callback for DAG %s', 'test_sla_miss'
+        )
+
+    @mock.patch('airflow.dag_processing.processor.send_email')
+    def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
+        session = settings.Session()
+
+        test_start_date = days_ago(2)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+        )
+
+        email1 = 'test1@test.com'
+        task = DummyOperator(
+            task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
+        )
+
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+        email2 = 'test2@test.com'
+        DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2)
+
+        session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
+
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+
+        dag_file_processor.manage_slas(dag=dag, session=session)
+
+        assert len(mock_send_email.call_args_list) == 1
+
+        send_email_to = mock_send_email.call_args_list[0][0][0]
+        assert email1 in send_email_to
+        assert email2 not in send_email_to
+
+    @mock.patch('airflow.jobs.scheduler_job.Stats.incr')
+    @mock.patch("airflow.utils.email.send_email")
+    def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
+        """
+        Test that the dag file processor gracefully logs an exception if there is a problem
+        sending an email
+        """
+        session = settings.Session()
+
+        # Mock the callback function so we can verify that it was not called
+        mock_send_email.side_effect = RuntimeError('Could not send an email')
+
+        test_start_date = days_ago(2)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+        )
+
+        task = DummyOperator(
+            task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+        )
+
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+        # Create an SlaMiss where notification was sent, but email was not
+        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
+
+        mock_log = mock.MagicMock()
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+
+        dag_file_processor.manage_slas(dag=dag, session=session)
+        mock_log.exception.assert_called_once_with(
+            'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss'
+        )
+        mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
+
+    def test_dag_file_processor_sla_miss_deleted_task(self):
+        """
+        Test that the dag file processor will not crash when trying to send
+        sla miss notification for a deleted task
+        """
+        session = settings.Session()
+
+        test_start_date = days_ago(2)
+        dag = DAG(
+            dag_id='test_sla_miss',
+            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
+        )
+
+        task = DummyOperator(
+            task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
+        )
+
+        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
+
+        # Create an SlaMiss where notification was sent, but email was not
+        session.merge(
+            SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date)
+        )
+
+        mock_log = mock.MagicMock()
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
+        dag_file_processor.manage_slas(dag=dag, session=session)
+
+    @parameterized.expand(
+        [
+            [State.NONE, None, None],
+            [
+                State.UP_FOR_RETRY,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+            [
+                State.UP_FOR_RESCHEDULE,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+        ]
+    )
+    def test_dag_file_processor_process_task_instances(self, state, start_date, end_date):
+        """
+        Test if _process_task_instances puts the right task instances into the
+        mock_list.
+        """
+        dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE)
+        BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi')
+
+        with create_session() as session:
+            orm_dag = DagModel(dag_id=dag.dag_id)
+            session.merge(orm_dag)
+
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.MagicMock()
+        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+        dag.clear()
+        dr = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+        assert dr is not None
+
+        with create_session() as session:
+            ti = dr.get_task_instances(session=session)[0]
+            ti.state = state
+            ti.start_date = start_date
+            ti.end_date = end_date
+
+            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+            assert count == 1
+
+            session.refresh(ti)
+            assert ti.state == State.SCHEDULED
+
+    @parameterized.expand(
+        [
+            [State.NONE, None, None],
+            [
+                State.UP_FOR_RETRY,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+            [
+                State.UP_FOR_RESCHEDULE,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+        ]
+    )
+    def test_dag_file_processor_process_task_instances_with_task_concurrency(
+        self,
+        state,
+        start_date,
+        end_date,
+    ):
+        """
+        Test if _process_task_instances puts the right task instances into the
+        mock_list.
+        """
+        dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE)
+        BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi')
+
+        with create_session() as session:
+            orm_dag = DagModel(dag_id=dag.dag_id)
+            session.merge(orm_dag)
+
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.MagicMock()
+        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+        dag.clear()
+        dr = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+        assert dr is not None
+
+        with create_session() as session:
+            ti = dr.get_task_instances(session=session)[0]
+            ti.state = state
+            ti.start_date = start_date
+            ti.end_date = end_date
+
+            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+            assert count == 1
+
+            session.refresh(ti)
+            assert ti.state == State.SCHEDULED
+
+    @parameterized.expand(
+        [
+            [State.NONE, None, None],
+            [
+                State.UP_FOR_RETRY,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+            [
+                State.UP_FOR_RESCHEDULE,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+        ]
+    )
+    def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date):
+        """
+        Test if _process_task_instances puts the right task instances into the
+        mock_list.
+        """
+        dag = DAG(
+            dag_id='test_scheduler_process_execute_task_depends_on_past',
+            start_date=DEFAULT_DATE,
+            default_args={
+                'depends_on_past': True,
+            },
+        )
+        BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi')
+        BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi')
+
+        with create_session() as session:
+            orm_dag = DagModel(dag_id=dag.dag_id)
+            session.merge(orm_dag)
+
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.MagicMock()
+        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+        dag.clear()
+        dr = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+        assert dr is not None
+
+        with create_session() as session:
+            tis = dr.get_task_instances(session=session)
+            for ti in tis:
+                ti.state = state
+                ti.start_date = start_date
+                ti.end_date = end_date
+
+            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
+            assert count == 2
+
+            session.refresh(tis[0])
+            session.refresh(tis[1])
+            assert tis[0].state == State.SCHEDULED
+            assert tis[1].state == State.SCHEDULED
+
+    def test_scheduler_job_add_new_task(self):
+        """
+        Test if a task instance will be added if the dag is updated
+        """
+        dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE)
+        BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test')
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+
+        # Since we don't want to store the code for the DAG defined in this file
+        with mock.patch.object(settings, "STORE_DAG_CODE", False):
+            self.scheduler_job.dagbag.sync_to_db()
+
+        session = settings.Session()
+        orm_dag = session.query(DagModel).get(dag.dag_id)
+        assert orm_dag is not None
+
+        if self.scheduler_job.processor_agent:
+            self.scheduler_job.processor_agent.end()
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.MagicMock()
+        dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session)
+        self.scheduler_job._create_dag_runs([orm_dag], session)
+
+        drs = DagRun.find(dag_id=dag.dag_id, session=session)
+        assert len(drs) == 1
+        dr = drs[0]
+
+        tis = dr.get_task_instances()
+        assert len(tis) == 1
+
+        BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
+        SerializedDagModel.write_dag(dag=dag)
+
+        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session)
+        session.flush()
+        assert scheduled_tis == 2
+
+        drs = DagRun.find(dag_id=dag.dag_id, session=session)
+        assert len(drs) == 1
+        dr = drs[0]
+
+        tis = dr.get_task_instances()
+        assert len(tis) == 2
+
+    def test_runs_respected_after_clear(self):
+        """
+        Test if _process_task_instances only schedules ti's up to max_active_runs
+        (related to issue AIRFLOW-137)
+        """
+        dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
+        dag.max_active_runs = 3
+
+        BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
+
+        session = settings.Session()
+        orm_dag = DagModel(dag_id=dag.dag_id)
+        session.merge(orm_dag)
+        session.commit()
+        session.close()
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job.processor_agent = mock.MagicMock()
+        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+        dag.clear()
+
+        date = DEFAULT_DATE
+        dr1 = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            execution_date=date,
+            state=State.RUNNING,
+        )
+        date = dag.following_schedule(date)
+        dr2 = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            execution_date=date,
+            state=State.RUNNING,
+        )
+        date = dag.following_schedule(date)
+        dr3 = dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            execution_date=date,
+            state=State.RUNNING,
+        )
+
+        # First create up to 3 dagruns in RUNNING state.
+        assert dr1 is not None
+        assert dr2 is not None
+        assert dr3 is not None
+        assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3
+
+        # Reduce max_active_runs to 1
+        dag.max_active_runs = 1
+
+        # and schedule them in, so we can check how many
+        # tasks are put on the task_instances_list (should be one, not 3)
+        with create_session() as session:
+            num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session)
+            assert num_scheduled == 1
+            num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session)
+            assert num_scheduled == 0
+            num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session)
+            assert num_scheduled == 0
+
+    @patch.object(TaskInstance, 'handle_failure_with_callback')
+    def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
+        dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        with create_session() as session:
+            session.query(TaskInstance).delete()
+            dag = dagbag.get_dag('example_branch_operator')
+            task = dag.get_task(task_id='run_this_first')
+
+            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+
+            session.add(ti)
+            session.commit()
+
+            requests = [
+                TaskCallbackRequest(
+                    full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
+                )
+            ]
+            dag_file_processor.execute_callbacks(dagbag, requests)
+            mock_ti_handle_failure.assert_called_once_with(
+                error="Message",
+                test_mode=conf.getboolean('core', 'unit_test_mode'),
+            )
+
+    def test_process_file_should_failure_callback(self):
+        dag_file = os.path.join(
+            os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'
+        )
+        dagbag = DagBag(dag_folder=dag_file, include_examples=False)
+        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
+        with create_session() as session, NamedTemporaryFile(delete=False) as callback_file:
+            session.query(TaskInstance).delete()
+            dag = dagbag.get_dag('test_om_failure_callback_dag')
+            task = dag.get_task(task_id='test_om_failure_callback_task')
+
+            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+
+            session.add(ti)
+            session.commit()
+
+            requests = [
+                TaskCallbackRequest(
+                    full_filepath=dag.full_filepath,
+                    simple_task_instance=SimpleTaskInstance(ti),
+                    msg="Message",
+                )
+            ]
+            callback_file.close()
+
+            with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}):
+                dag_file_processor.process_file(dag_file, requests)
+            with open(callback_file.name) as callback_file2:
+                content = callback_file2.read()
+            assert "Callback fired" == content
+            os.remove(callback_file.name)
+
+    def test_should_mark_dummy_task_as_success(self):
+        dag_file = os.path.join(
+            os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py'
+        )
+
+        # Write DAGs to dag and serialized_dag table
+        dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False)
+        dagbag.sync_to_db()
+
+        self.scheduler_job_job = SchedulerJob(subdir=os.devnull)
+        self.scheduler_job_job.processor_agent = mock.MagicMock()
+        dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks")
+
+        # Create DagRun
+        session = settings.Session()
+        orm_dag = session.query(DagModel).get(dag.dag_id)
+        self.scheduler_job_job._create_dag_runs([orm_dag], session)
+
+        drs = DagRun.find(dag_id=dag.dag_id, session=session)
+        assert len(drs) == 1
+        dr = drs[0]
+
+        # Schedule TaskInstances
+        self.scheduler_job_job._schedule_dag_run(dr, {}, session)
+        with create_session() as session:
+            tis = session.query(TaskInstance).all()
+
+        dags = self.scheduler_job_job.dagbag.dags.values()
+        assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags]
+        assert 5 == len(tis)
+        assert {
+            ('test_task_a', 'success'),
+            ('test_task_b', None),
+            ('test_task_c', 'success'),
+            ('test_task_on_execute', 'scheduled'),
+            ('test_task_on_success', 'scheduled'),
+        } == {(ti.task_id, ti.state) for ti in tis}
+        for state, start_date, end_date, duration in [
+            (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+        ]:
+            if state == 'success':
+                assert start_date is not None
+                assert end_date is not None
+                assert 0.0 == duration
+            else:
+                assert start_date is None
+                assert end_date is None
+                assert duration is None
+
+        self.scheduler_job_job._schedule_dag_run(dr, {}, session)
+        with create_session() as session:
+            tis = session.query(TaskInstance).all()
+
+        assert 5 == len(tis)
+        assert {
+            ('test_task_a', 'success'),
+            ('test_task_b', 'success'),
+            ('test_task_c', 'success'),
+            ('test_task_on_execute', 'scheduled'),
+            ('test_task_on_success', 'scheduled'),
+        } == {(ti.task_id, ti.state) for ti in tis}
+        for state, start_date, end_date, duration in [
+            (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
+        ]:
+            if state == 'success':
+                assert start_date is not None
+                assert end_date is not None
+                assert 0.0 == duration
+            else:
+                assert start_date is None
+                assert end_date is None
+                assert duration is None
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index fe0b257..9fe8517 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -22,7 +22,7 @@ import os
 import shutil
 import unittest
 from datetime import timedelta
-from tempfile import NamedTemporaryFile, mkdtemp
+from tempfile import mkdtemp
 from time import sleep
 from unittest import mock
 from unittest.mock import MagicMock, patch
@@ -37,22 +37,20 @@ from sqlalchemy import func
 import airflow.example_dags
 import airflow.smart_sensor_dags
 from airflow import settings
-from airflow.configuration import conf
+from airflow.dag_processing.manager import DagFileProcessorAgent
 from airflow.exceptions import AirflowException
 from airflow.executors.base_executor import BaseExecutor
 from airflow.jobs.backfill_job import BackfillJob
-from airflow.jobs.scheduler_job import DagFileProcessor, SchedulerJob
-from airflow.models import DAG, DagBag, DagModel, Pool, SlaMiss, TaskInstance, errors
+from airflow.jobs.scheduler_job import SchedulerJob
+from airflow.models import DAG, DagBag, DagModel, Pool, TaskInstance, errors
 from airflow.models.dagrun import DagRun
 from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
+from airflow.models.taskinstance import TaskInstanceKey
 from airflow.operators.bash import BashOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.serialization.serialized_objects import SerializedDAG
 from airflow.utils import timezone
-from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest
-from airflow.utils.dag_processing import DagFileProcessorAgent
-from airflow.utils.dates import days_ago
+from airflow.utils.callback_requests import DagCallbackRequest
 from airflow.utils.file import list_py_file_paths
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import State
@@ -101,688 +99,6 @@ def disable_load_example():
 
 
 @pytest.mark.usefixtures("disable_load_example")
-class TestDagFileProcessor(unittest.TestCase):
-    @staticmethod
-    def clean_db():
-        clear_db_runs()
-        clear_db_pools()
-        clear_db_dags()
-        clear_db_sla_miss()
-        clear_db_import_errors()
-        clear_db_jobs()
-        clear_db_serialized_dags()
-
-    def setUp(self):
-        self.clean_db()
-
-        # Speed up some tests by not running the tasks, just look at what we
-        # enqueue!
-        self.null_exec = MockExecutor()
-        self.scheduler_job = None
-
-    def tearDown(self) -> None:
-        if self.scheduler_job and self.scheduler_job.processor_agent:
-            self.scheduler_job.processor_agent.end()
-            self.scheduler_job = None
-        self.clean_db()
-
-    def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs):
-        dag = DAG(
-            dag_id='test_scheduler_reschedule',
-            start_date=start_date,
-            # Make sure it only creates a single DAG Run
-            end_date=end_date,
-        )
-        dag.clear()
-        dag.is_subdag = False
-        with create_session() as session:
-            orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
-            session.merge(orm_dag)
-            session.commit()
-        return dag
-
-    @classmethod
-    def setUpClass(cls):
-        # Ensure the DAGs we are looking at from the DB are up-to-date
-        non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
-        non_serialized_dagbag.sync_to_db()
-        cls.dagbag = DagBag(read_dags_from_db=True)
-
-    def test_dag_file_processor_sla_miss_callback(self):
-        """
-        Test that the dag file processor calls the sla miss callback
-        """
-        session = settings.Session()
-
-        sla_callback = MagicMock()
-
-        # Create dag with a start of 1 day ago, but an sla of 0
-        # so we'll already have an sla_miss on the books.
-        test_start_date = days_ago(1)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            sla_miss_callback=sla_callback,
-            default_args={'start_date': test_start_date, 'sla': datetime.timedelta()},
-        )
-
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
-
-        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-        dag_file_processor.manage_slas(dag=dag, session=session)
-
-        assert sla_callback.called
-
-    def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
-        """
-        Test that the dag file processor does not call the sla miss callback when
-        given an invalid sla
-        """
-        session = settings.Session()
-
-        sla_callback = MagicMock()
-
-        # Create dag with a start of 1 day ago, but an sla of 0
-        # so we'll already have an sla_miss on the books.
-        # Pass anything besides a timedelta object to the sla argument.
-        test_start_date = days_ago(1)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            sla_miss_callback=sla_callback,
-            default_args={'start_date': test_start_date, 'sla': None},
-        )
-
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
-
-        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-        dag_file_processor.manage_slas(dag=dag, session=session)
-        sla_callback.assert_not_called()
-
-    def test_dag_file_processor_sla_miss_callback_sent_notification(self):
-        """
-        Test that the dag file processor does not call the sla_miss_callback when a
-        notification has already been sent
-        """
-        session = settings.Session()
-
-        # Mock the callback function so we can verify that it was not called
-        sla_callback = MagicMock()
-
-        # Create dag with a start of 2 days ago, but an sla of 1 day
-        # ago so we'll already have an sla_miss on the books
-        test_start_date = days_ago(2)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            sla_miss_callback=sla_callback,
-            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
-        )
-
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
-        # Create a TaskInstance for two days ago
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success'))
-
-        # Create an SlaMiss where notification was sent, but email was not
-        session.merge(
-            SlaMiss(
-                task_id='dummy',
-                dag_id='test_sla_miss',
-                execution_date=test_start_date,
-                email_sent=False,
-                notification_sent=True,
-            )
-        )
-
-        # Now call manage_slas and see if the sla_miss callback gets called
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-        dag_file_processor.manage_slas(dag=dag, session=session)
-
-        sla_callback.assert_not_called()
-
-    def test_dag_file_processor_sla_miss_callback_exception(self):
-        """
-        Test that the dag file processor gracefully logs an exception if there is a problem
-        calling the sla_miss_callback
-        """
-        session = settings.Session()
-
-        sla_callback = MagicMock(side_effect=RuntimeError('Could not call function'))
-
-        test_start_date = days_ago(2)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            sla_miss_callback=sla_callback,
-            default_args={'start_date': test_start_date},
-        )
-
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1))
-
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
-        # Create an SlaMiss where notification was sent, but email was not
-        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
-        # Now call manage_slas and see if the sla_miss callback gets called
-        mock_log = mock.MagicMock()
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
-        dag_file_processor.manage_slas(dag=dag, session=session)
-        assert sla_callback.called
-        mock_log.exception.assert_called_once_with(
-            'Could not call sla_miss_callback for DAG %s', 'test_sla_miss'
-        )
-
-    @mock.patch('airflow.jobs.scheduler_job.send_email')
-    def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email):
-        session = settings.Session()
-
-        test_start_date = days_ago(2)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
-        )
-
-        email1 = 'test1@test.com'
-        task = DummyOperator(
-            task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1)
-        )
-
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
-        email2 = 'test2@test.com'
-        DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2)
-
-        session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date))
-
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-
-        dag_file_processor.manage_slas(dag=dag, session=session)
-
-        assert len(mock_send_email.call_args_list) == 1
-
-        send_email_to = mock_send_email.call_args_list[0][0][0]
-        assert email1 in send_email_to
-        assert email2 not in send_email_to
-
-    @mock.patch('airflow.jobs.scheduler_job.Stats.incr')
-    @mock.patch("airflow.utils.email.send_email")
-    def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr):
-        """
-        Test that the dag file processor gracefully logs an exception if there is a problem
-        sending an email
-        """
-        session = settings.Session()
-
-        # Mock the callback function so we can verify that it was not called
-        mock_send_email.side_effect = RuntimeError('Could not send an email')
-
-        test_start_date = days_ago(2)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
-        )
-
-        task = DummyOperator(
-            task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
-        )
-
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
-        # Create an SlaMiss where notification was sent, but email was not
-        session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date))
-
-        mock_log = mock.MagicMock()
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
-
-        dag_file_processor.manage_slas(dag=dag, session=session)
-        mock_log.exception.assert_called_once_with(
-            'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss'
-        )
-        mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
-
-    def test_dag_file_processor_sla_miss_deleted_task(self):
-        """
-        Test that the dag file processor will not crash when trying to send
-        sla miss notification for a deleted task
-        """
-        session = settings.Session()
-
-        test_start_date = days_ago(2)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)},
-        )
-
-        task = DummyOperator(
-            task_id='dummy', dag=dag, owner='airflow', email='test@test.com', sla=datetime.timedelta(hours=1)
-        )
-
-        session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success'))
-
-        # Create an SlaMiss where notification was sent, but email was not
-        session.merge(
-            SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date)
-        )
-
-        mock_log = mock.MagicMock()
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log)
-        dag_file_processor.manage_slas(dag=dag, session=session)
-
-    @parameterized.expand(
-        [
-            [State.NONE, None, None],
-            [
-                State.UP_FOR_RETRY,
-                timezone.utcnow() - datetime.timedelta(minutes=30),
-                timezone.utcnow() - datetime.timedelta(minutes=15),
-            ],
-            [
-                State.UP_FOR_RESCHEDULE,
-                timezone.utcnow() - datetime.timedelta(minutes=30),
-                timezone.utcnow() - datetime.timedelta(minutes=15),
-            ],
-        ]
-    )
-    def test_dag_file_processor_process_task_instances(self, state, start_date, end_date):
-        """
-        Test if _process_task_instances puts the right task instances into the
-        mock_list.
-        """
-        dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE)
-        BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi')
-
-        with create_session() as session:
-            orm_dag = DagModel(dag_id=dag.dag_id)
-            session.merge(orm_dag)
-
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.processor_agent = mock.MagicMock()
-        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        assert dr is not None
-
-        with create_session() as session:
-            ti = dr.get_task_instances(session=session)[0]
-            ti.state = state
-            ti.start_date = start_date
-            ti.end_date = end_date
-
-            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
-            assert count == 1
-
-            session.refresh(ti)
-            assert ti.state == State.SCHEDULED
-
-    @parameterized.expand(
-        [
-            [State.NONE, None, None],
-            [
-                State.UP_FOR_RETRY,
-                timezone.utcnow() - datetime.timedelta(minutes=30),
-                timezone.utcnow() - datetime.timedelta(minutes=15),
-            ],
-            [
-                State.UP_FOR_RESCHEDULE,
-                timezone.utcnow() - datetime.timedelta(minutes=30),
-                timezone.utcnow() - datetime.timedelta(minutes=15),
-            ],
-        ]
-    )
-    def test_dag_file_processor_process_task_instances_with_task_concurrency(
-        self,
-        state,
-        start_date,
-        end_date,
-    ):
-        """
-        Test if _process_task_instances puts the right task instances into the
-        mock_list.
-        """
-        dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE)
-        BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi')
-
-        with create_session() as session:
-            orm_dag = DagModel(dag_id=dag.dag_id)
-            session.merge(orm_dag)
-
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.processor_agent = mock.MagicMock()
-        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        assert dr is not None
-
-        with create_session() as session:
-            ti = dr.get_task_instances(session=session)[0]
-            ti.state = state
-            ti.start_date = start_date
-            ti.end_date = end_date
-
-            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
-            assert count == 1
-
-            session.refresh(ti)
-            assert ti.state == State.SCHEDULED
-
-    @parameterized.expand(
-        [
-            [State.NONE, None, None],
-            [
-                State.UP_FOR_RETRY,
-                timezone.utcnow() - datetime.timedelta(minutes=30),
-                timezone.utcnow() - datetime.timedelta(minutes=15),
-            ],
-            [
-                State.UP_FOR_RESCHEDULE,
-                timezone.utcnow() - datetime.timedelta(minutes=30),
-                timezone.utcnow() - datetime.timedelta(minutes=15),
-            ],
-        ]
-    )
-    def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date):
-        """
-        Test if _process_task_instances puts the right task instances into the
-        mock_list.
-        """
-        dag = DAG(
-            dag_id='test_scheduler_process_execute_task_depends_on_past',
-            start_date=DEFAULT_DATE,
-            default_args={
-                'depends_on_past': True,
-            },
-        )
-        BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi')
-        BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi')
-
-        with create_session() as session:
-            orm_dag = DagModel(dag_id=dag.dag_id)
-            session.merge(orm_dag)
-
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.processor_agent = mock.MagicMock()
-        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-        dag.clear()
-        dr = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=DEFAULT_DATE,
-            state=State.RUNNING,
-        )
-        assert dr is not None
-
-        with create_session() as session:
-            tis = dr.get_task_instances(session=session)
-            for ti in tis:
-                ti.state = state
-                ti.start_date = start_date
-                ti.end_date = end_date
-
-            count = self.scheduler_job._schedule_dag_run(dr, set(), session)
-            assert count == 2
-
-            session.refresh(tis[0])
-            session.refresh(tis[1])
-            assert tis[0].state == State.SCHEDULED
-            assert tis[1].state == State.SCHEDULED
-
-    def test_scheduler_job_add_new_task(self):
-        """
-        Test if a task instance will be added if the dag is updated
-        """
-        dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE)
-        BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test')
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-
-        # Since we don't want to store the code for the DAG defined in this file
-        with mock.patch.object(settings, "STORE_DAG_CODE", False):
-            self.scheduler_job.dagbag.sync_to_db()
-
-        session = settings.Session()
-        orm_dag = session.query(DagModel).get(dag.dag_id)
-        assert orm_dag is not None
-
-        if self.scheduler_job.processor_agent:
-            self.scheduler_job.processor_agent.end()
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.processor_agent = mock.MagicMock()
-        dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session)
-        self.scheduler_job._create_dag_runs([orm_dag], session)
-
-        drs = DagRun.find(dag_id=dag.dag_id, session=session)
-        assert len(drs) == 1
-        dr = drs[0]
-
-        tis = dr.get_task_instances()
-        assert len(tis) == 1
-
-        BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test')
-        SerializedDagModel.write_dag(dag=dag)
-
-        scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session)
-        session.flush()
-        assert scheduled_tis == 2
-
-        drs = DagRun.find(dag_id=dag.dag_id, session=session)
-        assert len(drs) == 1
-        dr = drs[0]
-
-        tis = dr.get_task_instances()
-        assert len(tis) == 2
-
-    def test_runs_respected_after_clear(self):
-        """
-        Test if _process_task_instances only schedules ti's up to max_active_runs
-        (related to issue AIRFLOW-137)
-        """
-        dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE)
-        dag.max_active_runs = 3
-
-        BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi')
-
-        session = settings.Session()
-        orm_dag = DagModel(dag_id=dag.dag_id)
-        session.merge(orm_dag)
-        session.commit()
-        session.close()
-        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job.processor_agent = mock.MagicMock()
-        self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-        dag.clear()
-
-        date = DEFAULT_DATE
-        dr1 = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=date,
-            state=State.RUNNING,
-        )
-        date = dag.following_schedule(date)
-        dr2 = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=date,
-            state=State.RUNNING,
-        )
-        date = dag.following_schedule(date)
-        dr3 = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            execution_date=date,
-            state=State.RUNNING,
-        )
-
-        # First create up to 3 dagruns in RUNNING state.
-        assert dr1 is not None
-        assert dr2 is not None
-        assert dr3 is not None
-        assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3
-
-        # Reduce max_active_runs to 1
-        dag.max_active_runs = 1
-
-        # and schedule them in, so we can check how many
-        # tasks are put on the task_instances_list (should be one, not 3)
-        with create_session() as session:
-            num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session)
-            assert num_scheduled == 1
-            num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session)
-            assert num_scheduled == 0
-            num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session)
-            assert num_scheduled == 0
-
-    @patch.object(TaskInstance, 'handle_failure_with_callback')
-    def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):
-        dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-        with create_session() as session:
-            session.query(TaskInstance).delete()
-            dag = dagbag.get_dag('example_branch_operator')
-            task = dag.get_task(task_id='run_this_first')
-
-            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
-
-            session.add(ti)
-            session.commit()
-
-            requests = [
-                TaskCallbackRequest(
-                    full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
-                )
-            ]
-            dag_file_processor.execute_callbacks(dagbag, requests)
-            mock_ti_handle_failure.assert_called_once_with(
-                error="Message",
-                test_mode=conf.getboolean('core', 'unit_test_mode'),
-            )
-
-    def test_process_file_should_failure_callback(self):
-        dag_file = os.path.join(
-            os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'
-        )
-        dagbag = DagBag(dag_folder=dag_file, include_examples=False)
-        dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-        with create_session() as session, NamedTemporaryFile(delete=False) as callback_file:
-            session.query(TaskInstance).delete()
-            dag = dagbag.get_dag('test_om_failure_callback_dag')
-            task = dag.get_task(task_id='test_om_failure_callback_task')
-
-            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
-
-            session.add(ti)
-            session.commit()
-
-            requests = [
-                TaskCallbackRequest(
-                    full_filepath=dag.full_filepath,
-                    simple_task_instance=SimpleTaskInstance(ti),
-                    msg="Message",
-                )
-            ]
-            callback_file.close()
-
-            with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}):
-                dag_file_processor.process_file(dag_file, requests)
-            with open(callback_file.name) as callback_file2:
-                content = callback_file2.read()
-            assert "Callback fired" == content
-            os.remove(callback_file.name)
-
-    def test_should_mark_dummy_task_as_success(self):
-        dag_file = os.path.join(
-            os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py'
-        )
-
-        # Write DAGs to dag and serialized_dag table
-        dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False)
-        dagbag.sync_to_db()
-
-        self.scheduler_job_job = SchedulerJob(subdir=os.devnull)
-        self.scheduler_job_job.processor_agent = mock.MagicMock()
-        dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks")
-
-        # Create DagRun
-        session = settings.Session()
-        orm_dag = session.query(DagModel).get(dag.dag_id)
-        self.scheduler_job_job._create_dag_runs([orm_dag], session)
-
-        drs = DagRun.find(dag_id=dag.dag_id, session=session)
-        assert len(drs) == 1
-        dr = drs[0]
-
-        # Schedule TaskInstances
-        self.scheduler_job_job._schedule_dag_run(dr, {}, session)
-        with create_session() as session:
-            tis = session.query(TaskInstance).all()
-
-        dags = self.scheduler_job_job.dagbag.dags.values()
-        assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags]
-        assert 5 == len(tis)
-        assert {
-            ('test_task_a', 'success'),
-            ('test_task_b', None),
-            ('test_task_c', 'success'),
-            ('test_task_on_execute', 'scheduled'),
-            ('test_task_on_success', 'scheduled'),
-        } == {(ti.task_id, ti.state) for ti in tis}
-        for state, start_date, end_date, duration in [
-            (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
-        ]:
-            if state == 'success':
-                assert start_date is not None
-                assert end_date is not None
-                assert 0.0 == duration
-            else:
-                assert start_date is None
-                assert end_date is None
-                assert duration is None
-
-        self.scheduler_job_job._schedule_dag_run(dr, {}, session)
-        with create_session() as session:
-            tis = session.query(TaskInstance).all()
-
-        assert 5 == len(tis)
-        assert {
-            ('test_task_a', 'success'),
-            ('test_task_b', 'success'),
-            ('test_task_c', 'success'),
-            ('test_task_on_execute', 'scheduled'),
-            ('test_task_on_success', 'scheduled'),
-        } == {(ti.task_id, ti.state) for ti in tis}
-        for state, start_date, end_date, duration in [
-            (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis
-        ]:
-            if state == 'success':
-                assert start_date is not None
-                assert end_date is not None
-                assert 0.0 == duration
-            else:
-                assert start_date is None
-                assert end_date is None
-                assert duration is None
-
-
-@pytest.mark.usefixtures("disable_load_example")
 class TestSchedulerJob(unittest.TestCase):
     @staticmethod
     def clean_db():
@@ -802,7 +118,7 @@ class TestSchedulerJob(unittest.TestCase):
         # enqueue!
         self.null_exec = MockExecutor()
 
-        self.patcher = patch('airflow.utils.dag_processing.SerializedDagModel.remove_deleted_dags')
+        self.patcher = patch('airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags')
         # Since we don't want to store the code for the DAG defined in this file
         self.patcher_dag_code = patch.object(settings, "STORE_DAG_CODE", False)
         self.patcher.start()
@@ -3213,7 +2529,7 @@ class TestSchedulerJob(unittest.TestCase):
         dagbag.bag_dag(dag=dag, root_dag=dag)
         dagbag.sync_to_db()
 
-        @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag)
+        @mock.patch('airflow.dag_processing.processor.DagBag', return_value=dagbag)
         def do_schedule(mock_dagbag):
             # Use a empty file since the above mock will return the
             # expected DAGs. Also specify only a single file so that it doesn't
diff --git a/tests/test_utils/perf/perf_kit/python.py b/tests/test_utils/perf/perf_kit/python.py
index 7d92a49..596f4f6 100644
--- a/tests/test_utils/perf/perf_kit/python.py
+++ b/tests/test_utils/perf/perf_kit/python.py
@@ -91,7 +91,7 @@ if __name__ == "__main__":
         import logging
 
         import airflow
-        from airflow.jobs.scheduler_job import DagFileProcessor
+        from airflow.dag_processing.processor import DagFileProcessor
 
         log = logging.getLogger(__name__)
         processor = DagFileProcessor(dag_ids=[], log=log)
diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py
index e60ad51..37cf0fe 100644
--- a/tests/test_utils/perf/perf_kit/sqlalchemy.py
+++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py
@@ -218,7 +218,7 @@ if __name__ == "__main__":
         import logging
         from unittest import mock
 
-        from airflow.jobs.scheduler_job import DagFileProcessor
+        from airflow.dag_processing.processor import DagFileProcessor
 
         with mock.patch.dict(
             "os.environ",

[airflow] 01/09: Run mini scheduler in LocalTaskJob during task exit (#16289)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit e8e8b19a218ae921da0a65ef626c8c2c57582b2b
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Jun 10 14:29:30 2021 +0100

    Run mini scheduler in LocalTaskJob during task exit (#16289)
    
    Currently, the chances of tasks being killed by the LocalTaskJob heartbeat is high.
    
    This is because, after marking a task successful/failed in Taskinstance.py and mini scheduler is enabled,
    we start running the mini scheduler. Whenever the mini scheduling takes time and meet the next job heartbeat,
    the heartbeat detects that this task has succeeded with no return code because LocalTaskJob.handle_task_exit
    was not called after the task succeeded. Hence, the heartbeat thinks that this task was externally marked failed/successful.
    
    This change resolves this by moving the mini scheduler to LocalTaskJob at the handle_task_exit method ensuring
    that the task will no longer be killed by the next heartbeat
    
    (cherry picked from commit 408bd26c22913af93d05aa70abc3c66c52cd4588)
---
 airflow/jobs/local_task_job.py          |  60 ++++++++++++++-
 airflow/models/taskinstance.py          |  60 +--------------
 tests/cli/commands/test_task_command.py |   4 +-
 tests/jobs/test_local_task_job.py       | 130 ++++++++++++++++++++++++++++++--
 tests/models/test_taskinstance.py       | 103 -------------------------
 5 files changed, 183 insertions(+), 174 deletions(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 3afc801..cce4e64 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -16,21 +16,24 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
 import signal
 from typing import Optional
 
 import psutil
+from sqlalchemy.exc import OperationalError
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.jobs.base_job import BaseJob
+from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
+from airflow.sentry import Sentry
 from airflow.stats import Stats
 from airflow.task.task_runner import get_task_runner
 from airflow.utils import timezone
 from airflow.utils.net import get_hostname
 from airflow.utils.session import provide_session
+from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import State
 
 
@@ -157,8 +160,10 @@ class LocalTaskJob(BaseJob):
             self.task_instance.set_state(State.FAILED)
         if self.task_instance.state != State.SUCCESS:
             error = self.task_runner.deserialize_run_error()
-        self.task_instance._run_finished_callback(error=error)
+        self.task_instance._run_finished_callback(error=error)  # pylint: disable=protected-access
         if not self.task_instance.test_mode:
+            if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
+                self._run_mini_scheduler_on_child_tasks()
             self._update_dagrun_state_for_paused_dag()
 
     def on_kill(self):
@@ -215,6 +220,57 @@ class LocalTaskJob(BaseJob):
             self.terminating = True
 
     @provide_session
+    @Sentry.enrich_errors
+    def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
+        try:
+            # Re-select the row with a lock
+            dag_run = with_row_locks(
+                session.query(DagRun).filter_by(
+                    dag_id=self.dag_id,
+                    execution_date=self.task_instance.execution_date,
+                ),
+                session=session,
+            ).one()
+
+            # Get a partial dag with just the specific tasks we want to
+            # examine. In order for dep checks to work correctly, we
+            # include ourself (so TriggerRuleDep can check the state of the
+            # task we just executed)
+            task = self.task_instance.task
+
+            partial_dag = task.dag.partial_subset(
+                task.downstream_task_ids,
+                include_downstream=False,
+                include_upstream=False,
+                include_direct_upstream=True,
+            )
+
+            dag_run.dag = partial_dag
+            info = dag_run.task_instance_scheduling_decisions(session)
+
+            skippable_task_ids = {
+                task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
+            }
+
+            schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
+            for schedulable_ti in schedulable_tis:
+                if not hasattr(schedulable_ti, "task"):
+                    schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
+
+            num = dag_run.schedule_tis(schedulable_tis)
+            self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
+
+            session.commit()
+        except OperationalError as e:
+            # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
+            self.log.info(
+                "Skipping mini scheduling run due to exception: %s",
+                e.statement,
+                exc_info=True,
+            )
+            session.rollback()
+
+    @provide_session
     def _update_dagrun_state_for_paused_dag(self, session=None):
         """
         Checks for paused dags with DagRuns in the running state and
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b99fa34..5fb8155 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -35,7 +35,6 @@ import lazy_object_proxy
 import pendulum
 from jinja2 import TemplateAssertionError, UndefinedError
 from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_
-from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm import reconstructor, relationship
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql.elements import BooleanClauseList
@@ -70,7 +69,7 @@ from airflow.utils.net import get_hostname
 from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 
@@ -1223,62 +1222,6 @@ class TaskInstance(Base, LoggingMixin):
 
         session.commit()
 
-        if not test_mode:
-            self._run_mini_scheduler_on_child_tasks(session)
-
-    @provide_session
-    @Sentry.enrich_errors
-    def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
-        if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
-            from airflow.models.dagrun import DagRun  # Avoid circular import
-
-            try:
-                # Re-select the row with a lock
-                dag_run = with_row_locks(
-                    session.query(DagRun).filter_by(
-                        dag_id=self.dag_id,
-                        execution_date=self.execution_date,
-                    ),
-                    session=session,
-                ).one()
-
-                # Get a partial dag with just the specific tasks we want to
-                # examine. In order for dep checks to work correctly, we
-                # include ourself (so TriggerRuleDep can check the state of the
-                # task we just executed)
-                partial_dag = self.task.dag.partial_subset(
-                    self.task.downstream_task_ids,
-                    include_downstream=False,
-                    include_upstream=False,
-                    include_direct_upstream=True,
-                )
-
-                dag_run.dag = partial_dag
-                info = dag_run.task_instance_scheduling_decisions(session)
-
-                skippable_task_ids = {
-                    task_id
-                    for task_id in partial_dag.task_ids
-                    if task_id not in self.task.downstream_task_ids
-                }
-
-                schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
-                for schedulable_ti in schedulable_tis:
-                    if not hasattr(schedulable_ti, "task"):
-                        schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id)
-
-                num = dag_run.schedule_tis(schedulable_tis)
-                self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
-
-                session.commit()
-            except OperationalError as e:
-                # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
-                self.log.info(
-                    f"Skipping mini scheduling run due to exception: {e.statement}",
-                    exc_info=True,
-                )
-                session.rollback()
-
     def _prepare_and_execute_task_with_callbacks(self, context, task):
         """Prepare Task for Execution"""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -1440,6 +1383,7 @@ class TaskInstance(Base, LoggingMixin):
             session=session,
         )
         if not res:
+            self.log.info("CHECK AND CHANGE")
             return
 
         try:
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index f50ddbc..2b93e6d 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -71,8 +71,7 @@ class TestCliTasks(unittest.TestCase):
         args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree'])
         task_command.task_list(args)
 
-    @mock.patch("airflow.models.taskinstance.TaskInstance._run_mini_scheduler_on_child_tasks")
-    def test_test(self, mock_run_mini_scheduler):
+    def test_test(self):
         """Test the `airflow test` command"""
         args = self.parser.parse_args(
             ["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01']
@@ -81,7 +80,6 @@ class TestCliTasks(unittest.TestCase):
         with redirect_stdout(io.StringIO()) as stdout:
             task_command.task_test(args)
 
-        mock_run_mini_scheduler.assert_not_called()
         # Check that prints, and log messages, are shown
         assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()
 
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index ed43198..11e9adf 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -33,6 +33,7 @@ from airflow import settings
 from airflow.exceptions import AirflowException, AirflowFailException
 from airflow.executors.sequential_executor import SequentialExecutor
 from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.jobs.scheduler_job import SchedulerJob
 from airflow.models.dag import DAG, DagModel
 from airflow.models.dagbag import DagBag
 from airflow.models.taskinstance import TaskInstance
@@ -45,8 +46,9 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.types import DagRunType
+from tests.test_utils import db
 from tests.test_utils.asserts import assert_queries_count
-from tests.test_utils.db import clear_db_jobs, clear_db_runs
+from tests.test_utils.config import conf_vars
 from tests.test_utils.mock_executor import MockExecutor
 
 # pylint: skip-file
@@ -57,15 +59,25 @@ TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']
 
 class TestLocalTaskJob(unittest.TestCase):
     def setUp(self):
-        clear_db_jobs()
-        clear_db_runs()
+        db.clear_db_dags()
+        db.clear_db_jobs()
+        db.clear_db_runs()
+        db.clear_db_task_fail()
         patcher = patch('airflow.jobs.base_job.sleep')
         self.addCleanup(patcher.stop)
         self.mock_base_job_sleep = patcher.start()
 
     def tearDown(self) -> None:
-        clear_db_jobs()
-        clear_db_runs()
+        db.clear_db_dags()
+        db.clear_db_jobs()
+        db.clear_db_runs()
+        db.clear_db_task_fail()
+
+    def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
+        for task_id, expected_state in ti_state_mapping.items():
+            task_instance = dag_run.get_task_instance(task_id=task_id)
+            task_instance.refresh_from_db()
+            assert task_instance.state == expected_state, error_message
 
     def test_localtaskjob_essential_attr(self):
         """
@@ -660,14 +672,116 @@ class TestLocalTaskJob(unittest.TestCase):
             if ti.state == State.RUNNING and ti.pid is not None:
                 break
             time.sleep(0.2)
-        assert ti.state == State.RUNNING
         assert ti.pid is not None
+        assert ti.state == State.RUNNING
         os.kill(ti.pid, signal_type)
         process.join(timeout=10)
         assert failure_callback_called.value == 1
         assert task_terminated_externally.value == 1
         assert not process.is_alive()
 
+    @parameterized.expand(
+        [
+            (
+                {('scheduler', 'schedule_after_task_execution'): 'True'},
+                {'A': 'B', 'B': 'C'},
+                {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
+                {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
+                {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
+                "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
+            ),
+            (
+                {('scheduler', 'schedule_after_task_execution'): 'False'},
+                {'A': 'B', 'B': 'C'},
+                {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
+                {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
+                None,
+                "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
+            ),
+            (
+                {('scheduler', 'schedule_after_task_execution'): 'True'},
+                {'A': 'B', 'C': 'B', 'D': 'C'},
+                {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
+                {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
+                None,
+                "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.",
+            ),
+            (
+                {('scheduler', 'schedule_after_task_execution'): 'True'},
+                {'A': 'C', 'B': 'C'},
+                {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
+                {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
+                None,
+                "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
+            ),
+        ]
+    )
+    def test_fast_follow(
+        self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
+    ):
+        # pylint: disable=too-many-locals
+        with conf_vars(conf):
+            session = settings.Session()
+
+            dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
+
+            dag_model = DagModel(
+                dag_id=dag.dag_id,
+                next_dagrun=dag.start_date,
+                is_active=True,
+            )
+            session.add(dag_model)
+            session.flush()
+
+            python_callable = lambda: True
+            with dag:
+                task_a = PythonOperator(task_id='A', python_callable=python_callable)
+                task_b = PythonOperator(task_id='B', python_callable=python_callable)
+                task_c = PythonOperator(task_id='C', python_callable=python_callable)
+                if 'D' in init_state:
+                    task_d = PythonOperator(task_id='D', python_callable=python_callable)
+                for upstream, downstream in dependencies.items():
+                    dag.set_dependency(upstream, downstream)
+
+            scheduler_job = SchedulerJob(subdir=os.devnull)
+            scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+
+            dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING)
+
+            task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A'])
+
+            task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B'])
+
+            task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C'])
+
+            if 'D' in init_state:
+                task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D'])
+                session.merge(task_instance_d)
+
+            session.merge(task_instance_a)
+            session.merge(task_instance_b)
+            session.merge(task_instance_c)
+            session.flush()
+
+            job1 = LocalTaskJob(
+                task_instance=task_instance_a, ignore_ti_state=True, executor=SequentialExecutor()
+            )
+            job1.task_runner = StandardTaskRunner(job1)
+
+            job2 = LocalTaskJob(
+                task_instance=task_instance_b, ignore_ti_state=True, executor=SequentialExecutor()
+            )
+            job2.task_runner = StandardTaskRunner(job2)
+
+            settings.engine.dispose()
+            job1.run()
+            self.validate_ti_states(dag_run, first_run_state, error_message)
+            if second_run_state:
+                job2.run()
+                self.validate_ti_states(dag_run, second_run_state, error_message)
+            if scheduler_job.processor_agent:
+                scheduler_job.processor_agent.end()
+
     def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
         """Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
         dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
@@ -709,8 +823,8 @@ class TestLocalTaskJob(unittest.TestCase):
 @pytest.fixture()
 def clean_db_helper():
     yield
-    clear_db_jobs()
-    clear_db_runs()
+    db.clear_db_jobs()
+    db.clear_db_runs()
 
 
 @pytest.mark.usefixtures("clean_db_helper")
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 021809b..c1882e1 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -34,10 +34,8 @@ from sqlalchemy.orm.session import Session
 
 from airflow import models, settings
 from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException
-from airflow.jobs.scheduler_job import SchedulerJob
 from airflow.models import (
     DAG,
-    DagModel,
     DagRun,
     Pool,
     RenderedTaskInstanceFields,
@@ -1963,107 +1961,6 @@ class TestTaskInstance(unittest.TestCase):
         with create_session() as session:
             session.query(RenderedTaskInstanceFields).delete()
 
-    def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
-        for task_id, expected_state in ti_state_mapping.items():
-            task_instance = dag_run.get_task_instance(task_id=task_id)
-            assert task_instance.state == expected_state, error_message
-
-    @parameterized.expand(
-        [
-            (
-                {('scheduler', 'schedule_after_task_execution'): 'True'},
-                {'A': 'B', 'B': 'C'},
-                {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
-                {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
-                {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
-                "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
-            ),
-            (
-                {('scheduler', 'schedule_after_task_execution'): 'False'},
-                {'A': 'B', 'B': 'C'},
-                {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
-                {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
-                None,
-                "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
-            ),
-            (
-                {('scheduler', 'schedule_after_task_execution'): 'True'},
-                {'A': 'B', 'C': 'B', 'D': 'C'},
-                {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
-                {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
-                None,
-                "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.",
-            ),
-            (
-                {('scheduler', 'schedule_after_task_execution'): 'True'},
-                {'A': 'C', 'B': 'C'},
-                {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
-                {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
-                None,
-                "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
-            ),
-        ]
-    )
-    def test_fast_follow(
-        self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
-    ):
-        with conf_vars(conf):
-            session = settings.Session()
-
-            dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
-
-            dag_model = DagModel(
-                dag_id=dag.dag_id,
-                next_dagrun=dag.start_date,
-                is_active=True,
-            )
-            session.add(dag_model)
-            session.flush()
-
-            python_callable = lambda: True
-            with dag:
-                task_a = PythonOperator(task_id='A', python_callable=python_callable)
-                task_b = PythonOperator(task_id='B', python_callable=python_callable)
-                task_c = PythonOperator(task_id='C', python_callable=python_callable)
-                if 'D' in init_state:
-                    task_d = PythonOperator(task_id='D', python_callable=python_callable)
-                for upstream, downstream in dependencies.items():
-                    dag.set_dependency(upstream, downstream)
-
-            scheduler_job = SchedulerJob(subdir=os.devnull)
-            scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-
-            dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING)
-
-            task_instance_a = dag_run.get_task_instance(task_id=task_a.task_id)
-            task_instance_a.task = task_a
-            task_instance_a.set_state(init_state['A'])
-
-            task_instance_b = dag_run.get_task_instance(task_id=task_b.task_id)
-            task_instance_b.task = task_b
-            task_instance_b.set_state(init_state['B'])
-
-            task_instance_c = dag_run.get_task_instance(task_id=task_c.task_id)
-            task_instance_c.task = task_c
-            task_instance_c.set_state(init_state['C'])
-
-            if 'D' in init_state:
-                task_instance_d = dag_run.get_task_instance(task_id=task_d.task_id)
-                task_instance_d.task = task_d
-                task_instance_d.state = init_state['D']
-
-            session.commit()
-            task_instance_a.run()
-
-            self.validate_ti_states(dag_run, first_run_state, error_message)
-
-            if second_run_state:
-                scheduler_job._critical_section_execute_task_instances(session=session)
-                task_instance_b.run()
-                self.validate_ti_states(dag_run, second_run_state, error_message)
-            if scheduler_job.processor_agent:
-                scheduler_job.processor_agent.end()
-
     def test_set_state_up_for_retry(self):
         dag = DAG('dag', start_date=DEFAULT_DATE)
         op1 = DummyOperator(task_id='op_1', owner='test', dag=dag)