You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2017/07/14 23:38:29 UTC

incubator-airflow git commit: [AIRFLOW-1345] Dont expire TIs on each scheduler loop

Repository: incubator-airflow
Updated Branches:
  refs/heads/master e05d3b4df -> 0dd00291d


[AIRFLOW-1345] Dont expire TIs on each scheduler loop

TIs get expired on commit, which causes any access
to their properties
to cause a new query to the DB to be issued,
causing an n+1 query issue,
even when the TI is not scheduled. This change
makes all queries
batches, which will make scheduling substantially
faster.

Closes #2397 from saguziel/aguziel-commit-last


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0dd00291
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0dd00291
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0dd00291

Branch: refs/heads/master
Commit: 0dd00291d74e10a30ed328c8542416b78e24bc06
Parents: e05d3b4
Author: Alex Guziel <al...@airbnb.com>
Authored: Fri Jul 14 16:38:25 2017 -0700
Committer: Alex Guziel <al...@airbnb.com>
Committed: Fri Jul 14 16:38:25 2017 -0700

----------------------------------------------------------------------
 airflow/jobs.py | 243 ++++++++++++++++++++++++++++++++-------------
 tests/jobs.py   | 271 +++++++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 441 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0dd00291/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index e8431b7..6b63df0 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1016,13 +1016,10 @@ class SchedulerJob(BaseJob):
                 tis_changed, new_state))
 
     @provide_session
-    def _execute_task_instances(self,
-                                simple_dag_bag,
-                                states,
-                                session=None):
+    def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
         """
-        Fetches task instances from ORM in the specified states, figures
-        out pool limits, and sends them to the executor for execution.
+        Finds TIs that are ready for execution with respect to pool limits,
+        dag concurrency, executor state, and priority.
 
         :param simple_dag_bag: TaskInstances associated with DAGs in the
         simple_dag_bag will be fetched from the DB and executed
@@ -1031,18 +1028,20 @@ class SchedulerJob(BaseJob):
         :type executor: BaseExecutor
         :param states: Execute TaskInstances in these states
         :type states: Tuple[State]
-        :return: None
+        :return: List[TaskInstance]
         """
+        executable_tis = []
+
         # Get all the queued task instances from associated with scheduled
-        # DagRuns.
+        # DagRuns which are not backfilled, in the given states,
+        # and the dag is not pasued
         TI = models.TaskInstance
         DR = models.DagRun
         DM = models.DagModel
