You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/11/02 18:00:08 UTC

incubator-airflow git commit: [AIRFLOW-387] Close SQLAlchemy sessions properly

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 1bde78338 -> 776527847


[AIRFLOW-387] Close SQLAlchemy sessions properly

This commit adopts the `provide_session` helper in
almost the entire
codebase. This ensures session are handled and
closed consistently.
In particular, this ensures we don't forget to
close and thus leak
database connections.

As an additional change, the `provide_session`
helper has been extended
to also rollback and close created connections
under error conditions.

As an additional helper, this commit also
introduces a contextmanager
that provides the same functionality as the
`provide_session`
decorator. This is helpful in cases where the
scope of a session should
be smaller than the entire method where it is
being used.

Closes #2739 from StephanErb/session_close


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/77652784
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/77652784
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/77652784

Branch: refs/heads/master
Commit: 7765278479b81a73244ff44e47973bbc5bf9bcef
Parents: 1bde783
Author: Stephan Erb <st...@blue-yonder.com>
Authored: Thu Nov 2 18:59:30 2017 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Thu Nov 2 18:59:34 2017 +0100

----------------------------------------------------------------------
 .../auth/backends/github_enterprise_auth.py     |  14 +-
 airflow/contrib/auth/backends/google_auth.py    |  14 +-
 airflow/contrib/auth/backends/kerberos_auth.py  |  13 +-
 airflow/contrib/auth/backends/ldap_auth.py      |  15 +-
 airflow/contrib/auth/backends/password_auth.py  |  13 +-
 airflow/default_login.py                        |  13 +-
 airflow/hooks/base_hook.py                      |   6 +-
 airflow/jobs.py                                 |  83 +++++------
 airflow/models.py                               |  68 ++++-----
 airflow/operators/dagrun_operator.py            |  24 +--
 airflow/operators/sensors.py                    |   6 +-
 airflow/utils/db.py                             |  40 +++--
 airflow/www/utils.py                            |   8 +-
 airflow/www/views.py                            | 149 ++++++++-----------
 14 files changed, 214 insertions(+), 252 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/contrib/auth/backends/github_enterprise_auth.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/auth/backends/github_enterprise_auth.py b/airflow/contrib/auth/backends/github_enterprise_auth.py
index 28c3cfc..e4665cb 100644
--- a/airflow/contrib/auth/backends/github_enterprise_auth.py
+++ b/airflow/contrib/auth/backends/github_enterprise_auth.py
@@ -27,6 +27,7 @@ from flask_oauthlib.client import OAuth
 
 from airflow import models, configuration, settings
 from airflow.configuration import AirflowConfigException
+from airflow.utils.db import provide_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 log = LoggingMixin().log
@@ -174,19 +175,17 @@ class GHEAuthBackend(object):
 
         return False
 
-    def load_user(self, userid):
+    @provide_session
+    def load_user(self, userid, session=None):
         if not userid or userid == 'None':
             return None
 
-        session = settings.Session()
         user = session.query(models.User).filter(
             models.User.id == int(userid)).first()
-        session.expunge_all()
-        session.commit()
-        session.close()
         return GHEUser(user)
 
-    def oauth_callback(self):
+    @provide_session
+    def oauth_callback(self, session=None):
         _log.debug('GHE OAuth callback called')
 
         next_url = request.args.get('next') or url_for('admin.index')
@@ -210,8 +209,6 @@ class GHEAuthBackend(object):
             _log.exception('')
             return redirect(url_for('airflow.noaccess'))
 
-        session = settings.Session()
-
         user = session.query(models.User).filter(
             models.User.username == username).first()
 
@@ -225,7 +222,6 @@ class GHEAuthBackend(object):
         session.commit()
         login_user(GHEUser(user))
         session.commit()
-        session.close()
 
         return redirect(next_url)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/contrib/auth/backends/google_auth.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/auth/backends/google_auth.py b/airflow/contrib/auth/backends/google_auth.py
index e6eab94..7e8be97 100644
--- a/airflow/contrib/auth/backends/google_auth.py
+++ b/airflow/contrib/auth/backends/google_auth.py
@@ -26,6 +26,7 @@ from flask import url_for, redirect, request
 from flask_oauthlib.client import OAuth
 
 from airflow import models, configuration, settings
+from airflow.utils.db import provide_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 log = LoggingMixin().log
@@ -127,19 +128,17 @@ class GoogleAuthBackend(object):
             return True
         return False
 
-    def load_user(self, userid):
+    @provide_session
+    def load_user(self, userid, session=None):
         if not userid or userid == 'None':
             return None
 
-        session = settings.Session()
         user = session.query(models.User).filter(
             models.User.id == int(userid)).first()
-        session.expunge_all()
-        session.commit()
-        session.close()
         return GoogleUser(user)
 
-    def oauth_callback(self):
+    @provide_session
+    def oauth_callback(self, session=None):
         log.debug('Google OAuth callback called')
 
         next_url = request.args.get('next') or url_for('admin.index')
@@ -162,8 +161,6 @@ class GoogleAuthBackend(object):
         except AuthenticationError:
             return redirect(url_for('airflow.noaccess'))
 
-        session = settings.Session()
-
         user = session.query(models.User).filter(
             models.User.username == username).first()
 
@@ -177,7 +174,6 @@ class GoogleAuthBackend(object):
         session.commit()
         login_user(GoogleUser(user))
         session.commit()
-        session.close()
 
         return redirect(next_url)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/contrib/auth/backends/kerberos_auth.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/auth/backends/kerberos_auth.py b/airflow/contrib/auth/backends/kerberos_auth.py
index 908ebc9..21e0ffb 100644
--- a/airflow/contrib/auth/backends/kerberos_auth.py
+++ b/airflow/contrib/auth/backends/kerberos_auth.py
@@ -29,6 +29,7 @@ from flask import url_for, redirect
 from airflow import settings
 from airflow import models
 from airflow import configuration
