You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2019/01/08 10:45:42 UTC

[GitHub] kaxil closed pull request #4298: [AIRFLOW-3478] Make sure that the session is closed

kaxil closed pull request #4298: [AIRFLOW-3478] Make sure that the session is closed
URL: https://github.com/apache/airflow/pull/4298
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py
index ebb622bd12..6df6e0fe09 100644
--- a/airflow/api/common/experimental/delete_dag.py
+++ b/airflow/api/common/experimental/delete_dag.py
@@ -21,11 +21,13 @@
 
 from sqlalchemy import or_
 
-from airflow import models, settings
+from airflow import models
+from airflow.utils.db import provide_session
 from airflow.exceptions import DagNotFound, DagFileExists
 
 
-def delete_dag(dag_id, keep_records_in_log=True):
+@provide_session
+def delete_dag(dag_id, keep_records_in_log=True, session=None):
     """
     :param dag_id: the dag_id of the DAG to delete
     :type dag_id: str
@@ -34,8 +36,6 @@ def delete_dag(dag_id, keep_records_in_log=True):
         The default value is True.
     :type keep_records_in_log: bool
     """
-    session = settings.Session()
-
     DM = models.DagModel
     dag = session.query(DM).filter(DM.dag_id == dag_id).first()
     if dag is None:
@@ -60,6 +60,4 @@ def delete_dag(dag_id, keep_records_in_log=True):
         for m in models.DagRun, models.TaskFail, models.TaskInstance:
             count += session.query(m).filter(m.dag_id == p, m.task_id == c).delete()
 
-    session.commit()
-
     return count
diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py
index 2fac1254cd..21b4463da1 100644
--- a/airflow/api/common/experimental/mark_tasks.py
+++ b/airflow/api/common/experimental/mark_tasks.py
@@ -22,7 +22,6 @@
 from airflow.jobs import BackfillJob
 from airflow.models import DagRun, TaskInstance
 from airflow.operators.subdag_operator import SubDagOperator
-from airflow.settings import Session
 from airflow.utils import timezone
 from airflow.utils.db import provide_session
 from airflow.utils.state import State
@@ -54,8 +53,9 @@ def _create_dagruns(dag, execution_dates, state, run_id_template):
     return drs
 
 
+@provide_session
 def set_state(task, execution_date, upstream=False, downstream=False,
-              future=False, past=False, state=State.SUCCESS, commit=False):
+              future=False, past=False, state=State.SUCCESS, commit=False, session=None):
     """
     Set the state of a task instance and if needed its relatives. Can set state
     for future tasks (calculated from execution_date) and retroactively
@@ -71,6 +71,7 @@ def set_state(task, execution_date, upstream=False, downstream=False,
     :param past: Retroactively mark all tasks starting from start_date of the DAG
     :param state: State to which the tasks need to be set
     :param commit: Commit tasks to be altered to the database
+    :param session: database session
     :return: list of tasks that have been created and updated
     """
     assert timezone.is_localized(execution_date)
@@ -124,7 +125,6 @@ def set_state(task, execution_date, upstream=False, downstream=False,
     # go through subdagoperators and create dag runs. We will only work
     # within the scope of the subdag. We wont propagate to the parent dag,
     # but we will propagate from parent to subdag.
-    session = Session()
     dags = [dag]
     sub_dag_ids = []
     while len(dags) > 0:
@@ -180,18 +180,15 @@ def set_state(task, execution_date, upstream=False, downstream=False,
             tis_altered += qry_sub_dag.with_for_update().all()
         for ti in tis_altered:
             ti.state = state
-        session.commit()
     else:
         tis_altered = qry_dag.all()
         if len(sub_dag_ids) > 0:
             tis_altered += qry_sub_dag.all()
 
-    session.expunge_all()
-    session.close()
-
     return tis_altered
 
 
+@provide_session
 def _set_dag_run_state(dag_id, execution_date, state, session=None):
     """
     Helper method that set dag run state in the DB.
@@ -211,12 +208,11 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None):
         dr.end_date = None
     else:
         dr.end_date = timezone.utcnow()
-    session.commit()
+    session.merge(dr)
 
 
 @provide_session
-def set_dag_run_state_to_success(dag, execution_date, commit=False,
-                                 session=None):
+def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None):
     """
     Set the dag run for a specific execution date and its task instances
     to success.
@@ -248,8 +244,7 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False,
 
 
 @provide_session
-def set_dag_run_state_to_failed(dag, execution_date, commit=False,
-                                session=None):
+def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None):
     """
     Set the dag run for a specific execution date and its running task instances
     to failed.
@@ -290,8 +285,7 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False,
 
 
 @provide_session
-def set_dag_run_state_to_running(dag, execution_date, commit=False,
-                                 session=None):
+def set_dag_run_state_to_running(dag, execution_date, commit=False, session=None):
     """
     Set the dag run for a specific execution date to running.
     :param dag: the DAG of which to alter state
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index 877bf34e20..8d73671969 100644
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -59,8 +59,7 @@
 from airflow.models.connection import Connection
 from airflow.models.dagpickle import DagPickle
 from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
-from airflow.utils import cli as cli_utils
-from airflow.utils import db as db_utils
+from airflow.utils import cli as cli_utils, db
 from airflow.utils.net import get_hostname
 from airflow.utils.log.logging_mixin import (LoggingMixin, redirect_stderr,
                                              redirect_stdout)
@@ -343,10 +342,8 @@ def variables(args):
         except ValueError as e:
             print(e)
     if args.delete:
-        session = settings.Session()
-        session.query(Variable).filter_by(key=args.delete).delete()
-        session.commit()
-        session.close()
+        with db.create_session() as session:
+            session.query(Variable).filter_by(key=args.delete).delete()
     if args.set:
         Variable.set(args.set[0], args.set[1])
     # Work around 'import' as a reserved keyword
@@ -360,10 +357,10 @@ def variables(args):
         export_helper(args.export)
     if not (args.set or args.get or imp or args.export or args.delete):
         # list all variables
-        session = settings.Session()
-        vars = session.query(Variable)
-        msg = "\n".join(var.key for var in vars)
-        print(msg)
+        with db.create_session() as session:
+            vars = session.query(Variable)
+            msg = "\n".join(var.key for var in vars)
+            print(msg)
 
 
 def import_helper(filepath):
@@ -390,19 +387,17 @@ def import_helper(filepath):
 
 
 def export_helper(filepath):
-    session = settings.Session()
-    qry = session.query(Variable).all()
-    session.close()
-
     var_dict = {}
-    d = json.JSONDecoder()
-    for var in qry:
-        val = None
-        try:
-            val = d.decode(var.val)
-        except Exception:
-            val = var.val
-        var_dict[var.key] = val
+    with db.create_session() as session:
+        qry = session.query(Variable).all()
+
+        d = json.JSONDecoder()
+        for var in qry:
+            try:
+                val = d.decode(var.val)
+            except Exception:
+                val = var.val
+            var_dict[var.key] = val
 
     with open(filepath, 'w') as varfile:
         varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
@@ -422,14 +417,12 @@ def unpause(args, dag=None):
 def set_is_paused(is_paused, args, dag=None):
     dag = dag or get_dag(args)
 
-    session = settings.Session()
-    dm = session.query(DagModel).filter(
-        DagModel.dag_id == dag.dag_id).first()
-    dm.is_paused = is_paused
-    session.commit()
+    with db.create_session() as session:
+        dm = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).first()
+        dm.is_paused = is_paused
+        session.commit()
 