-        task_instances_to_examine = (
+        ti_query = (
             session
             .query(TI)
             .filter(TI.dag_id.in_(simple_dag_bag.dag_ids))
-            .filter(TI.state.in_(states))
             .outerjoin(DR,
                 and_(DR.dag_id == TI.dag_id,
                      DR.execution_date == TI.execution_date))
@@ -1051,14 +1050,19 @@ class SchedulerJob(BaseJob):
             .outerjoin(DM, DM.dag_id==TI.dag_id)
             .filter(or_(DM.dag_id == None,
                     not_(DM.is_paused)))
-            .all()
         )
+        if None in states:
+            ti_query = ti_query.filter(or_(TI.state == None, TI.state.in_(states)))
+        else:
+            ti_query = ti_query.filter(TI.state.in_(states))
+
+        task_instances_to_examine = ti_query.all()
 
-        # Put one task instance on each line
         if len(task_instances_to_examine) == 0:
-            self.logger.info("No tasks to send to the executor")
-            return
+            self.logger.info("No tasks to consider for execution.")
+            return executable_tis
 
+        # Put one task instance on each line
         task_instance_str = "\n\t".join(
             ["{}".format(x) for x in task_instances_to_examine])
         self.logger.info("Tasks up for execution:\n\t{}".format(task_instance_str))
@@ -1130,63 +1134,170 @@ class SchedulerJob(BaseJob):
 
 
                 if self.executor.has_task(task_instance):
-                    self.logger.debug("Not handling task {} as the executor reports it is running"
+                    self.logger.debug(("Not handling task {} as the executor " +
+                                      "reports it is running")
                                       .format(task_instance.key))
                     continue
+                executable_tis.append(task_instance)
+                open_slots -= 1
+                dag_id_to_possibly_running_task_count[dag_id] += 1
 
-                command = " ".join(TI.generate_command(
-                    task_instance.dag_id,
-                    task_instance.task_id,
-                    task_instance.execution_date,
-                    local=True,
-                    mark_success=False,
-                    ignore_all_deps=False,
-                    ignore_depends_on_past=False,
-                    ignore_task_deps=False,
-                    ignore_ti_state=False,
-                    pool=task_instance.pool,
-                    file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath,
-                    pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id))
-
-                priority = task_instance.priority_weight
-                queue = task_instance.queue
-                self.logger.info("Sending to executor {} with priority {} and queue {}"
-                                 .format(task_instance.key, priority, queue))
-
-                # Set the state to queued
-                task_instance.refresh_from_db(lock_for_update=True, session=session)
-                if task_instance.state not in states:
-                    self.logger.info("Task {} was set to {} outside this scheduler."
-                                     .format(task_instance.key, task_instance.state))
-                    session.commit()
-                    continue
+        task_instance_str = "\n\t".join(
+            ["{}".format(x) for x in executable_tis])
+        self.logger.info("Setting the follow tasks to queued state:\n\t{}"
+                         .format(task_instance_str))
+        return executable_tis
 
-                self.logger.info("Setting state of {} to {}".format(
-                    task_instance.key, State.QUEUED))
-                task_instance.state = State.QUEUED
-                task_instance.queued_dttm = (datetime.now()
-                                             if not task_instance.queued_dttm
-                                             else task_instance.queued_dttm)
-                session.merge(task_instance)
-                session.commit()
-
-                # These attributes will be lost after the object expires, so save them.
-                task_id_ = task_instance.task_id
-                dag_id_ = task_instance.dag_id
-                execution_date_ = task_instance.execution_date
-                make_transient(task_instance)
-                task_instance.task_id = task_id_
-                task_instance.dag_id = dag_id_
-                task_instance.execution_date = execution_date_
-
-                self.executor.queue_command(
-                    task_instance,
-                    command,
-                    priority=priority,
-                    queue=queue)
+    @provide_session
+    def _change_state_for_executable_task_instances(self, task_instances,
+                                                    acceptable_states, session=None):
+        """
+        Changes the state of task instances in the list with one of the given states
+        to QUEUED atomically, and returns the TIs changed.
+
+        :param task_instances: TaskInstances to change the state of
+        :type task_instances: List[TaskInstance]
+        :param acceptable_states: Filters the TaskInstances updated to be in these states
+        :type acceptable_states: Iterable[State]
+        :return: List[TaskInstance]
+        """
+        if len(task_instances) == 0:
+            session.commit()
+            return []
 
-                open_slots -= 1
-                dag_id_to_possibly_running_task_count[dag_id] += 1
+        TI = models.TaskInstance
+        filter_for_ti_state_change = (
+            [and_(
+                TI.dag_id == ti.dag_id,
+                TI.task_id == ti.task_id,
+                TI.execution_date == ti.execution_date)
+                for ti in task_instances])
+        ti_query = (
+            session
+            .query(TI)
+            .filter(or_(*filter_for_ti_state_change)))
+
+        if None in acceptable_states:
+            ti_query = ti_query.filter(or_(TI.state == None, TI.state.in_(acceptable_states)))
+        else:
+            ti_query = ti_query.filter(TI.state.in_(acceptable_states))
+
+        tis_to_set_to_queued = (
+            ti_query
+            .with_for_update()
+            .all())
+        if len(tis_to_set_to_queued) == 0:
+            self.logger.info("No tasks were able to have their state changed to queued.")
+            session.commit()
+            return []
+
+        # set TIs to queued state
+        for task_instance in tis_to_set_to_queued:
+            task_instance.state = State.QUEUED
+            task_instance.queued_dttm = (datetime.now()
+                                         if not task_instance.queued_dttm
+                                         else task_instance.queued_dttm)
+            session.merge(task_instance)
+
+        # save which TIs we set before session expires them
+        filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id,
+                                  TI.task_id == ti.task_id,
+                                  TI.execution_date == ti.execution_date)
+                             for ti in tis_to_set_to_queued])
+        session.commit()
+
+        # requery in batch since above was expired by commit
+        tis_to_be_queued = (
+            session
+            .query(TI)
+            .filter(or_(*filter_for_ti_enqueue))
+            .all())
+
+        task_instance_str = "\n\t".join(
+            ["{}".format(x) for x in tis_to_be_queued])
+        self.logger.info("Setting the follow tasks to queued state:\n\t{}"
+                         .format(task_instance_str))
+        return tis_to_be_queued
+
+    def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances):
+        """
+        Takes task_instances, which should have been set to queued, and enqueues them
+        with the executor.
+
+        :param task_instances: TaskInstances to enqueue
+        :type task_instances: List[TaskInstance]
+        :param simple_dag_bag: Should contains all of the task_instances' dags
+        :type simple_dag_bag: SimpleDagBag
+        """
+        TI = models.TaskInstance
+        # actually enqueue them
+        for task_instance in task_instances:
+            command = " ".join(TI.generate_command(
+                task_instance.dag_id,
+                task_instance.task_id,
+                task_instance.execution_date,
+                local=True,
+                mark_success=False,
+                ignore_all_deps=False,
+                ignore_depends_on_past=False,
+                ignore_task_deps=False,
+                ignore_ti_state=False,
+                pool=task_instance.pool,
+                file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath,
+                pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id))
+
+            priority = task_instance.priority_weight
+            queue = task_instance.queue
+            self.logger.info("Sending {} to executor with priority {} and queue {}"
+                             .format(task_instance.key, priority, queue))
+
+            # save attributes so sqlalchemy doesnt expire them
+            copy_dag_id = task_instance.dag_id
+            copy_task_id = task_instance.task_id
+            copy_execution_date = task_instance.execution_date
+            make_transient(task_instance)
+            task_instance.dag_id = copy_dag_id
+            task_instance.task_id = copy_task_id
+            task_instance.execution_date = copy_execution_date
+
+            self.executor.queue_command(
+                task_instance,
+                command,
+                priority=priority,
+                queue=queue)
+
+    @provide_session
+    def _execute_task_instances(self,
+                                simple_dag_bag,
+                                states,
+                                session=None):
+        """
+        Attempts to execute TaskInstances that should be executed by the scheduler.
+
+        There are three steps:
+        1. Pick TIs by priority with the constraint that they are in the expected states
+        and that we do exceed max_active_runs or pool limits.
+        2. Change the state for the TIs above atomically.
+        3. Enqueue the TIs in the executor.
+
+        :param simple_dag_bag: TaskInstances associated with DAGs in the
+        simple_dag_bag will be fetched from the DB and executed
+        :type simple_dag_bag: SimpleDagBag
+        :param states: Execute TaskInstances in these states
+        :type states: Tuple[State]
+        :return: None
+        """
+        executable_tis = self._find_executable_task_instances(simple_dag_bag, states,
+                                                              session=session)
+        tis_with_state_changed = self._change_state_for_executable_task_instances(
+            executable_tis,
+            states,
+            session=session)
+        self._enqueue_task_instances_with_queued_state(
+            simple_dag_bag,
+            tis_with_state_changed)
+        session.commit()
+        return len(tis_with_state_changed)
 
     def _process_dags(self, dagbag, dags, tis_out):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0dd00291/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index 13bd9f5..e987e0c 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -28,7 +28,7 @@ from tempfile import mkdtemp
 
 from airflow import AirflowException, settings, models
 from airflow.bin import cli