+from airflow.utils.db import provide_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 login_manager = flask_login.LoginManager()
@@ -86,19 +87,17 @@ class KerberosUser(models.User, LoggingMixin):
 
 
 @login_manager.user_loader
-def load_user(userid):
+@provide_session
+def load_user(userid, session=None):
     if not userid or userid == 'None':
         return None
 
-    session = settings.Session()
     user = session.query(models.User).filter(models.User.id == int(userid)).first()
-    session.expunge_all()
-    session.commit()
-    session.close()
     return KerberosUser(user)
 
 
-def login(self, request):
+@provide_session
+def login(self, request, session=None):
     if current_user.is_authenticated():
         flash("You are already logged in")
         return redirect(url_for('index'))
@@ -120,7 +119,6 @@ def login(self, request):
     try:
         KerberosUser.authenticate(username, password)
 
-        session = settings.Session()
         user = session.query(models.User).filter(
             models.User.username == username).first()
 
@@ -133,7 +131,6 @@ def login(self, request):
         session.commit()
         flask_login.login_user(KerberosUser(user))
         session.commit()
-        session.close()
 
         return redirect(request.args.get("next") or url_for("admin.index"))
     except AuthenticationError:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/contrib/auth/backends/ldap_auth.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/auth/backends/ldap_auth.py b/airflow/contrib/auth/backends/ldap_auth.py
index 2dcacda..98c620e 100644
--- a/airflow/contrib/auth/backends/ldap_auth.py
+++ b/airflow/contrib/auth/backends/ldap_auth.py
@@ -25,10 +25,10 @@ import ssl
 
 from flask import url_for, redirect
 
-from airflow import settings
 from airflow import models
 from airflow import configuration
 from airflow.configuration import AirflowConfigException
+from airflow.utils.db import provide_session
 
 import traceback
 import re
@@ -250,20 +250,17 @@ class LdapUser(models.User):
 
 
 @login_manager.user_loader
-def load_user(userid):
+@provide_session
+def load_user(userid, session=None):
     log.debug("Loading user %s", userid)
     if not userid or userid == 'None':
         return None
 
-    session = settings.Session()
     user = session.query(models.User).filter(models.User.id == int(userid)).first()
-    session.expunge_all()
-    session.commit()
-    session.close()
     return LdapUser(user)
 
-
-def login(self, request):
+@provide_session
+def login(self, request, session=None):
     if current_user.is_authenticated():
         flash("You are already logged in")
         return redirect(url_for('admin.index'))
@@ -286,7 +283,6 @@ def login(self, request):
         LdapUser.try_login(username, password)
         log.info("User %s successfully authenticated", username)
 
-        session = settings.Session()
         user = session.query(models.User).filter(
             models.User.username == username).first()
 
@@ -299,7 +295,6 @@ def login(self, request):
         session.commit()
         flask_login.login_user(LdapUser(user))
         session.commit()
-        session.close()
 
         return redirect(request.args.get("next") or url_for("admin.index"))
     except (LdapException, AuthenticationError) as e:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/contrib/auth/backends/password_auth.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/auth/backends/password_auth.py b/airflow/contrib/auth/backends/password_auth.py
index 8adb1f4..e380ec4 100644
--- a/airflow/contrib/auth/backends/password_auth.py
+++ b/airflow/contrib/auth/backends/password_auth.py
@@ -32,6 +32,7 @@ from sqlalchemy.ext.hybrid import hybrid_property
 
 from airflow import settings
 from airflow import models
+from airflow.utils.db import provide_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 login_manager = flask_login.LoginManager()
@@ -91,20 +92,18 @@ class PasswordUser(models.User):
 
 
 @login_manager.user_loader
-def load_user(userid):
+@provide_session
+def load_user(userid, session=None):
     log.debug("Loading user %s", userid)
     if not userid or userid == 'None':
         return None
 
-    session = settings.Session()
     user = session.query(models.User).filter(models.User.id == int(userid)).first()
-    session.expunge_all()
-    session.commit()
-    session.close()
     return PasswordUser(user)
 
 
-def login(self, request):
+@provide_session
+def login(self, request, session=None):
     if current_user.is_authenticated():
         flash("You are already logged in")
         return redirect(url_for('admin.index'))
@@ -124,7 +123,6 @@ def login(self, request):
                            form=form)
 
     try:
-        session = settings.Session()
         user = session.query(PasswordUser).filter(
             PasswordUser.username == username).first()
 
@@ -139,7 +137,6 @@ def login(self, request):
 
         flask_login.login_user(user)
         session.commit()
-        session.close()
 
         return redirect(request.args.get("next") or url_for("admin.index"))
     except AuthenticationError:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/default_login.py
----------------------------------------------------------------------
diff --git a/airflow/default_login.py b/airflow/default_login.py
index 653a9c0..7184f42 100644
--- a/airflow/default_login.py
+++ b/airflow/default_login.py
@@ -26,6 +26,7 @@ from flask import url_for, redirect
 
 from airflow import settings
 from airflow import models
+from airflow.utils.db import provide_session
 
 DEFAULT_USERNAME = 'airflow'
 
@@ -63,17 +64,14 @@ class DefaultUser(object):
 
 
 @login_manager.user_loader
-def load_user(userid):
-    session = settings.Session()
+@provide_session
+def load_user(userid, session=None):
     user = session.query(models.User).filter(models.User.id == userid).first()
-    session.expunge_all()
-    session.commit()
-    session.close()
     return DefaultUser(user)
 
 
-def login(self, request):
-    session = settings.Session()
+@provide_session
+def login(self, request, session=None):
     user = session.query(models.User).filter(
         models.User.username == DEFAULT_USERNAME).first()
     if not user:
@@ -84,5 +82,4 @@ def login(self, request):
     session.commit()
     flask_login.login_user(DefaultUser(user))
     session.commit()