-    msg = "Dag: {}, paused: {}".format(dag, str(dag.is_paused))
-    print(msg)
+    print("Dag: {}, paused: {}".format(dag, str(dag.is_paused)))
 
 
 def _run(args, dag, ti):
@@ -455,14 +448,12 @@ def _run(args, dag, ti):
         if args.ship_dag:
             try:
                 # Running remotely, so pickling the DAG
-                session = settings.Session()
-                pickle = DagPickle(dag)
-                session.add(pickle)
-                session.commit()
-                pickle_id = pickle.id
-                # TODO: This should be written to a log
-                print('Pickled dag {dag} as pickle_id:{pickle_id}'
-                      .format(**locals()))
+                with db.create_session() as session:
+                    pickle = DagPickle(dag)
+                    session.add(pickle)
+                    pickle_id = pickle.id
+                    # TODO: This should be written to a log
+                    print('Pickled dag {dag} as pickle_id:{pickle_id}'.format(**locals()))
             except Exception as e:
                 print('Could not pickle the DAG')
                 print(e)
@@ -511,13 +502,12 @@ def run(args, dag=None):
     if not args.pickle and not dag:
         dag = get_dag(args)
     elif not dag:
-        session = settings.Session()
-        log.info('Loading pickle id {args.pickle}'.format(args=args))
-        dag_pickle = session.query(
-            DagPickle).filter(DagPickle.id == args.pickle).first()
-        if not dag_pickle:
-            raise AirflowException("Who hid the pickle!? [missing pickle]")
-        dag = dag_pickle.pickle
+        with db.create_session() as session:
+            log.info('Loading pickle id {args.pickle}'.format(args=args))
+            dag_pickle = session.query(DagPickle).filter(DagPickle.id == args.pickle).first()
+            if not dag_pickle:
+                raise AirflowException("Who hid the pickle!? [missing pickle]")
+            dag = dag_pickle.pickle
 
     task = dag.get_task(task_id=args.task_id)
     ti = TaskInstance(task, args.execution_date)
@@ -1089,7 +1079,7 @@ def worker(args):
 
 def initdb(args):
     print("DB: " + repr(settings.engine.url))
-    db_utils.initdb(settings.RBAC)
+    db.initdb(settings.RBAC)
     print("Done.")
 
 
@@ -1098,7 +1088,7 @@ def resetdb(args):
     if args.yes or input("This will drop existing tables "
                          "if they exist. Proceed? "
                          "(y/n)").upper() == "Y":
-        db_utils.resetdb(settings.RBAC)
+        db.resetdb(settings.RBAC)
     else:
         print("Bail.")
 
@@ -1106,7 +1096,7 @@ def resetdb(args):
 @cli_utils.action_logging
 def upgradedb(args):
     print("DB: " + repr(settings.engine.url))
-    db_utils.upgradedb()
+    db.upgradedb()
 
 
 @cli_utils.action_logging
@@ -1133,20 +1123,20 @@ def connections(args):
             print(msg)
             return
 