-from airflow.executors import SequentialExecutor
+from airflow.executors import BaseExecutor, SequentialExecutor
 from airflow.jobs import BackfillJob, SchedulerJob, LocalTaskJob
 from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance as TI
 from airflow.operators.dummy_operator import DummyOperator
@@ -38,7 +38,7 @@ from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths
 
-from mock import patch
+from mock import Mock, patch
 from sqlalchemy.orm.session import make_transient
 from tests.executors.test_executor import TestExecutor
 
@@ -718,8 +718,266 @@ class SchedulerJobTest(unittest.TestCase):
         ti1.refresh_from_db()
         self.assertEquals(State.SCHEDULED, ti1.state)
 
-    def test_concurrency(self):
-        dag_id = 'SchedulerJobTest.test_concurrency'
+    def test_find_executable_task_instances_backfill_nodagrun(self):
+        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill_nodagrun'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+        dr2.run_id = BackfillJob.ID_PREFIX + 'asdf'
+
+        ti_no_dagrun = TI(task1, DEFAULT_DATE - datetime.timedelta(days=1))
+        ti_backfill = TI(task1, dr2.execution_date)
+        ti_with_dagrun = TI(task1, dr1.execution_date)
+        # ti_with_paused
+        ti_no_dagrun.state = State.SCHEDULED
+        ti_backfill.state = State.SCHEDULED
+        ti_with_dagrun.state = State.SCHEDULED
+
+        session.merge(dr2)
+        session.merge(ti_no_dagrun)
+        session.merge(ti_backfill)
+        session.merge(ti_with_dagrun)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(2, len(res))
+        res_keys = map(lambda x: x.key, res)
+        self.assertIn(ti_no_dagrun.key, res_keys)
+        self.assertIn(ti_with_dagrun.key, res_keys)
+        
+    def test_find_executable_task_instances_pool(self):
+        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_pool'
+        task_id_1 = 'dummy'
+        task_id_2 = 'dummydummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a')
+        task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b')
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+
+        tis = ([
+            TI(task1, dr1.execution_date),
+            TI(task2, dr1.execution_date),
+            TI(task1, dr2.execution_date),
+            TI(task2, dr2.execution_date)
+            ])
+        for ti in tis:
+            ti.state = State.SCHEDULED
+            session.merge(ti)
+        pool = models.Pool(pool='a', slots=1, description='haha')
+        pool2 = models.Pool(pool='b', slots=100, description='haha')
+        session.add(pool)
+        session.add(pool2)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+        session.commit()
+        self.assertEqual(3, len(res))
+        res_keys = []
+        for ti in res:
+            res_keys.append(ti.key)
+        self.assertIn(tis[0].key, res_keys)
+        self.assertIn(tis[1].key, res_keys)
+        self.assertIn(tis[3].key, res_keys)
+
+    def test_find_executable_task_instances_none(self):
+        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_none'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        session.commit()
+
+        self.assertEqual(0, len(scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)))
+
+    def test_find_executable_task_instances_concurrency(self):
+        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+        dr3 = scheduler.create_dag_run(dag)
+
+        ti1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task1, dr2.execution_date)
+        ti3 = TI(task1, dr3.execution_date)
+        ti1.state = State.RUNNING
+        ti2.state = State.SCHEDULED
+        ti3.state = State.SCHEDULED
+        session.merge(ti1)
+        session.merge(ti2)
+        session.merge(ti3)
+
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(1, len(res))
+        res_keys = map(lambda x: x.key, res)
+        self.assertIn(ti2.key, res_keys)
+
+        ti2.state = State.RUNNING
+        session.merge(ti2)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(0, len(res))
+
+    def test_change_state_for_executable_task_instances_no_tis(self):
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+        res = scheduler._change_state_for_executable_task_instances([], [State.NONE], session)
+        self.assertEqual(0, len(res))
+
+    def test_change_state_for_executable_task_instances_no_tis_with_state(self):
+        dag_id = 'SchedulerJobTest.test_change_state_for__no_tis_with_state'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+        dr3 = scheduler.create_dag_run(dag)
+
+        ti1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task1, dr2.execution_date)
+        ti3 = TI(task1, dr3.execution_date)
+        ti1.state = State.SCHEDULED
+        ti2.state = State.SCHEDULED
+        ti3.state = State.SCHEDULED
+        session.merge(ti1)
+        session.merge(ti2)
+        session.merge(ti3)
+
+        session.commit()
+
+        res = scheduler._change_state_for_executable_task_instances(
+            [ti1, ti2, ti3],
+            [State.RUNNING],
+            session)
+        self.assertEqual(0, len(res))
+
+    def test_change_state_for_executable_task_instances_none_state(self):
+        dag_id = 'SchedulerJobTest.test_change_state_for__none_state'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+        dr3 = scheduler.create_dag_run(dag)
+
+        ti1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task1, dr2.execution_date)
+        ti3 = TI(task1, dr3.execution_date)
+        ti1.state = State.SCHEDULED
+        ti2.state = State.QUEUED
+        ti3.state = State.NONE
+        session.merge(ti1)
+        session.merge(ti2)
+        session.merge(ti3)
+
+        session.commit()
+
+        res = scheduler._change_state_for_executable_task_instances(
+            [ti1, ti2, ti3],
+            [State.NONE, State.SCHEDULED],
+            session)
+        self.assertEqual(2, len(res))
+        ti1.refresh_from_db()
+        ti3.refresh_from_db()
+        self.assertEqual(State.QUEUED, ti1.state)
+        self.assertEqual(State.QUEUED, ti3.state)
+
+    def test_enqueue_task_instances_with_queued_state(self):
+        dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+
+        ti1 = TI(task1, dr1.execution_date)
+        session.merge(ti1)
+        session.commit()
+
+        with patch.object(BaseExecutor, 'queue_command') as mock_queue_command:
+            scheduler._enqueue_task_instances_with_queued_state(dagbag, [ti1])
+
+        mock_queue_command.assert_called()
+
+    def test_execute_task_instances_nothing(self):
+        dag_id = 'SchedulerJobTest.test_execute_task_instances_nothing'
+        task_id_1 = 'dummy'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        dagbag = SimpleDagBag([])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        ti1 = TI(task1, dr1.execution_date)
+        ti1.state = State.SCHEDULED
+        session.merge(ti1)
+        session.commit()
+
+        self.assertEqual(0, scheduler._execute_task_instances(dagbag, states=[State.SCHEDULED]))
+
+    def test_execute_task_instances(self):
+        dag_id = 'SchedulerJobTest.test_execute_task_instances'
         task_id_1 = 'dummy_task'
         task_id_2 = 'dummy_task_nonexistent_queue'
         # important that len(tasks) is less than concurrency
@@ -765,7 +1023,7 @@ class SchedulerJobTest(unittest.TestCase):
 
         self.assertEqual(State.RUNNING, dr2.state)
 
-        scheduler._execute_task_instances(dagbag, [State.SCHEDULED])
+        res = scheduler._execute_task_instances(dagbag, [State.SCHEDULED])
 
         # check that concurrency is respected
         ti1.refresh_from_db()
@@ -777,8 +1035,7 @@ class SchedulerJobTest(unittest.TestCase):
         self.assertEqual(State.RUNNING, ti1.state)
         self.assertEqual(State.RUNNING, ti2.state)
         six.assertCountEqual(self, [State.QUEUED, State.SCHEDULED], [ti3.state, ti4.state])
-
-        session.close()
+        self.assertEqual(1, res)
 
     def test_change_state_for_tis_without_dagrun(self):
         dag = DAG(