-    session.close()
     return redirect(request.args.get("next") or url_for("index"))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/hooks/base_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py
index 92313ca..b8da61e 100644
--- a/airflow/hooks/base_hook.py
+++ b/airflow/hooks/base_hook.py
@@ -23,6 +23,7 @@ import random
 from airflow import settings
 from airflow.models import Connection
 from airflow.exceptions import AirflowException
+from airflow.utils.db import provide_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 CONN_ENV_PREFIX = 'AIRFLOW_CONN_'
@@ -41,15 +42,14 @@ class BaseHook(LoggingMixin):
 
 
     @classmethod
-    def _get_connections_from_db(cls, conn_id):
-        session = settings.Session()
+    @provide_session
+    def _get_connections_from_db(cls, conn_id, session=None):
         db = (
             session.query(Connection)
             .filter(Connection.conn_id == conn_id)
             .all()
         )
         session.expunge_all()
-        session.close()
         if not db:
             raise AirflowException(
                 "The conn_id `{0}` isn't defined".format(conn_id))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index c971c1c..664fab5 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -52,7 +52,8 @@ from airflow.utils.dag_processing import (AbstractDagFileProcessor,
                                           SimpleDag,
                                           SimpleDagBag,
                                           list_py_file_paths)
-from airflow.utils.db import provide_session, pessimistic_connection_handling
+from airflow.utils.db import (
+    create_session, provide_session, pessimistic_connection_handling)
 from airflow.utils.email import send_email
 from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter
 from airflow.utils.state import State
@@ -111,8 +112,8 @@ class BaseJob(Base, LoggingMixin):
             (conf.getint('scheduler', 'JOB_HEARTBEAT_SEC') * 2.1)
         )
 
-    def kill(self):
-        session = settings.Session()
+    @provide_session
+    def kill(self, session=None):
         job = session.query(BaseJob).filter(BaseJob.id == self.id).first()
         job.end_date = datetime.utcnow()
         try:
@@ -121,7 +122,6 @@ class BaseJob(Base, LoggingMixin):
             self.log.error('on_kill() method failed')
         session.merge(job)
         session.commit()
-        session.close()
         raise AirflowException("Job shut down externally.")
 
     def on_kill(self):