-        session = settings.Session()
-        conns = session.query(Connection.conn_id, Connection.conn_type,
-                              Connection.host, Connection.port,
-                              Connection.is_encrypted,
-                              Connection.is_extra_encrypted,
-                              Connection.extra).all()
-        conns = [map(reprlib.repr, conn) for conn in conns]
-        msg = tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port',
-                               'Is Encrypted', 'Is Extra Encrypted', 'Extra'],
-                       tablefmt="fancy_grid")
-        if sys.version_info[0] < 3:
-            msg = msg.encode('utf-8')
-        print(msg)
-        return
+        with db.create_session() as session:
+            conns = session.query(Connection.conn_id, Connection.conn_type,
+                                  Connection.host, Connection.port,
+                                  Connection.is_encrypted,
+                                  Connection.is_extra_encrypted,
+                                  Connection.extra).all()
+            conns = [map(reprlib.repr, conn) for conn in conns]
+            msg = tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port',
+                                   'Is Encrypted', 'Is Extra Encrypted', 'Extra'],
+                           tablefmt="fancy_grid")
+            if sys.version_info[0] < 3:
+                msg = msg.encode('utf-8')
+            print(msg)
+            return
 
     if args.delete:
         # Check that only the `conn_id` arg was passed to the command
@@ -1166,31 +1156,30 @@ def connections(args):
                   'the --conn_id flag.\n')
             return
 
-        session = settings.Session()
-        try:
-            to_delete = (session
-                         .query(Connection)
-                         .filter(Connection.conn_id == args.conn_id)
-                         .one())
-        except exc.NoResultFound:
-            msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n'
-            msg = msg.format(conn_id=args.conn_id)
-            print(msg)
-            return
-        except exc.MultipleResultsFound:
-            msg = ('\n\tFound more than one connection with ' +
-                   '`conn_id`={conn_id}\n')
-            msg = msg.format(conn_id=args.conn_id)
-            print(msg)
+        with db.create_session() as session:
+            try:
+                to_delete = (session
+                             .query(Connection)
+                             .filter(Connection.conn_id == args.conn_id)
+                             .one())
+            except exc.NoResultFound:
+                msg = '\n\tDid not find a connection with `conn_id`={conn_id}\n'
+                msg = msg.format(conn_id=args.conn_id)
+                print(msg)
+                return
+            except exc.MultipleResultsFound:
+                msg = ('\n\tFound more than one connection with ' +
+                       '`conn_id`={conn_id}\n')
+                msg = msg.format(conn_id=args.conn_id)
+                print(msg)
+                return
+            else:
+                deleted_conn_id = to_delete.conn_id
+                session.delete(to_delete)
+                msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n'
+                msg = msg.format(conn_id=deleted_conn_id)
+                print(msg)
             return
-        else:
-            deleted_conn_id = to_delete.conn_id
-            session.delete(to_delete)
-            session.commit()
-            msg = '\n\tSuccessfully deleted `conn_id`={conn_id}\n'
-            msg = msg.format(conn_id=deleted_conn_id)
-            print(msg)
-        return
 
     if args.add:
         # Check that the conn_id and conn_uri args were passed to the command:
@@ -1229,26 +1218,25 @@ def connections(args):
         if args.conn_extra is not None:
             new_conn.set_extra(args.conn_extra)
 
-        session = settings.Session()
-        if not (session.query(Connection)
-                       .filter(Connection.conn_id == new_conn.conn_id).first()):
-            session.add(new_conn)
-            session.commit()
-            msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
-            msg = msg.format(conn_id=new_conn.conn_id,
-                             uri=args.conn_uri or
-                             urlunparse((args.conn_type,
-                                        '{login}:{password}@{host}:{port}'
-                                         .format(login=args.conn_login or '',
-                                                 password=args.conn_password or '',
-                                                 host=args.conn_host or '',
-                                                 port=args.conn_port or ''),
-                                         args.conn_schema or '', '', '', '')))
-            print(msg)
-        else:
-            msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
-            msg = msg.format(conn_id=new_conn.conn_id)
-            print(msg)
+        with db.create_session() as session:
+            if not (session.query(Connection)
+                           .filter(Connection.conn_id == new_conn.conn_id).first()):
+                session.add(new_conn)
+                msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
+                msg = msg.format(conn_id=new_conn.conn_id,
+                                 uri=args.conn_uri or
+                                 urlunparse((args.conn_type,
+                                            '{login}:{password}@{host}:{port}'
+                                             .format(login=args.conn_login or '',
+                                                     password=args.conn_password or '',
+                                                     host=args.conn_host or '',
+                                                     port=args.conn_port or ''),
+                                             args.conn_schema or '', '', '', '')))
+                print(msg)
+            else:
+                msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
+                msg = msg.format(conn_id=new_conn.conn_id)
+                print(msg)
 
         return
 
diff --git a/airflow/contrib/auth/backends/password_auth.py b/airflow/contrib/auth/backends/password_auth.py
index 0bc04fcea8..9d6a3ccbe6 100644
--- a/airflow/contrib/auth/backends/password_auth.py
+++ b/airflow/contrib/auth/backends/password_auth.py
@@ -35,9 +35,8 @@
 from sqlalchemy import Column, String
 from sqlalchemy.ext.hybrid import hybrid_property
 
-from airflow import settings
 from airflow import models