@@ -152,11 +152,10 @@ class BaseJob(Base, LoggingMixin):
         heart rate. If you go over 60 seconds before calling it, it won't
         sleep at all.
         '''
-        session = settings.Session()
-        job = session.query(BaseJob).filter_by(id=self.id).one()
-        make_transient(job)
-        session.commit()
-        session.close()
+        with create_session() as session:
+            job = session.query(BaseJob).filter_by(id=self.id).one()
+            make_transient(job)
+            session.commit()
 
         if job.state == State.SHUTDOWN:
             self.kill()
@@ -168,41 +167,37 @@ class BaseJob(Base, LoggingMixin):
                 0,
                 self.heartrate - (datetime.utcnow() - job.latest_heartbeat).total_seconds())
 
-        # Don't keep session open while sleeping as it leaves a connection open
-        session.close()
         sleep(sleep_for)
 
         # Update last heartbeat time
-        session = settings.Session()
-        job = session.query(BaseJob).filter(BaseJob.id == self.id).first()
-        job.latest_heartbeat = datetime.utcnow()
-        session.merge(job)
-        session.commit()
+        with create_session() as session:
+            job = session.query(BaseJob).filter(BaseJob.id == self.id).first()
+            job.latest_heartbeat = datetime.utcnow()
+            session.merge(job)
+            session.commit()
 
-        self.heartbeat_callback(session=session)
-        session.close()
-        self.log.debug('[heartbeat]')
+            self.heartbeat_callback(session=session)
+            self.log.debug('[heartbeat]')
 
     def run(self):
         Stats.incr(self.__class__.__name__.lower() + '_start', 1, 1)
         # Adding an entry in the DB
-        session = settings.Session()
-        self.state = State.RUNNING
-        session.add(self)
-        session.commit()
-        id_ = self.id
-        make_transient(self)
-        self.id = id_
+        with create_session() as session:
+            self.state = State.RUNNING
+            session.add(self)
+            session.commit()
+            id_ = self.id
+            make_transient(self)
+            self.id = id_
 
-        # Run
-        self._execute()
+            # Run
+            self._execute()
 
-        # Marking the success in the DB
-        self.end_date = datetime.utcnow()
-        self.state = State.SUCCESS
-        session.merge(self)
-        session.commit()
-        session.close()
+            # Marking the success in the DB
+            self.end_date = datetime.utcnow()
+            self.state = State.SUCCESS
+            session.merge(self)
+            session.commit()
 
         Stats.incr(self.__class__.__name__.lower() + '_end', 1, 1)
 
@@ -711,7 +706,6 @@ class SchedulerJob(BaseJob):
                     sla.notification_sent = True
                     session.merge(sla)
             session.commit()
-            session.close()
 
     @staticmethod
     @provide_session
@@ -886,13 +880,13 @@ class SchedulerJob(BaseJob):
                 )
                 return next_run
 
-    def _process_task_instances(self, dag, queue):
+    @provide_session
+    def _process_task_instances(self, dag, queue, session=None):
         """
         This method schedules the tasks for a single DAG by looking at the
         active DAG runs and adding task instances that should run to the
         queue.
         """
-        session = settings.Session()
 
         # update the state of the previously active dag runs
         dag_runs = DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)
@@ -949,8 +943,6 @@ class SchedulerJob(BaseJob):
                     self.log.debug('Queuing task: %s', ti)
                     queue.append(ti.key)
 
-        session.close()
-
     @provide_session
     def _change_state_for_tis_without_dagrun(self,
                                              simple_dag_bag,
@@ -1590,10 +1582,8 @@ class SchedulerJob(BaseJob):
         """
         self.executor.start()
 
-        session = settings.Session()
         self.log.info("Resetting orphaned tasks for active dag runs")
-        self.reset_state_for_orphaned_tasks(session=session)
-        session.close()
+        self.reset_state_for_orphaned_tasks()
 
         execute_start_time = datetime.utcnow()
 
@@ -1982,9 +1972,7 @@ class BackfillJob(BaseJob):
                     "reaching concurrency limits. Re-adding task to queue.",
                     ti
                 )
-                session = settings.Session()
-                ti.set_state(State.SCHEDULED, session=session)
-                session.close()
+                ti.set_state(State.SCHEDULED)
                 ti_status.started.pop(key)
                 ti_status.to_run[key] = ti
 
@@ -2395,12 +2383,12 @@ class BackfillJob(BaseJob):
 
         ti_status.executed_dag_run_dates.update(processed_dag_run_dates)
 
-    def _execute(self):
+    @provide_session
+    def _execute(self, session=None):
         """
         Initializes all components required to run a dag for a specified date range and
         calls helper method to execute the tasks.
         """
-        session = settings.Session()
         ti_status = BackfillJob._DagRunTaskStatus()
 
         start_date = self.bf_start_date
@@ -2456,7 +2444,6 @@ class BackfillJob(BaseJob):
         finally:
             executor.end()
             session.commit()
-            session.close()
 
         self.log.info("Backfill done. Exiting.")
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 3bdd68f..f2f87ac 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -471,22 +471,19 @@ class DagBag(BaseDagBag, LoggingMixin):
             table=pprinttable(stats),
         )
 
-    def deactivate_inactive_dags(self):
+    @provide_session
+    def deactivate_inactive_dags(self, session=None):
         active_dag_ids = [dag.dag_id for dag in list(self.dags.values())]
-        session = settings.Session()
         for dag in session.query(
                 DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
             dag.is_active = False
             session.merge(dag)
         session.commit()
-        session.close()
 
-    def paused_dags(self):
-        session = settings.Session()
+    @provide_session
+    def paused_dags(self, session=None):
         dag_ids = [dp.dag_id for dp in session.query(DagModel).filter(
             DagModel.is_paused.__eq__(True))]
-        session.commit()
-        session.close()
         return dag_ids
 
 
@@ -1061,7 +1058,8 @@ class TaskInstance(Base, LoggingMixin):
         """
         return self.dag_id, self.task_id, self.execution_date
 
-    def set_state(self, state, session):
+    @provide_session
+    def set_state(self, state, session=None):
         self.state = state
         self.start_date = datetime.utcnow()
         self.end_date = datetime.utcnow()
@@ -1563,10 +1561,10 @@ class TaskInstance(Base, LoggingMixin):
         self.render_templates()
         task_copy.dry_run()
 
-    def handle_failure(self, error, test_mode=False, context=None):
+    @provide_session
+    def handle_failure(self, error, test_mode=False, context=None, session=None):
         self.log.exception(error)
         task = self.task
-        session = settings.Session()
         self.end_date = datetime.utcnow()
         self.set_duration()
         Stats.incr('operator_failures_{}'.format(task.__class__.__name__), 1, 1)
@@ -1913,20 +1911,21 @@ class Log(Base):
 
 
 class SkipMixin(LoggingMixin):
-    def skip(self, dag_run, execution_date, tasks):
+    @provide_session
+    def skip(self, dag_run, execution_date, tasks, session=None):
         """
         Sets tasks instances to skipped from the same dag run.
 
         :param dag_run: the DagRun for which to set the tasks to skipped
         :param execution_date: execution_date
         :param tasks: tasks to skip (not task_ids)
+        :param session: db session to use
         """
         if not tasks:
             return
 
         task_ids = [d.task_id for d in tasks]
         now = datetime.utcnow()
-        session = settings.Session()
 
         if dag_run:
             session.query(TaskInstance).filter(
@@ -1951,7 +1950,6 @@ class SkipMixin(LoggingMixin):
                 session.merge(ti)
 
             session.commit()
-        session.close()
 
 
 @functools.total_ordering
@@ -2496,13 +2494,18 @@ class BaseOperator(LoggingMixin):
     def downstream_task_ids(self):
         return self._downstream_task_ids
 
-    def clear(self, start_date=None, end_date=None, upstream=False, downstream=False):
+    @provide_session
+    def clear(
+              self,
+              start_date=None,
+              end_date=None,
+              upstream=False,
+              downstream=False,
+              session=None):
         """
         Clears the state of task instances associated with the task, following
         the parameters specified.
         """
-        session = settings.Session()
-
         TI = TaskInstance
         qry = session.query(TI).filter(TI.dag_id == self.dag_id)
 
@@ -2528,7 +2531,7 @@ class BaseOperator(LoggingMixin):
         clear_task_instances(qry.all(), session, dag=self.dag)
 
         session.commit()
-        session.close()
+
         return count
 
     def get_task_instances(self, session, start_date=None, end_date=None):
@@ -2751,13 +2754,9 @@ class DagModel(Base):
         return "<DAG: {self.dag_id}>".format(self=self)
 
     @classmethod
-    def get_current(cls, dag_id):
-        session = settings.Session()
-        obj = session.query(cls).filter(cls.dag_id == dag_id).first()
-        session.expunge_all()
-        session.commit()
-        session.close()
-        return obj
+    @provide_session
+    def get_current(cls, dag_id, session=None):
+        return session.query(cls).filter(cls.dag_id == dag_id).first()
 
 
 @functools.total_ordering
@@ -3208,16 +3207,14 @@ class DAG(BaseDag, LoggingMixin):
         return dagrun
 
     @property
-    def latest_execution_date(self):
+    @provide_session
+    def latest_execution_date(self, session=None):
         """
         Returns the latest date for which at least one dag run exists
         """
-        session = settings.Session()
         execution_date = session.query(func.max(DagRun.execution_date)).filter(
             DagRun.dag_id == self.dag_id
         ).scalar()
-        session.commit()
-        session.close()
         return execution_date
 
     @property
@@ -3352,6 +3349,7 @@ class DAG(BaseDag, LoggingMixin):
             dirty_ids.append(dr.dag_id)
         DagStat.update(dirty_ids, session=session)
 
+    @provide_session
     def clear(
             self, start_date=None, end_date=None,
             only_failed=False,
@@ -3359,12 +3357,12 @@ class DAG(BaseDag, LoggingMixin):
             confirm_prompt=False,
             include_subdags=True,
             reset_dag_runs=True,
-            dry_run=False):
+            dry_run=False,
+            session=None):
         """
         Clears a set of task instances associated with the current dag for
         a specified date range.
         """
-        session = settings.Session()
         TI = TaskInstance
         tis = session.query(TI)
         if include_subdags:
@@ -3415,7 +3413,6 @@ class DAG(BaseDag, LoggingMixin):
             print("Bail. Nothing was cleared.")
 
         session.commit()
-        session.close()
         return count
 
     @classmethod
@@ -3625,9 +3622,9 @@ class DAG(BaseDag, LoggingMixin):
         for task in tasks:
             self.add_task(task)
 
-    def db_merge(self):
+    @provide_session
+    def db_merge(self, session=None):
         BO = BaseOperator
-        session = settings.Session()
         tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
         for t in tasks:
             session.delete(t)
@@ -4380,8 +4377,9 @@ class DagRun(Base, LoggingMixin):
         if self._state != state:
             self._state = state
             if self.dag_id is not None:
-                # something really weird goes on here: if you try to close the session
-                # dag runs will end up detached
+                # FIXME: Due to the scoped_session factor we we don't get a clean
+                # session here, so something really weird goes on:
+                # if you try to close the session dag runs will end up detached
                 session = settings.Session()
                 DagStat.set_dirty(self.dag_id, session=session)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/operators/dagrun_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py
index 9a13f90..923b8a4 100644
--- a/airflow/operators/dagrun_operator.py
+++ b/airflow/operators/dagrun_operator.py
@@ -15,6 +15,7 @@
 from datetime import datetime
 
 from airflow.models import BaseOperator, DagBag
+from airflow.utils.db import create_session
 from airflow.utils.decorators import apply_defaults
 from airflow.utils.state import State
 from airflow import settings
@@ -61,17 +62,16 @@ class TriggerDagRunOperator(BaseOperator):
         dro = DagRunOrder(run_id='trig__' + datetime.utcnow().isoformat())
         dro = self.python_callable(context, dro)
         if dro:
-            session = settings.Session()
-            dbag = DagBag(settings.DAGS_FOLDER)
-            trigger_dag = dbag.get_dag(self.trigger_dag_id)
-            dr = trigger_dag.create_dagrun(
-                run_id=dro.run_id,
-                state=State.RUNNING,
-                conf=dro.payload,
-                external_trigger=True)
-            self.log.info("Creating DagRun %s", dr)
-            session.add(dr)
-            session.commit()
-            session.close()
+            with create_session() as session:
+                dbag = DagBag(settings.DAGS_FOLDER)
+                trigger_dag = dbag.get_dag(self.trigger_dag_id)
+                dr = trigger_dag.create_dagrun(
+                    run_id=dro.run_id,
+                    state=State.RUNNING,
+                    conf=dro.payload,
+                    external_trigger=True)
+                self.log.info("Creating DagRun %s", dr)
+                session.add(dr)
+                session.commit()
         else:
             self.log.info("Criteria not met, moving on")

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/operators/sensors.py
----------------------------------------------------------------------
diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py
index 14daa6d..da7a62f 100644
--- a/airflow/operators/sensors.py
+++ b/airflow/operators/sensors.py
@@ -34,6 +34,7 @@ from airflow.hooks.base_hook import BaseHook
 from airflow.hooks.hdfs_hook import HDFSHook
 from airflow.hooks.http_hook import HttpHook
 from airflow.utils.state import State
+from airflow.utils.db import provide_session
 from airflow.utils.decorators import apply_defaults
 
 
@@ -224,7 +225,8 @@ class ExternalTaskSensor(BaseSensorOperator):
         self.external_dag_id = external_dag_id
         self.external_task_id = external_task_id
 
-    def poke(self, context):
+    @provide_session
+    def poke(self, context, session=None):
         if self.execution_delta:
             dttm = context['execution_date'] - self.execution_delta
         elif self.execution_date_fn:
@@ -243,7 +245,6 @@ class ExternalTaskSensor(BaseSensorOperator):
             '{} ... '.format(serialized_dttm_filter, **locals()))
         TI = TaskInstance
 
-        session = settings.Session()
         count = session.query(TI).filter(
             TI.dag_id == self.external_dag_id,
             TI.task_id == self.external_task_id,
@@ -251,7 +252,6 @@ class ExternalTaskSensor(BaseSensorOperator):
             TI.execution_date.in_(dttm_filter),
         ).count()
         session.commit()
-        session.close()
         return count == len(dttm_filter)
 
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 1d086be..9c924d1 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -20,6 +20,7 @@ from __future__ import unicode_literals
 from functools import wraps
 
 import os
+import contextlib
 
 from sqlalchemy import event, exc
 from sqlalchemy.pool import Pool
@@ -29,6 +30,24 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 
 log = LoggingMixin().log
 
+
+@contextlib.contextmanager
+def create_session():
+    """
+    Contextmanager that will create and teardown a session.
+    """
+    session = settings.Session()
+    try:
+        yield session
+        session.expunge_all()
+        session.commit()
+    except:
+        session.rollback()
+        raise
+    finally:
+        session.close()
+
+
 def provide_session(func):
     """
     Function decorator that provides a session if it isn't provided.
@@ -38,21 +57,20 @@ def provide_session(func):
     """
     @wraps(func)
     def wrapper(*args, **kwargs):
-        needs_session = False
         arg_session = 'session'
+
         func_params = func.__code__.co_varnames
         session_in_args = arg_session in func_params and \
             func_params.index(arg_session) < len(args)
-        if not (arg_session in kwargs or session_in_args):
-            needs_session = True
-            session = settings.Session()
-            kwargs[arg_session] = session
-        result = func(*args, **kwargs)
-        if needs_session:
-            session.expunge_all()
-            session.commit()
-            session.close()
-        return result
+        session_in_kwargs = arg_session in kwargs
+
+        if session_in_kwargs or session_in_args:
+            return func(*args, **kwargs)
+        else:
+            with create_session() as session:
+                kwargs[arg_session] = session
+                return func(*args, **kwargs)
+
     return wrapper
 
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/www/utils.py
----------------------------------------------------------------------
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 96293a2..52b22fc 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -31,6 +31,7 @@ import wtforms
 from wtforms.compat import text_type
 
 from airflow import configuration, models, settings
+from airflow.utils.db import create_session
 from airflow.utils.json import AirflowJsonEncoder
 
 AUTHENTICATE = configuration.getboolean('webserver', 'AUTHENTICATE')
@@ -237,8 +238,6 @@ def action_logging(f):
     '''
     @functools.wraps(f)
     def wrapper(*args, **kwargs):