-from airflow.utils.db import provide_session
+from airflow.utils.db import provide_session, create_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 login_manager = flask_login.LoginManager()
@@ -165,9 +164,6 @@ def login(self, request, session=None):
         return self.render('airflow/login.html',
                            title="Airflow - Login",
                            form=form)
-    finally:
-        session.commit()
-        session.close()
 
 
 class LoginForm(Form):
@@ -201,19 +197,16 @@ def decorated(*args, **kwargs):
             userpass = ''.join(header.split()[1:])
             username, password = base64.b64decode(userpass).decode("utf-8").split(":", 1)
 
-            session = settings.Session()
-            try:
-                authenticate(session, username, password)
+            with create_session() as session:
+                try:
+                    authenticate(session, username, password)
 
-                response = function(*args, **kwargs)
-                response = make_response(response)
-                return response
+                    response = function(*args, **kwargs)
+                    response = make_response(response)
+                    return response
 
-            except AuthenticationError:
-                return _forbidden()
+                except AuthenticationError:
+                    return _forbidden()
 
-            finally:
-                session.commit()
-                session.close()
         return _unauthorized()
     return decorated
diff --git a/airflow/contrib/executors/kubernetes_executor.py b/airflow/contrib/executors/kubernetes_executor.py
index fa81cf3203..aaf85881a9 100644
--- a/airflow/contrib/executors/kubernetes_executor.py
+++ b/airflow/contrib/executors/kubernetes_executor.py
@@ -32,7 +32,8 @@
 from airflow.executors import Executors
 from airflow.models import TaskInstance, KubeResourceVersion, KubeWorkerIdentifier
 from airflow.utils.state import State
-from airflow import configuration, settings
+from airflow.utils.db import provide_session, create_session
+from airflow import configuration
 from airflow.exceptions import AirflowConfigException, AirflowException
 from airflow.utils.log.logging_mixin import LoggingMixin
 
@@ -337,8 +338,7 @@ def process_status(self, pod_id, status, labels, resource_version):
 
 
 class AirflowKubernetesScheduler(LoggingMixin):
-    def __init__(self, kube_config, task_queue, result_queue, session,
-                 kube_client, worker_uuid):
+    def __init__(self, kube_config, task_queue, result_queue, kube_client, worker_uuid):
         self.log.debug("Creating Kubernetes executor")
         self.kube_config = kube_config
         self.task_queue = task_queue
@@ -349,12 +349,11 @@ def __init__(self, kube_config, task_queue, result_queue, session,
         self.launcher = PodLauncher(kube_client=self.kube_client)
         self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config)
         self.watcher_queue = multiprocessing.Queue()
-        self._session = session
         self.worker_uuid = worker_uuid
         self.kube_watcher = self._make_kube_watcher()
 
     def _make_kube_watcher(self):
-        resource_version = KubeResourceVersion.get_current_resource_version(self._session)
+        resource_version = KubeResourceVersion.get_current_resource_version()
         watcher = KubernetesJobWatcher(self.namespace, self.watcher_queue,
                                        resource_version, self.worker_uuid)
         watcher.start()
@@ -514,14 +513,14 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
     def __init__(self):
         self.kube_config = KubeConfig()
         self.task_queue = None
-        self._session = None
         self.result_queue = None
         self.kube_scheduler = None
         self.kube_client = None
         self.worker_uuid = None
         super(KubernetesExecutor, self).__init__(parallelism=self.kube_config.parallelism)
 
-    def clear_not_launched_queued_tasks(self):
+    @provide_session
+    def clear_not_launched_queued_tasks(self, session=None):
         """
         If the airflow scheduler restarts with pending "Queued" tasks, the tasks may or
         may not
@@ -537,8 +536,9 @@ def clear_not_launched_queued_tasks(self):
         proper support
         for State.LAUNCHED
         """
-        queued_tasks = self._session.query(
-            TaskInstance).filter(TaskInstance.state == State.QUEUED).all()
+        queued_tasks = session\
+            .query(TaskInstance)\
+            .filter(TaskInstance.state == State.QUEUED).all()
         self.log.info(
             'When executor started up, found %s queued task instances',
             len(queued_tasks)
@@ -557,14 +557,12 @@ def clear_not_launched_queued_tasks(self):
                     'TaskInstance: %s found in queued state but was not launched, '
                     'rescheduling', task
                 )
-                self._session.query(TaskInstance).filter(
+                session.query(TaskInstance).filter(
                     TaskInstance.dag_id == task.dag_id,
                     TaskInstance.task_id == task.task_id,
                     TaskInstance.execution_date == task.execution_date
                 ).update({TaskInstance.state: State.NONE})
 
-        self._session.commit()
-
     def _inject_secrets(self):
         def _create_or_update_secret(secret_name, secret_path):
             try:
@@ -601,20 +599,18 @@ def _create_or_update_secret(secret_name, secret_path):
 
     def start(self):
         self.log.info('Start Kubernetes executor')
-        self._session = settings.Session()
-        self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(
-            self._session)
+        self.worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid()
         self.log.debug('Start with worker_uuid: %s', self.worker_uuid)
         # always need to reset resource version since we don't know
         # when we last started, note for behavior below
         # https://github.com/kubernetes-client/python/blob/master/kubernetes/docs
         # /CoreV1Api.md#list_namespaced_pod
-        KubeResourceVersion.reset_resource_version(self._session)
+        KubeResourceVersion.reset_resource_version()
         self.task_queue = Queue()
         self.result_queue = Queue()
         self.kube_client = get_kube_client()
         self.kube_scheduler = AirflowKubernetesScheduler(
-            self.kube_config, self.task_queue, self.result_queue, self._session,
+            self.kube_config, self.task_queue, self.result_queue,
             self.kube_client, self.worker_uuid
         )
         self._inject_secrets()
@@ -643,8 +639,7 @@ def sync(self):
             self.log.info('Changing state of %s to %s', results, state)
             self._change_state(key, state, pod_id)
 
-        KubeResourceVersion.checkpoint_resource_version(
-            last_resource_version, session=self._session)
+        KubeResourceVersion.checkpoint_resource_version(last_resource_version)
 
         if not self.task_queue.empty():
             task = self.task_queue.get()
@@ -667,15 +662,15 @@ def _change_state(self, key, state, pod_id):
                 pass
         self.event_buffer[key] = state
         (dag_id, task_id, ex_time, try_number) = key
-        item = self._session.query(TaskInstance).filter_by(
-            dag_id=dag_id,
-            task_id=task_id,
-            execution_date=ex_time
-        ).one()
-        if state:
-            item.state = state
-            self._session.add(item)
-            self._session.commit()
+        with create_session() as session:
+            item = session.query(TaskInstance).filter_by(
+                dag_id=dag_id,
+                task_id=task_id,
+                execution_date=ex_time
+            ).one()
+            if state:
+                item.state = state
+                session.add(item)
 
     def end(self):
         self.log.info('Shutting down Kubernetes executor')
diff --git a/airflow/settings.py b/airflow/settings.py
index 8691fe4e75..8ff9a34263 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -172,8 +172,8 @@ def configure_orm(disable_connection_pool=False):
         except conf.AirflowConfigException:
             pool_recycle = 1800
 
-        log.info("setting.configure_orm(): Using pool settings. pool_size={}, "
-                 "pool_recycle={}".format(pool_size, pool_recycle))
+        log.info("settings.configure_orm(): Using pool settings. pool_size={}, "
+                 "pool_recycle={}, pid={}".format(pool_size, pool_recycle, os.getpid()))
         engine_args['pool_size'] = pool_size
         engine_args['pool_recycle'] = pool_recycle
 
diff --git a/airflow/utils/cli_action_loggers.py b/airflow/utils/cli_action_loggers.py
index 8b8cda538f..21304936f3 100644
--- a/airflow/utils/cli_action_loggers.py
+++ b/airflow/utils/cli_action_loggers.py
@@ -25,7 +25,7 @@
 
 import logging
 
-import airflow.settings
+from airflow.utils.db import create_session
 
 
 def register_pre_exec_callback(action_logger):
@@ -94,9 +94,8 @@ def default_action_log(log, **_):
     :param **_: other keyword arguments that is not being used by this function
     :return: None
     """
-    session = airflow.settings.Session()
-    session.add(log)
-    session.commit()
+    with create_session() as session:
+        session.add(log)
 
 
 __pre_exec_callbacks = []
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 021f0583d1..e864c6f26e 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -41,7 +41,6 @@ def create_session():
     session = settings.Session()
     try:
         yield session
-        session.expunge_all()
         session.commit()
     except Exception:
         session.rollback()
diff --git a/airflow/www_rbac/decorators.py b/airflow/www_rbac/decorators.py
index 8f962ef840..4ec2c7e8e9 100644
--- a/airflow/www_rbac/decorators.py
+++ b/airflow/www_rbac/decorators.py
@@ -22,7 +22,8 @@
 import pendulum
 from io import BytesIO as IO
 from flask import after_this_request, redirect, request, url_for, g
-from airflow import models, settings
+from airflow import models
+from airflow.utils.db import create_session
 
 
 def action_logging(f):
@@ -31,26 +32,26 @@ def action_logging(f):
     """
     @functools.wraps(f)
     def wrapper(*args, **kwargs):
-        session = settings.Session()
-        if g.user.is_anonymous:
-            user = 'anonymous'
-        else:
-            user = g.user.username
-
-        log = models.Log(
-            event=f.__name__,
-            task_instance=None,
-            owner=user,
-            extra=str(list(request.args.items())),
-            task_id=request.args.get('task_id'),
-            dag_id=request.args.get('dag_id'))
-
-        if 'execution_date' in request.args:
-            log.execution_date = pendulum.parse(
-                request.args.get('execution_date'))
-
-        session.add(log)
-        session.commit()
+
+        with create_session() as session:
+            if g.user.is_anonymous:
+                user = 'anonymous'
+            else:
+                user = g.user.username
+
+            log = models.Log(
+                event=f.__name__,
+                task_instance=None,
+                owner=user,
+                extra=str(list(request.args.items())),
+                task_id=request.args.get('task_id'),
+                dag_id=request.args.get('dag_id'))
+
+            if 'execution_date' in request.args:
+                log.execution_date = pendulum.parse(
+                    request.args.get('execution_date'))
+
+            session.add(log)
 
         return f(*args, **kwargs)
 
diff --git a/airflow/www_rbac/security.py b/airflow/www_rbac/security.py
index 369ec71ca7..6bc627a561 100644
--- a/airflow/www_rbac/security.py
+++ b/airflow/www_rbac/security.py
@@ -24,8 +24,9 @@
 from flask_appbuilder.security.sqla.manager import SecurityManager
 from sqlalchemy import or_
 
-from airflow import models, settings
+from airflow import models
 from airflow.www_rbac.app import appbuilder
+from airflow.utils.db import provide_session
 
 ###########################################################################
 #                               VIEW MENUS
@@ -329,7 +330,8 @@ def _merge_perm(self, permission_name, view_menu_name):
         if not pv and permission_name and view_menu_name:
             self.add_permission_view_menu(permission_name, view_menu_name)
 
-    def create_custom_dag_permission_view(self):
+    @provide_session
+    def create_custom_dag_permission_view(self, session=None):
         """
         Workflow:
         1. when scheduler found a new dag, we will create an entry in ab_view_menu
@@ -359,7 +361,7 @@ def merge_pv(perm, view_menu):
                 merge_pv(perm, dag)
 
         # Get all the active / paused dags and insert them into a set
-        all_dags_models = settings.Session.query(models.DagModel)\
+        all_dags_models = session.query(models.DagModel)\
             .filter(or_(models.DagModel.is_active, models.DagModel.is_paused))\
             .filter(~models.DagModel.is_subdag).all()
 
@@ -382,7 +384,7 @@ def merge_pv(perm, view_menu):
         view_menu = self.viewmenu_model
 
         # todo(tao) comment on the query
-        all_perm_view_by_user = settings.Session.query(ab_perm_view_role)\
+        all_perm_view_by_user = session.query(ab_perm_view_role)\
             .join(perm_view, perm_view.id == ab_perm_view_role
                   .columns.permission_view_id)\
             .filter(ab_perm_view_role.columns.role_id == user_role.id)\
diff --git a/airflow/www_rbac/utils.py b/airflow/www_rbac/utils.py
index b25e1541ab..7c983279af 100644
--- a/airflow/www_rbac/utils.py
+++ b/airflow/www_rbac/utils.py
@@ -40,7 +40,7 @@
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 import flask_appbuilder.models.sqla.filters as fab_sqlafilters
 import sqlalchemy as sqla
-from airflow import configuration, settings
+from airflow import configuration
 from airflow.models import BaseOperator
 from airflow.operators.subdag_operator import SubDagOperator
 from airflow.utils import timezone
@@ -427,8 +427,6 @@ class CustomSQLAInterface(SQLAInterface):
     def __init__(self, obj):
         super(CustomSQLAInterface, self).__init__(obj)
 
-        self.session = settings.Session()
-
         def clean_column_names():
             if self.list_properties:
                 self.list_properties = dict(
diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py
index 0cb8a4e888..28df2f6a46 100644
--- a/tests/api/common/experimental/test_mark_tasks.py
+++ b/tests/api/common/experimental/test_mark_tasks.py
@@ -18,16 +18,18 @@
 # under the License.
 
 import unittest
+import time
 from datetime import datetime
 
 from airflow import configuration, models
 from airflow.api.common.experimental.mark_tasks import (
     set_state, _create_dagruns, set_dag_run_state_to_success, set_dag_run_state_to_failed,
     set_dag_run_state_to_running)
-from airflow.settings import Session
 from airflow.utils import timezone
+from airflow.utils.db import create_session, provide_session
 from airflow.utils.dates import days_ago
 from airflow.utils.state import State
+from airflow.models import DagRun
 
 DEV_NULL = "/dev/null"
 
@@ -59,33 +61,29 @@ def setUp(self):
             dr.dag = self.dag2
             dr.verify_integrity()
 
-        self.session = Session()
-
     def tearDown(self):
         self.dag1.clear()
         self.dag2.clear()
 
         # just to make sure we are fully cleaned up
-        self.session.query(models.DagRun).delete()
-        self.session.query(models.TaskInstance).delete()
-        self.session.commit()
-        self.session.close()
+        with create_session() as session:
+            session.query(models.DagRun).delete()
+            session.query(models.TaskInstance).delete()
 
-    def snapshot_state(self, dag, execution_dates):
+    @staticmethod
+    def snapshot_state(dag, execution_dates):
         TI = models.TaskInstance
-        tis = self.session.query(TI).filter(
-            TI.dag_id == dag.dag_id,
-            TI.execution_date.in_(execution_dates)
-        ).all()
-
-        self.session.expunge_all()
-
-        return tis
-
-    def verify_state(self, dag, task_ids, execution_dates, state, old_tis):
+        with create_session() as session:
+            return session.query(TI).filter(
+                TI.dag_id == dag.dag_id,
+                TI.execution_date.in_(execution_dates)
+            ).all()
+
+    @provide_session
+    def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None):
         TI = models.TaskInstance
 
-        tis = self.session.query(TI).filter(
+        tis = session.query(TI).filter(
             TI.dag_id == dag.dag_id,
             TI.execution_date.in_(execution_dates)
         ).all()
@@ -102,7 +100,7 @@ def verify_state(self, dag, task_ids, execution_dates, state, old_tis):
 
     def test_mark_tasks_now(self):
         # set one task to success but do not commit
-        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
         task = self.dag1.get_task("runme_1")
         altered = set_state(task=task, execution_date=self.execution_dates[0],
                             upstream=False, downstream=False, future=False,
@@ -136,7 +134,7 @@ def test_mark_tasks_now(self):
                           State.FAILED, snapshot)
 
         # dont alter other tasks
-        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
         task = self.dag1.get_task("runme_0")
         altered = set_state(task=task, execution_date=self.execution_dates[0],
                             upstream=False, downstream=False, future=False,
@@ -147,7 +145,7 @@ def test_mark_tasks_now(self):
 
     def test_mark_downstream(self):
         # test downstream
-        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
         task = self.dag1.get_task("runme_1")
         relatives = task.get_flat_relatives(upstream=False)
         task_ids = [t.task_id for t in relatives]
@@ -157,12 +155,11 @@ def test_mark_downstream(self):
                             upstream=False, downstream=True, future=False,
                             past=False, state=State.SUCCESS, commit=True)
         self.assertEqual(len(altered), 3)
-        self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
-                          State.SUCCESS, snapshot)
+        self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot)
 
     def test_mark_upstream(self):
         # test upstream
-        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
         task = self.dag1.get_task("run_after_loop")
         relatives = task.get_flat_relatives(upstream=True)
         task_ids = [t.task_id for t in relatives]
@@ -177,31 +174,28 @@ def test_mark_upstream(self):
 
     def test_mark_tasks_future(self):
         # set one task to success towards end of scheduled dag runs
-        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
         task = self.dag1.get_task("runme_1")
         altered = set_state(task=task, execution_date=self.execution_dates[0],
                             upstream=False, downstream=False, future=True,
                             past=False, state=State.SUCCESS, commit=True)
         self.assertEqual(len(altered), 2)
-        self.verify_state(self.dag1, [task.task_id], self.execution_dates,
-                          State.SUCCESS, snapshot)
+        self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)
 
     def test_mark_tasks_past(self):
         # set one task to success towards end of scheduled dag runs
-        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
         task = self.dag1.get_task("runme_1")
         altered = set_state(task=task, execution_date=self.execution_dates[1],
                             upstream=False, downstream=False, future=False,
                             past=True, state=State.SUCCESS, commit=True)
         self.assertEqual(len(altered), 2)
-        self.verify_state(self.dag1, [task.task_id], self.execution_dates,
-                          State.SUCCESS, snapshot)
+        self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)
 
     # TODO: this skipIf should be removed once a fixing solution is found later
     #       We skip it here because this test case is working with Postgres & SQLite
     #       but not with MySQL
-    @unittest.skipIf('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'),
-                     "Flaky with MySQL")
+    @unittest.skipIf('mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), "Flaky with MySQL")
     def test_mark_tasks_subdag(self):
         # set one task to success towards end of scheduled dag runs
         task = self.dag2.get_task("section-1")
@@ -226,39 +220,35 @@ def setUp(self):
         self.dagbag = models.DagBag(include_examples=True)
         self.dag1 = self.dagbag.dags['example_bash_operator']
         self.dag2 = self.dagbag.dags['example_subdag_operator']
-
         self.execution_dates = [days_ago(2), days_ago(1), days_ago(0)]
 
-        self.session = Session()
-
     def _set_default_task_instance_states(self, dr):
-        if dr.dag_id != 'example_bash_operator':
-            return
         # success task
-        dr.get_task_instance('runme_0').set_state(State.SUCCESS, self.session)
+        dr.get_task_instance('runme_0').set_state(State.SUCCESS)
         # skipped task
-        dr.get_task_instance('runme_1').set_state(State.SKIPPED, self.session)
+        dr.get_task_instance('runme_1').set_state(State.SKIPPED)
         # retry task
-        dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY, self.session)
+        dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY)
         # queued task
-        dr.get_task_instance('also_run_this').set_state(State.QUEUED, self.session)
+        dr.get_task_instance('also_run_this').set_state(State.QUEUED)
         # running task
-        dr.get_task_instance('run_after_loop').set_state(State.RUNNING, self.session)
+        dr.get_task_instance('run_after_loop').set_state(State.RUNNING)
         # failed task
-        dr.get_task_instance('run_this_last').set_state(State.FAILED, self.session)
+        dr.get_task_instance('run_this_last').set_state(State.FAILED)
 
     def _verify_task_instance_states_remain_default(self, dr):
         self.assertEqual(dr.get_task_instance('runme_0').state, State.SUCCESS)
         self.assertEqual(dr.get_task_instance('runme_1').state, State.SKIPPED)
         self.assertEqual(dr.get_task_instance('runme_2').state, State.UP_FOR_RETRY)
-        self.assertEqual(dr.get_task_instance('also_run_this').state, State.QUEUED, )
+        self.assertEqual(dr.get_task_instance('also_run_this').state, State.QUEUED)
         self.assertEqual(dr.get_task_instance('run_after_loop').state, State.RUNNING)
         self.assertEqual(dr.get_task_instance('run_this_last').state, State.FAILED)
 
-    def _verify_task_instance_states(self, dag, date, state):
+    @provide_session
+    def _verify_task_instance_states(self, dag, date, state, session):
         TI = models.TaskInstance
-        tis = self.session.query(TI).filter(TI.dag_id == dag.dag_id,
-                                            TI.execution_date == date)
+        tis = session.query(TI)\
+            .filter(TI.dag_id == dag.dag_id, TI.execution_date == date)
         for ti in tis:
             self.assertEqual(ti.state, state)
 
@@ -266,8 +256,7 @@ def _create_test_dag_run(self, state, date):
         return self.dag1.create_dagrun(
             run_id='manual__' + datetime.now().isoformat(),
             state=state,
-            execution_date=date,
-            session=self.session
+            execution_date=date
         )
 
     def _verify_dag_run_state(self, dag, date, state):
@@ -276,15 +265,22 @@ def _verify_dag_run_state(self, dag, date, state):
 
         self.assertEqual(dr.get_state(), state)
 
-    def _verify_dag_run_dates(self, dag, date, state, middle_time):
+    @provide_session
+    def _verify_dag_run_dates(self, dag, date, state, middle_time, session=None):
         # When target state is RUNNING, we should set start_date,
         # otherwise we should set end_date.
-        drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
-        dr = drs[0]
+        DR = DagRun
+        dr = session.query(DR).filter(
+            DR.dag_id == dag.dag_id,
+            DR.execution_date == date
+        ).one()
         if state == State.RUNNING:
+            # Since the DAG is running, the start_date must be updated after creation
             self.assertGreater(dr.start_date, middle_time)
+            # If the dag is still running, we don't have an end date
             self.assertIsNone(dr.end_date)
         else:
+            # If the dag is not running, there must be an end time
             self.assertLess(dr.start_date, middle_time)
             self.assertGreater(dr.end_date, middle_time)
 
@@ -324,7 +320,7 @@ def test_set_running_dag_run_to_running(self):
 
         altered = set_dag_run_state_to_running(self.dag1, date, commit=True)
 
-        # None of the tasks should be altered.
+        # None of the tasks should be altered, only the dag itself
         self.assertEqual(len(altered), 0)
         self._verify_dag_run_state(self.dag1, date, State.RUNNING)
         self._verify_task_instance_states_remain_default(dr)
@@ -366,7 +362,7 @@ def test_set_success_dag_run_to_running(self):
 
         altered = set_dag_run_state_to_running(self.dag1, date, commit=True)
 
-        # None of the tasks should be altered.
+        # None of the tasks should be altered, but only the dag object should be changed
         self.assertEqual(len(altered), 0)
         self._verify_dag_run_state(self.dag1, date, State.RUNNING)
         self._verify_task_instance_states_remain_default(dr)
@@ -406,9 +402,11 @@ def test_set_failed_dag_run_to_running(self):
         middle_time = timezone.utcnow()
         self._set_default_task_instance_states(dr)
 
+        time.sleep(2)
+
         altered = set_dag_run_state_to_running(self.dag1, date, commit=True)
 
-        # None of the tasks should be altered.
+        # None of the tasks should be altered, since we've only altered the DAG itself
         self.assertEqual(len(altered), 0)
         self._verify_dag_run_state(self.dag1, date, State.RUNNING)
         self._verify_task_instance_states_remain_default(dr)
@@ -440,28 +438,28 @@ def test_set_state_without_commit(self):
         self._verify_dag_run_state(self.dag1, date, State.RUNNING)
         self._verify_task_instance_states_remain_default(dr)
 
-    def test_set_state_with_multiple_dagruns(self):
+    @provide_session
+    def test_set_state_with_multiple_dagruns(self, session=None):
         self.dag2.create_dagrun(
             run_id='manual__' + datetime.now().isoformat(),
             state=State.FAILED,
             execution_date=self.execution_dates[0],
-            session=self.session
+            session=session
         )
         self.dag2.create_dagrun(
             run_id='manual__' + datetime.now().isoformat(),
             state=State.FAILED,
             execution_date=self.execution_dates[1],
-            session=self.session
+            session=session
         )
         self.dag2.create_dagrun(
             run_id='manual__' + datetime.now().isoformat(),
             state=State.RUNNING,
             execution_date=self.execution_dates[2],
-            session=self.session
+            session=session
         )
 
-        altered = set_dag_run_state_to_success(self.dag2, self.execution_dates[1],
-                                               commit=True)
+        altered = set_dag_run_state_to_success(self.dag2, self.execution_dates[1], commit=True)
 
         # Recursively count number of tasks in the dag
         def count_dag_tasks(dag):
@@ -513,10 +511,9 @@ def tearDown(self):
         self.dag1.clear()
         self.dag2.clear()
 
-        self.session.query(models.DagRun).delete()
-        self.session.query(models.TaskInstance).delete()
-        self.session.commit()
-        self.session.close()
+        with create_session() as session:
+            session.query(models.DagRun).delete()
+            session.query(models.TaskInstance).delete()
 
 
 if __name__ == '__main__':
diff --git a/tests/core.py b/tests/core.py
index 4d7b4af1ff..7b3f0e4d92 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -1164,13 +1164,13 @@ def test_cli_list_tasks(self):
             'list_tasks', 'example_bash_operator', '--tree'])
         cli.list_tasks(args)
 
-    @mock.patch("airflow.bin.cli.db_utils.initdb")
+    @mock.patch("airflow.bin.cli.db.initdb")
     def test_cli_initdb(self, initdb_mock):
         cli.initdb(self.parser.parse_args(['initdb']))
 
         initdb_mock.assert_called_once_with(False)
 
-    @mock.patch("airflow.bin.cli.db_utils.resetdb")
+    @mock.patch("airflow.bin.cli.db.resetdb")
     def test_cli_resetdb(self, resetdb_mock):
         cli.resetdb(self.parser.parse_args(['resetdb', '--yes']))
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services