-        session = settings.Session()
-
         if current_user and hasattr(current_user, 'username'):
             user = current_user.username
         else:
@@ -256,8 +255,9 @@ def action_logging(f):
             log.execution_date = dateparser.parse(
                 request.args.get('execution_date'))
 
-        session.add(log)
-        session.commit()
+        with create_session() as session:
+            session.add(log)
+            session.commit()
 
         return f(*args, **kwargs)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/77652784/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f9b1116..81c44b6 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -74,7 +74,7 @@ from airflow.operators.subdag_operator import SubDagOperator
 
 from airflow.utils.json import json_ser
 from airflow.utils.state import State
-from airflow.utils.db import provide_session
+from airflow.utils.db import create_session, provide_session
 from airflow.utils.helpers import alchemy_to_dict
 from airflow.utils.dates import infer_time_unit, scale_time_units
 from airflow.www import utils as wwwutils
@@ -303,15 +303,12 @@ class Airflow(BaseView):
         if conf.getboolean('core', 'secure_mode'):
             abort(404)
 
-        session = settings.Session()
-        chart_id = request.args.get('chart_id')
-        csv = request.args.get('csv') == "true"
-        chart = session.query(models.Chart).filter_by(id=chart_id).first()
-        db = session.query(
-            models.Connection).filter_by(conn_id=chart.conn_id).first()
-        session.expunge_all()
-        session.commit()
-        session.close()
+        with create_session() as session:
+            chart_id = request.args.get('chart_id')
+            csv = request.args.get('csv') == "true"
+            chart = session.query(models.Chart).filter_by(id=chart_id).first()
+            db = session.query(
+                models.Connection).filter_by(conn_id=chart.conn_id).first()
 
         payload = {
             "state": "ERROR",
@@ -444,13 +441,10 @@ class Airflow(BaseView):
         if conf.getboolean('core', 'secure_mode'):
             abort(404)
 
-        session = settings.Session()
-        chart_id = request.args.get('chart_id')
-        embed = request.args.get('embed')
-        chart = session.query(models.Chart).filter_by(id=chart_id).first()
-        session.expunge_all()
-        session.commit()
-        session.close()
+        with create_session() as session:
+            chart_id = request.args.get('chart_id')
+            embed = request.args.get('embed')
+            chart = session.query(models.Chart).filter_by(id=chart_id).first()
 
         NVd3ChartClass = chart_mapping.get(chart.chart_type)
         if not NVd3ChartClass:
@@ -477,9 +471,9 @@ class Airflow(BaseView):
 
     @expose('/dag_stats')
     @login_required
-    def dag_stats(self):
+    @provide_session
+    def dag_stats(self, session=None):
         ds = models.DagStat
-        session = Session()
 
         ds.update()
 
@@ -512,11 +506,11 @@ class Airflow(BaseView):
 
     @expose('/task_stats')
     @login_required
-    def task_stats(self):
+    @provide_session
+    def task_stats(self, session=None):
         TI = models.TaskInstance
         DagRun = models.DagRun
         Dag = models.DagModel
-        session = Session()
 
         LastDagRun = (
             session.query(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label('execution_date'))
@@ -561,7 +555,6 @@ class Airflow(BaseView):
                 data[dag_id] = {}
             data[dag_id][state] = count
         session.commit()
-        session.close()
 
         payload = {}
         for dag in dagbag.dags.values():
@@ -601,12 +594,12 @@ class Airflow(BaseView):
 
     @expose('/dag_details')
     @login_required
-    def dag_details(self):
+    @provide_session
+    def dag_details(self, session=None):
         dag_id = request.args.get('dag_id')
         dag = dagbag.get_dag(dag_id)
         title = "DAG details"
 
-        session = settings.Session()
         TI = models.TaskInstance
         states = (
             session.query(TI.state, sqla.func.count(TI.dag_id))
@@ -695,14 +688,14 @@ class Airflow(BaseView):
     @expose('/log')
     @login_required
     @wwwutils.action_logging
-    def log(self):
+    @provide_session
+    def log(self, session=None):
         dag_id = request.args.get('dag_id')
         task_id = request.args.get('task_id')
         execution_date = request.args.get('execution_date')
         dttm = dateutil.parser.parse(execution_date)
         form = DateTimeForm(data={'execution_date': dttm})
         dag = dagbag.get_dag(dag_id)
-        session = Session()
         ti = session.query(models.TaskInstance).filter(
             models.TaskInstance.dag_id == dag_id,
             models.TaskInstance.task_id == task_id,
@@ -811,7 +804,8 @@ class Airflow(BaseView):
     @expose('/xcom')
     @login_required
     @wwwutils.action_logging
-    def xcom(self):
+    @provide_session
+    def xcom(self, session=None):
         dag_id = request.args.get('dag_id')
         task_id = request.args.get('task_id')
         # Carrying execution_date through, even though it's irrelevant for
@@ -827,7 +821,6 @@ class Airflow(BaseView):
                 "error")
             return redirect('/admin/')
 
-        session = Session()
         xcomlist = session.query(XCom).filter(
             XCom.dag_id == dag_id, XCom.task_id == task_id,
             XCom.execution_date == dttm).all()
@@ -1022,8 +1015,8 @@ class Airflow(BaseView):
 
     @expose('/blocked')
     @login_required
-    def blocked(self):
-        session = settings.Session()
+    @provide_session
+    def blocked(self, session=None):
         DR = models.DagRun
         dags = (
             session.query(DR.dag_id, sqla.func.count(DR.id))
@@ -1138,7 +1131,8 @@ class Airflow(BaseView):
     @login_required
     @wwwutils.gzipped
     @wwwutils.action_logging
-    def tree(self):
+    @provide_session
+    def tree(self, session=None):
         dag_id = request.args.get('dag_id')
         blur = conf.getboolean('webserver', 'demo_mode')
         dag = dagbag.get_dag(dag_id)
@@ -1149,8 +1143,6 @@ class Airflow(BaseView):
                 include_downstream=False,
                 include_upstream=True)
 
-        session = settings.Session()
-
         base_date = request.args.get('base_date')
         num_runs = request.args.get('num_runs')
         num_runs = int(num_runs) if num_runs else 25
@@ -1247,7 +1239,6 @@ class Airflow(BaseView):
 
         data = json.dumps(data, indent=4, default=json_ser)
         session.commit()
-        session.close()
 
         form = DateTimeWithNumRunsForm(data={'base_date': max_date,
                                              'num_runs': num_runs})
@@ -1265,8 +1256,8 @@ class Airflow(BaseView):
     @login_required
     @wwwutils.gzipped
     @wwwutils.action_logging
-    def graph(self):
-        session = settings.Session()
+    @provide_session
+    def graph(self, session=None):
         dag_id = request.args.get('dag_id')
         blur = conf.getboolean('webserver', 'demo_mode')
         dag = dagbag.get_dag(dag_id)
@@ -1351,7 +1342,6 @@ class Airflow(BaseView):
         if not tasks:
             flash("No tasks found", "error")
         session.commit()
-        session.close()
         doc_md = markdown.markdown(dag.doc_md) if hasattr(dag, 'doc_md') and dag.doc_md else ''
 
         return self.render(
@@ -1378,8 +1368,8 @@ class Airflow(BaseView):
     @expose('/duration')
     @login_required
     @wwwutils.action_logging
-    def duration(self):
-        session = settings.Session()
+    @provide_session
+    def duration(self, session=None):
         dag_id = request.args.get('dag_id')
         dag = dagbag.get_dag(dag_id)
         base_date = request.args.get('base_date')
@@ -1462,7 +1452,6 @@ class Airflow(BaseView):
         max_date = max([ti.execution_date for ti in tis]) if dates else None
 
         session.commit()
-        session.close()
 
         form = DateTimeWithNumRunsForm(data={'base_date': max_date,
                                              'num_runs': num_runs})
@@ -1486,8 +1475,8 @@ class Airflow(BaseView):
     @expose('/tries')
     @login_required
     @wwwutils.action_logging
-    def tries(self):
-        session = settings.Session()
+    @provide_session
+    def tries(self, session=None):
         dag_id = request.args.get('dag_id')
         dag = dagbag.get_dag(dag_id)
         base_date = request.args.get('base_date')
@@ -1531,7 +1520,6 @@ class Airflow(BaseView):
         max_date = max([ti.execution_date for ti in tis]) if tries else None
 
         session.commit()
-        session.close()
 
         form = DateTimeWithNumRunsForm(data={'base_date': max_date,
                                              'num_runs': num_runs})
@@ -1550,8 +1538,8 @@ class Airflow(BaseView):
     @expose('/landing_times')
     @login_required
     @wwwutils.action_logging
-    def landing_times(self):
-        session = settings.Session()
+    @provide_session
+    def landing_times(self, session=None):
         dag_id = request.args.get('dag_id')
         dag = dagbag.get_dag(dag_id)
         base_date = request.args.get('base_date')
@@ -1609,9 +1597,6 @@ class Airflow(BaseView):
         dates = sorted(list({ti.execution_date for ti in tis}))
         max_date = max([ti.execution_date for ti in tis]) if dates else None
 
-        session.commit()
-        session.close()
-
         form = DateTimeWithNumRunsForm(data={'base_date': max_date,
                                              'num_runs': num_runs})
         chart.buildcontent()
@@ -1628,10 +1613,10 @@ class Airflow(BaseView):
     @expose('/paused', methods=['POST'])
     @login_required
     @wwwutils.action_logging
-    def paused(self):
+    @provide_session
+    def paused(self, session=None):
         DagModel = models.DagModel
         dag_id = request.args.get('dag_id')
-        session = settings.Session()
         orm_dag = session.query(
             DagModel).filter(DagModel.dag_id == dag_id).first()
         if request.args.get('is_paused') == 'false':
@@ -1640,7 +1625,6 @@ class Airflow(BaseView):
             orm_dag.is_paused = False
         session.merge(orm_dag)
         session.commit()
-        session.close()
 
         dagbag.get_dag(dag_id)
         return "OK"
@@ -1648,10 +1632,10 @@ class Airflow(BaseView):
     @expose('/refresh')
     @login_required
     @wwwutils.action_logging
-    def refresh(self):
+    @provide_session
+    def refresh(self, session=None):
         DagModel = models.DagModel
         dag_id = request.args.get('dag_id')
-        session = settings.Session()
         orm_dag = session.query(
             DagModel).filter(DagModel.dag_id == dag_id).first()
 
@@ -1659,7 +1643,6 @@ class Airflow(BaseView):
             orm_dag.last_expired = datetime.utcnow()
             session.merge(orm_dag)
         session.commit()
-        session.close()
 
         dagbag.get_dag(dag_id)
         flash("DAG [{}] is now fresh as a daisy".format(dag_id))
@@ -1676,8 +1659,8 @@ class Airflow(BaseView):
     @expose('/gantt')
     @login_required
     @wwwutils.action_logging
-    def gantt(self):
-        session = settings.Session()
+    @provide_session
+    def gantt(self, session=None):
         dag_id = request.args.get('dag_id')
         dag = dagbag.get_dag(dag_id)
         demo_mode = conf.getboolean('webserver', 'demo_mode')
@@ -1724,7 +1707,6 @@ class Airflow(BaseView):
         }
 
         session.commit()
-        session.close()
 
         return self.render(
             'airflow/gantt.html',
@@ -1740,8 +1722,8 @@ class Airflow(BaseView):
     @expose('/object/task_instances')
     @login_required
     @wwwutils.action_logging
-    def task_instances(self):
-        session = settings.Session()
+    @provide_session
+    def task_instances(self, session=None):
         dag_id = request.args.get('dag_id')
         dag = dagbag.get_dag(dag_id)
 
@@ -1765,10 +1747,10 @@ class Airflow(BaseView):
             if request.method == 'POST':
                 data = request.json
                 if data:
-                    session = settings.Session()
-                    var = models.Variable(key=form, val=json.dumps(data))
-                    session.add(var)
-                    session.commit()
+                    with create_session() as session:
+                        var = models.Variable(key=form, val=json.dumps(data))
+                        session.add(var)
+                        session.commit()
                 return ""
             else:
                 return self.render(
@@ -1799,8 +1781,8 @@ class Airflow(BaseView):
 class HomeView(AdminIndexView):
     @expose("/")
     @login_required
-    def index(self):
-        session = Session()
+    @provide_session
+    def index(self, session=None):
         DM = models.DagModel
 
         # restrict the dags shown if filter_by_owner and current user is not superuser
@@ -1862,9 +1844,6 @@ class HomeView(AdminIndexView):
             flash(
                 "Broken DAG: [{ie.filename}] {ie.stacktrace}".format(ie=ie),
                 "error")
-        session.expunge_all()
-        session.commit()
-        session.close()
 
         # get a list of all non-subdag dags visible to everyone
         # optionally filter out "paused" dags
@@ -1954,8 +1933,8 @@ class HomeView(AdminIndexView):
 class QueryView(wwwutils.DataProfilingMixin, BaseView):
     @expose('/', methods=['POST', 'GET'])
     @wwwutils.gzipped
-    def query(self):
-        session = settings.Session()
+    @provide_session
+    def query(self, session=None):
         dbs = session.query(models.Connection).order_by(
             models.Connection.conn_id).all()
         session.expunge_all()
@@ -2010,7 +1989,6 @@ class QueryView(wwwutils.DataProfilingMixin, BaseView):
 
         form = QueryForm(request.form, data=data)
         session.commit()
-        session.close()
         return self.render(
             'airflow/query.html', form=form,
             title="Ad Hoc Query",
@@ -2071,6 +2049,17 @@ class SlaMissModelView(wwwutils.SuperUserMixin, ModelViewOnly):
     }
 
 
+@provide_session
+def _connection_ids(session=None):
+    return [
+            (c.conn_id, c.conn_id)
+            for c in (
+                session.query(models.Connection.conn_id)
+                    .group_by(models.Connection.conn_id)
+            )
+    ]
+
+
 class ChartModelView(wwwutils.DataProfilingMixin, AirflowModelView):
     verbose_name = "chart"
     verbose_name_plural = "charts"
@@ -2162,13 +2151,7 @@ class ChartModelView(wwwutils.DataProfilingMixin, AirflowModelView):
             ('series', 'SELECT series, x, y FROM ...'),
             ('columns', 'SELECT x, y (series 1), y (series 2), ... FROM ...'),
         ],
-        'conn_id': [
-            (c.conn_id, c.conn_id)
-            for c in (
-                Session().query(models.Connection.conn_id)
-                    .group_by(models.Connection.conn_id)
-            )
-        ]
+        'conn_id': _connection_ids()
     }
 
     def on_model_change(self, form, model, is_created=True):
@@ -2312,11 +2295,10 @@ class VariableView(wwwutils.DataProfilingMixin, AirflowModelView):
 
     # Default flask-admin export functionality doesn't handle serialized json
     @action('varexport', 'Export', None)
-    def action_varexport(self, ids):
+    @provide_session
+    def action_varexport(self, ids, session=None):
         V = models.Variable
-        session = settings.Session()
         qry = session.query(V).filter(V.id.in_(ids)).all()
-        session.close()
 
         var_dict = {}
         d = json.JSONDecoder()
@@ -2401,8 +2383,8 @@ class DagRunModelView(ModelViewOnly):
         dag_id=dag_link)
 
     @action('new_delete', "Delete", "Are you sure you want to delete selected records?")
-    def action_new_delete(self, ids):
-        session = settings.Session()
+    @provide_session
+    def action_new_delete(self, ids, session=None):
         deleted = set(session.query(models.DagRun)
                       .filter(models.DagRun.id.in_(ids))
                       .all())
@@ -2414,7 +2396,6 @@ class DagRunModelView(ModelViewOnly):
         for row in deleted:
             dirty_ids.append(row.dag_id)
         models.DagStat.update(dirty_ids, dirty_only=False, session=session)
-        session.close()
 
     @action('set_running', "Set state to 'running'", None)
     def action_set_running(self, ids):