You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by er...@apache.org on 2020/08/06 22:34:18 UTC

[incubator-superset] branch master updated: Revert "chore: Cleanup database sessions (#10427)" (#10537)

This is an automated email from the ASF dual-hosted git repository.

erikrit pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new fd2d1c5  Revert "chore: Cleanup database sessions (#10427)" (#10537)
fd2d1c5 is described below

commit fd2d1c58c566d9312d6cfc5641a06ac2b03e753a
Author: Erik Ritter <er...@airbnb.com>
AuthorDate: Thu Aug 6 15:33:48 2020 -0700

    Revert "chore: Cleanup database sessions (#10427)" (#10537)
    
    This reverts commit 7645fc85c3d6676a13ae76ca5133f83d8fb54dbe.
---
 superset/cli.py                           |  12 +-
 superset/commands/utils.py                |   6 +-
 superset/common/query_context.py          |   4 +-
 superset/connectors/connector_registry.py |  35 +++---
 superset/connectors/druid/models.py       |  52 ++++----
 superset/connectors/druid/views.py        |   5 +-
 superset/connectors/sqla/models.py        |  29 +++--
 superset/dashboards/dao.py                |   8 +-
 superset/models/dashboard.py              |  36 +++---
 superset/models/helpers.py                |  13 +-
 superset/models/slice.py                  |   8 +-
 superset/models/tags.py                   |  81 ++++++++-----
 superset/security/manager.py              |   6 +-
 superset/sql_lab.py                       |   3 +-
 superset/tasks/cache.py                   |  23 ++--
 superset/tasks/schedules.py               |  13 +-
 superset/utils/dashboard_import_export.py |  11 +-
 superset/utils/dict_import_export.py      |  19 +--
 superset/utils/import_datasource.py       |  22 ++--
 superset/views/base.py                    |   3 +-
 superset/views/chart/views.py             |   4 +-
 superset/views/core.py                    | 129 ++++++++++++--------
 superset/views/datasource.py              |   6 +-
 superset/views/utils.py                   |   7 +-
 tests/access_tests.py                     |  92 +++++++-------
 tests/alerts_tests.py                     | 192 +++++++++++++++---------------
 tests/base_tests.py                       |  34 +++---
 tests/celery_tests.py                     |  10 +-
 tests/charts/api_tests.py                 |   4 +-
 tests/core_tests.py                       |  24 ++--
 tests/database_api_tests.py               |   5 +-
 tests/datasets/api_tests.py               |   7 +-
 tests/dict_import_export_tests.py         |  60 ++++++----
 tests/druid_tests.py                      |  10 +-
 tests/import_export_tests.py              |  23 ++--
 tests/query_context_tests.py              |   1 +
 tests/security_tests.py                   | 132 ++++++++++----------
 tests/sqllab_tests.py                     |   6 +-
 tests/strategy_tests.py                   |   6 +-
 39 files changed, 645 insertions(+), 496 deletions(-)

diff --git a/superset/cli.py b/superset/cli.py
index ef9ee84..ef17682 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -197,9 +197,10 @@ def set_database_uri(database_name: str, uri: str) -> None:
 )
 def refresh_druid(datasource: str, merge: bool) -> None:
     """Refresh druid datasources"""
+    session = db.session()
     from superset.connectors.druid.models import DruidCluster
 
-    for cluster in db.session.query(DruidCluster).all():
+    for cluster in session.query(DruidCluster).all():
         try:
             cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
         except Exception as ex:  # pylint: disable=broad-except
@@ -207,7 +208,7 @@ def refresh_druid(datasource: str, merge: bool) -> None:
             logger.exception(ex)
         cluster.metadata_last_refreshed = datetime.now()
         print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
-    db.session.commit()
+    session.commit()
 
 
 @superset.command()
@@ -249,7 +250,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None:
         logger.info("Importing dashboard from file %s", file_)
         try:
             with file_.open() as data_stream:
-                dashboard_import_export.import_dashboards(data_stream)
+                dashboard_import_export.import_dashboards(db.session, data_stream)
         except Exception as ex:  # pylint: disable=broad-except
             logger.error("Error when importing dashboard from file %s", file_)
             logger.error(ex)
@@ -267,7 +268,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
     """Export dashboards to JSON"""
     from superset.utils import dashboard_import_export
 
-    data = dashboard_import_export.export_dashboards()
+    data = dashboard_import_export.export_dashboards(db.session)
     if print_stdout or not dashboard_file:
         print(data)
     if dashboard_file:
@@ -320,7 +321,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
         try:
             with file_.open() as data_stream:
                 dict_import_export.import_from_dict(
-                    yaml.safe_load(data_stream), sync=sync_array
+                    db.session, yaml.safe_load(data_stream), sync=sync_array
                 )
         except Exception as ex:  # pylint: disable=broad-except
             logger.error("Error when importing datasources from file %s", file_)
@@ -359,6 +360,7 @@ def export_datasources(
     from superset.utils import dict_import_export
 
     data = dict_import_export.export_to_dict(
+        session=db.session,
         recursive=True,
         back_references=back_references,
         include_defaults=include_defaults,
diff --git a/superset/commands/utils.py b/superset/commands/utils.py
index 66fd543..c0bd8b7 100644
--- a/superset/commands/utils.py
+++ b/superset/commands/utils.py
@@ -25,7 +25,7 @@ from superset.commands.exceptions import (
 )
 from superset.connectors.base.models import BaseDatasource
 from superset.connectors.connector_registry import ConnectorRegistry
-from superset.extensions import security_manager
+from superset.extensions import db, security_manager
 
 
 def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]:
@@ -50,6 +50,8 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[
 
 def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
     try:
-        return ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+        return ConnectorRegistry.get_datasource(
+            datasource_type, datasource_id, db.session
+        )
     except (NoResultFound, KeyError):
         raise DatasourceNotFoundValidationError()
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 401f262..e602fbf 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -23,7 +23,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
 import numpy as np
 import pandas as pd
 
-from superset import app, cache, security_manager
+from superset import app, cache, db, security_manager
 from superset.common.query_object import QueryObject
 from superset.connectors.base.models import BaseDatasource
 from superset.connectors.connector_registry import ConnectorRegistry
@@ -64,7 +64,7 @@ class QueryContext:
         result_format: Optional[utils.ChartDataResultFormat] = None,
     ) -> None:
         self.datasource = ConnectorRegistry.get_datasource(
-            str(datasource["type"]), int(datasource["id"])
+            str(datasource["type"]), int(datasource["id"]), db.session
         )
         self.queries = [QueryObject(**query_obj) for query_obj in queries]
         self.force = force
diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py
index 7c47bc7..fff2f8e 100644
--- a/superset/connectors/connector_registry.py
+++ b/superset/connectors/connector_registry.py
@@ -17,9 +17,7 @@
 from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
 
 from sqlalchemy import or_
-from sqlalchemy.orm import subqueryload
-
-from superset.extensions import db
+from sqlalchemy.orm import Session, subqueryload
 
 if TYPE_CHECKING:
     # pylint: disable=unused-import
@@ -45,20 +43,20 @@ class ConnectorRegistry:
 
     @classmethod
     def get_datasource(
-        cls, datasource_type: str, datasource_id: int
+        cls, datasource_type: str, datasource_id: int, session: Session
     ) -> "BaseDatasource":
         return (
-            db.session.query(cls.sources[datasource_type])
+            session.query(cls.sources[datasource_type])
             .filter_by(id=datasource_id)
             .one()
         )
 
     @classmethod
-    def get_all_datasources(cls) -> List["BaseDatasource"]:
+    def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
         datasources: List["BaseDatasource"] = []
         for source_type in ConnectorRegistry.sources:
             source_class = ConnectorRegistry.sources[source_type]
-            qry = db.session.query(source_class)
+            qry = session.query(source_class)
             qry = source_class.default_query(qry)
             datasources.extend(qry.all())
         return datasources
@@ -66,6 +64,7 @@ class ConnectorRegistry:
     @classmethod
     def get_datasource_by_name(  # pylint: disable=too-many-arguments
         cls,
+        session: Session,
         datasource_type: str,
         datasource_name: str,
         schema: str,
@@ -73,17 +72,21 @@ class ConnectorRegistry:
     ) -> Optional["BaseDatasource"]:
         datasource_class = ConnectorRegistry.sources[datasource_type]
         return datasource_class.get_datasource_by_name(
-            datasource_name, schema, database_name
+            session, datasource_name, schema, database_name
         )
 
     @classmethod
     def query_datasources_by_permissions(  # pylint: disable=invalid-name
-        cls, database: "Database", permissions: Set[str], schema_perms: Set[str],
+        cls,
+        session: Session,
+        database: "Database",
+        permissions: Set[str],
+        schema_perms: Set[str],
     ) -> List["BaseDatasource"]:
         # TODO(bogdan): add unit test
         datasource_class = ConnectorRegistry.sources[database.type]
         return (
-            db.session.query(datasource_class)
+            session.query(datasource_class)
             .filter_by(database_id=database.id)
             .filter(
                 or_(
@@ -96,12 +99,12 @@ class ConnectorRegistry:
 
     @classmethod
     def get_eager_datasource(
-        cls, datasource_type: str, datasource_id: int
+        cls, session: Session, datasource_type: str, datasource_id: int
     ) -> "BaseDatasource":
         """Returns datasource with columns and metrics."""
         datasource_class = ConnectorRegistry.sources[datasource_type]
         return (
-            db.session.query(datasource_class)
+            session.query(datasource_class)
             .options(
                 subqueryload(datasource_class.columns),
                 subqueryload(datasource_class.metrics),
@@ -112,9 +115,13 @@ class ConnectorRegistry:
 
     @classmethod
     def query_datasources_by_name(
-        cls, database: "Database", datasource_name: str, schema: Optional[str] = None,
+        cls,
+        session: Session,
+        database: "Database",
+        datasource_name: str,
+        schema: Optional[str] = None,
     ) -> List["BaseDatasource"]:
         datasource_class = ConnectorRegistry.sources[database.type]
         return datasource_class.query_datasources_by_name(
-            database, datasource_name, schema=schema
+            session, database, datasource_name, schema=schema
         )
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index 0068f11..162163f 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -45,7 +45,7 @@ from sqlalchemy import (
     UniqueConstraint,
 )
 from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import backref, relationship
+from sqlalchemy.orm import backref, relationship, Session
 from sqlalchemy.sql import expression
 from sqlalchemy_utils import EncryptedType
 
@@ -223,8 +223,9 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
         Fetches metadata for the specified datasources and
         merges to the Superset database
         """
+        session = db.session
         ds_list = (
-            db.session.query(DruidDatasource)
+            session.query(DruidDatasource)
             .filter(DruidDatasource.cluster_id == self.id)
             .filter(DruidDatasource.datasource_name.in_(datasource_names))
         )
@@ -233,8 +234,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
             datasource = ds_map.get(ds_name, None)
             if not datasource:
                 datasource = DruidDatasource(datasource_name=ds_name)
-                with db.session.no_autoflush:
-                    db.session.add(datasource)
+                with session.no_autoflush:
+                    session.add(datasource)
                 flasher(_("Adding new datasource [{}]").format(ds_name), "success")
                 ds_map[ds_name] = datasource
             elif refresh_all:
@@ -244,7 +245,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
                 continue
             datasource.cluster = self
             datasource.merge_flag = merge_flag
-        db.session.flush()
+        session.flush()
 
         # Prepare multithreaded executation
         pool = ThreadPool()
@@ -258,7 +259,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
             cols = metadata[i]
             if cols:
                 col_objs_list = (
-                    db.session.query(DruidColumn)
+                    session.query(DruidColumn)
                     .filter(DruidColumn.datasource_id == datasource.id)
                     .filter(DruidColumn.column_name.in_(cols.keys()))
                 )
@@ -271,15 +272,15 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
                         col_obj = DruidColumn(
                             datasource_id=datasource.id, column_name=col
                         )
-                        with db.session.no_autoflush:
-                            db.session.add(col_obj)
+                        with session.no_autoflush:
+                            session.add(col_obj)
                     col_obj.type = cols[col]["type"]
                     col_obj.datasource = datasource
                     if col_obj.type == "STRING":
                         col_obj.groupby = True
                         col_obj.filterable = True
                 datasource.refresh_metrics()
-        db.session.commit()
+        session.commit()
 
     @hybrid_property
     def perm(self) -> str:
@@ -389,7 +390,7 @@ class DruidColumn(Model, BaseColumn):
                 .first()
             )
 
-        return import_datasource.import_simple_obj(i_column, lookup_obj)
+        return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
 
 
 class DruidMetric(Model, BaseMetric):
@@ -458,7 +459,7 @@ class DruidMetric(Model, BaseMetric):
                 .first()
             )
 
-        return import_datasource.import_simple_obj(i_metric, lookup_obj)
+        return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
 
 
 druiddatasource_user = Table(
@@ -634,7 +635,7 @@ class DruidDatasource(Model, BaseDatasource):
             return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
 
         return import_datasource.import_datasource(
-            i_datasource, lookup_cluster, lookup_datasource, import_time
+            db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
         )
 
     def latest_metadata(self) -> Optional[Dict[str, Any]]:
@@ -704,10 +705,9 @@ class DruidDatasource(Model, BaseDatasource):
         refresh: bool = True,
     ) -> None:
         """Merges the ds config from druid_config into one stored in the db."""
+        session = db.session
         datasource = (
-            db.session.query(cls)
-            .filter_by(datasource_name=druid_config["name"])
-            .first()
+            session.query(cls).filter_by(datasource_name=druid_config["name"]).first()
         )
         # Create a new datasource.
         if not datasource:
@@ -718,13 +718,13 @@ class DruidDatasource(Model, BaseDatasource):
                 changed_by_fk=user.id,
                 created_by_fk=user.id,
             )
-            db.session.add(datasource)
+            session.add(datasource)
         elif not refresh:
             return
 
         dimensions = druid_config["dimensions"]
         col_objs = (
-            db.session.query(DruidColumn)
+            session.query(DruidColumn)
             .filter(DruidColumn.datasource_id == datasource.id)
             .filter(DruidColumn.column_name.in_(dimensions))
         )
@@ -741,10 +741,10 @@ class DruidDatasource(Model, BaseDatasource):
                     type="STRING",
                     datasource=datasource,
                 )
-                db.session.add(col_obj)
+                session.add(col_obj)
         # Import Druid metrics
         metric_objs = (
-            db.session.query(DruidMetric)
+            session.query(DruidMetric)
             .filter(DruidMetric.datasource_id == datasource.id)
             .filter(
                 DruidMetric.metric_name.in_(
@@ -777,8 +777,8 @@ class DruidDatasource(Model, BaseDatasource):
                         % druid_config["name"]
                     ),
                 )
-                db.session.add(metric_obj)
-        db.session.commit()
+                session.add(metric_obj)
+        session.commit()
 
     @staticmethod
     def time_offset(granularity: Granularity) -> int:
@@ -788,10 +788,10 @@ class DruidDatasource(Model, BaseDatasource):
 
     @classmethod
     def get_datasource_by_name(
-        cls, datasource_name: str, schema: str, database_name: str
+        cls, session: Session, datasource_name: str, schema: str, database_name: str
     ) -> Optional["DruidDatasource"]:
         query = (
-            db.session.query(cls)
+            session.query(cls)
             .join(DruidCluster)
             .filter(cls.datasource_name == datasource_name)
             .filter(DruidCluster.cluster_name == database_name)
@@ -1724,7 +1724,11 @@ class DruidDatasource(Model, BaseDatasource):
 
     @classmethod
     def query_datasources_by_name(
-        cls, database: Database, datasource_name: str, schema: Optional[str] = None,
+        cls,
+        session: Session,
+        database: Database,
+        datasource_name: str,
+        schema: Optional[str] = None,
     ) -> List["DruidDatasource"]:
         return []
 
diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py
index 4a22bc2..4c2fbf9 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -365,10 +365,11 @@ class Druid(BaseSupersetView):
         self, refresh_all: bool = True
     ) -> FlaskResponse:
         """endpoint that refreshes druid datasources metadata"""
+        session = db.session()
         DruidCluster = ConnectorRegistry.sources[  # pylint: disable=invalid-name
             "druid"
         ].cluster_class
-        for cluster in db.session.query(DruidCluster).all():
+        for cluster in session.query(DruidCluster).all():
             cluster_name = cluster.cluster_name
             valid_cluster = True
             try:
@@ -390,7 +391,7 @@ class Druid(BaseSupersetView):
                     ),
                     "info",
                 )
-        db.session.commit()
+        session.commit()
         return redirect("/druiddatasourcemodelview/list/")
 
     @has_access
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 82f3017..530a2e1 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -41,7 +41,7 @@ from sqlalchemy import (
     Text,
 )
 from sqlalchemy.exc import CompileError
-from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty
+from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
 from sqlalchemy.orm.exc import NoResultFound
 from sqlalchemy.schema import UniqueConstraint
 from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
@@ -255,7 +255,7 @@ class TableColumn(Model, BaseColumn):
                 .first()
             )
 
-        return import_datasource.import_simple_obj(i_column, lookup_obj)
+        return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
 
     def dttm_sql_literal(
         self,
@@ -375,7 +375,7 @@ class SqlMetric(Model, BaseMetric):
                 .first()
             )
 
-        return import_datasource.import_simple_obj(i_metric, lookup_obj)
+        return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
 
 
 sqlatable_user = Table(
@@ -503,11 +503,15 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
 
     @classmethod
     def get_datasource_by_name(
-        cls, datasource_name: str, schema: Optional[str], database_name: str,
+        cls,
+        session: Session,
+        datasource_name: str,
+        schema: Optional[str],
+        database_name: str,
     ) -> Optional["SqlaTable"]:
         schema = schema or None
         query = (
-            db.session.query(cls)
+            session.query(cls)
             .join(Database)
             .filter(cls.table_name == datasource_name)
             .filter(Database.database_name == database_name)
@@ -1292,15 +1296,24 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
                 )
 
         return import_datasource.import_datasource(
-            i_datasource, lookup_database, lookup_sqlatable, import_time, database_id,
+            db.session,
+            i_datasource,
+            lookup_database,
+            lookup_sqlatable,
+            import_time,
+            database_id,
         )
 
     @classmethod
     def query_datasources_by_name(
-        cls, database: Database, datasource_name: str, schema: Optional[str] = None,
+        cls,
+        session: Session,
+        database: Database,
+        datasource_name: str,
+        schema: Optional[str] = None,
     ) -> List["SqlaTable"]:
         query = (
-            db.session.query(cls)
+            session.query(cls)
             .filter_by(database_id=database.id)
             .filter_by(table_name=datasource_name)
         )
diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py
index 6345bb7..774e1c8 100644
--- a/superset/dashboards/dao.py
+++ b/superset/dashboards/dao.py
@@ -99,7 +99,9 @@ class DashboardDAO(BaseDAO):
                 except KeyError:
                     pass
 
-        current_slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
+        session = db.session()
+        current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
+
         dashboard.slices = current_slices
 
         # update slice names. this assumes user has permissions to update the slice
@@ -109,8 +111,8 @@ class DashboardDAO(BaseDAO):
                 new_name = slice_id_to_name[slc.id]
                 if slc.slice_name != new_name:
                     slc.slice_name = new_name
-                    db.session.merge(slc)
-                    db.session.flush()
+                    session.merge(slc)
+                    session.flush()
             except KeyError:
                 pass
 
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index 04e9251..1844543 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -37,7 +37,7 @@ from sqlalchemy import (
     UniqueConstraint,
 )
 from sqlalchemy.engine.base import Connection
-from sqlalchemy.orm import relationship, subqueryload
+from sqlalchemy.orm import relationship, sessionmaker, subqueryload
 from sqlalchemy.orm.mapper import Mapper
 
 from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
@@ -62,17 +62,18 @@ config = app.config
 logger = logging.getLogger(__name__)
 
 
-def copy_dashboard(  # pylint: disable=unused-argument
-    mapper: Mapper, connection: Connection, target: "Dashboard"
-) -> None:
+def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
+    # pylint: disable=unused-argument
     dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
     if dashboard_id is None:
         return
 
-    new_user = db.session.query(User).filter_by(id=target.id).first()
+    session_class = sessionmaker(autoflush=False)
+    session = session_class(bind=connection)
+    new_user = session.query(User).filter_by(id=target.id).first()
 
     # copy template dashboard to user
-    template = db.session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
+    template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
     dashboard = Dashboard(
         dashboard_title=template.dashboard_title,
         position_json=template.position_json,
@@ -82,15 +83,15 @@ def copy_dashboard(  # pylint: disable=unused-argument
         slices=template.slices,
         owners=[new_user],
     )
-    db.session.add(dashboard)
-    db.session.commit()
+    session.add(dashboard)
+    session.commit()
 
     # set dashboard as the welcome dashboard
     extra_attributes = UserAttribute(
         user_id=target.id, welcome_dashboard_id=dashboard.id
     )
-    db.session.add(extra_attributes)
-    db.session.commit()
+    session.add(extra_attributes)
+    session.commit()
 
 
 sqla.event.listen(User, "after_insert", copy_dashboard)
@@ -306,6 +307,7 @@ class Dashboard(  # pylint: disable=too-many-instance-attributes
         logger.info(
             "Started import of the dashboard: %s", dashboard_to_import.to_json()
         )
+        session = db.session
         logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
         # copy slices object as Slice.import_slice will mutate the slice
         # and will remove the existing dashboard - slice association
@@ -322,7 +324,7 @@ class Dashboard(  # pylint: disable=too-many-instance-attributes
         i_params_dict = dashboard_to_import.params_dict
         remote_id_slice_map = {
             slc.params_dict["remote_id"]: slc
-            for slc in db.session.query(Slice).all()
+            for slc in session.query(Slice).all()
             if "remote_id" in slc.params_dict
         }
         for slc in slices:
@@ -373,7 +375,7 @@ class Dashboard(  # pylint: disable=too-many-instance-attributes
 
         # override the dashboard
         existing_dashboard = None
-        for dash in db.session.query(Dashboard).all():
+        for dash in session.query(Dashboard).all():
             if (
                 "remote_id" in dash.params_dict
                 and dash.params_dict["remote_id"] == dashboard_to_import.id
@@ -400,7 +402,7 @@ class Dashboard(  # pylint: disable=too-many-instance-attributes
             )
 
         new_slices = (
-            db.session.query(Slice)
+            session.query(Slice)
             .filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
             .all()
         )
@@ -408,12 +410,12 @@ class Dashboard(  # pylint: disable=too-many-instance-attributes
         if existing_dashboard:
             existing_dashboard.override(dashboard_to_import)
             existing_dashboard.slices = new_slices
-            db.session.flush()
+            session.flush()
             return existing_dashboard.id
 
         dashboard_to_import.slices = new_slices
-        db.session.add(dashboard_to_import)
-        db.session.flush()
+        session.add(dashboard_to_import)
+        session.flush()
         return dashboard_to_import.id  # type: ignore
 
     @classmethod
@@ -455,7 +457,7 @@ class Dashboard(  # pylint: disable=too-many-instance-attributes
         eager_datasources = []
         for datasource_id, datasource_type in datasource_ids:
             eager_datasource = ConnectorRegistry.get_eager_datasource(
-                datasource_type, datasource_id
+                db.session, datasource_type, datasource_id
             )
             copied_datasource = eager_datasource.copy()
             copied_datasource.alter_params(
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index c67c6b6..d903d27 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -34,9 +34,9 @@ from flask_appbuilder.models.mixins import AuditMixin
 from flask_appbuilder.security.sqla.models import User
 from sqlalchemy import and_, or_, UniqueConstraint
 from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import MultipleResultsFound
 
-from superset.extensions import db
 from superset.utils.core import QueryStatus
 
 logger = logging.getLogger(__name__)
@@ -127,6 +127,7 @@ class ImportMixin:
     @classmethod
     def import_from_dict(  # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
         cls,
+        session: Session,
         dict_rep: Dict[Any, Any],
         parent: Optional[Any] = None,
         recursive: bool = True,
@@ -177,7 +178,7 @@ class ImportMixin:
 
         # Check if object already exists in DB, break if more than one is found
         try:
-            obj_query = db.session.query(cls).filter(and_(*filters))
+            obj_query = session.query(cls).filter(and_(*filters))
             obj = obj_query.one_or_none()
         except MultipleResultsFound as ex:
             logger.error(
@@ -195,7 +196,7 @@ class ImportMixin:
             logger.info("Importing new %s %s", obj.__tablename__, str(obj))
             if cls.export_parent and parent:
                 setattr(obj, cls.export_parent, parent)
-            db.session.add(obj)
+            session.add(obj)
         else:
             is_new_obj = False
             logger.info("Updating %s %s", obj.__tablename__, str(obj))
@@ -213,7 +214,7 @@ class ImportMixin:
                 for c_obj in new_children.get(child, []):
                     added.append(
                         child_class.import_from_dict(
-                            dict_rep=c_obj, parent=obj, sync=sync
+                            session=session, dict_rep=c_obj, parent=obj, sync=sync
                         )
                     )
                 # If children should get synced, delete the ones that did not
@@ -227,11 +228,11 @@ class ImportMixin:
                         for k in back_refs.keys()
                     ]
                     to_delete = set(
-                        db.session.query(child_class).filter(and_(*delete_filters))
+                        session.query(child_class).filter(and_(*delete_filters))
                     ).difference(set(added))
                     for o in to_delete:
                         logger.info("Deleting %s %s", child, str(obj))
-                        db.session.delete(o)
+                        session.delete(o)
 
         return obj
 
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 0a2e7d5..b7f9e05 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -300,6 +300,7 @@ class Slice(
         :returns: The resulting id for the imported slice
         :rtype: int
         """
+        session = db.session
         make_transient(slc_to_import)
         slc_to_import.dashboards = []
         slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
@@ -308,6 +309,7 @@ class Slice(
         slc_to_import.reset_ownership()
         params = slc_to_import.params_dict
         datasource = ConnectorRegistry.get_datasource_by_name(
+            session,
             slc_to_import.datasource_type,
             params["datasource_name"],
             params["schema"],
@@ -316,11 +318,11 @@ class Slice(
         slc_to_import.datasource_id = datasource.id  # type: ignore
         if slc_to_override:
             slc_to_override.override(slc_to_import)
-            db.session.flush()
+            session.flush()
             return slc_to_override.id
-        db.session.add(slc_to_import)
+        session.add(slc_to_import)
         logger.info("Final slice: %s", str(slc_to_import.to_json()))
-        db.session.flush()
+        session.flush()
         return slc_to_import.id
 
     @property
diff --git a/superset/models/tags.py b/superset/models/tags.py
index 1302ff5..c09bb16 100644
--- a/superset/models/tags.py
+++ b/superset/models/tags.py
@@ -22,11 +22,10 @@ from typing import List, Optional, TYPE_CHECKING, Union
 from flask_appbuilder import Model
 from sqlalchemy import Column, Enum, ForeignKey, Integer, String
 from sqlalchemy.engine.base import Connection
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, Session, sessionmaker
 from sqlalchemy.orm.exc import NoResultFound
 from sqlalchemy.orm.mapper import Mapper
 
-from superset.extensions import db
 from superset.models.helpers import AuditMixinNullable
 
 if TYPE_CHECKING:
@@ -35,6 +34,8 @@ if TYPE_CHECKING:
     from superset.models.slice import Slice  # pylint: disable=unused-import
     from superset.models.sql_lab import Query  # pylint: disable=unused-import
 
+Session = sessionmaker(autoflush=False)
+
 
 class TagTypes(enum.Enum):
 
@@ -87,13 +88,13 @@ class TaggedObject(Model, AuditMixinNullable):
     tag = relationship("Tag", backref="objects")
 
 
-def get_tag(name: str, type_: TagTypes) -> Tag:
+def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
     try:
-        tag = db.session.query(Tag).filter_by(name=name, type=type_).one()
+        tag = session.query(Tag).filter_by(name=name, type=type_).one()
     except NoResultFound:
         tag = Tag(name=name, type=type_)
-        db.session.add(tag)
-        db.session.commit()
+        session.add(tag)
+        session.commit()
 
     return tag
 
@@ -121,43 +122,52 @@ class ObjectUpdater:
         raise NotImplementedError("Subclass should implement `get_owners_ids`")
 
     @classmethod
-    def _add_owners(cls, target: Union["Dashboard", "FavStar", "Slice"]) -> None:
+    def _add_owners(
+        cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
+    ) -> None:
         for owner_id in cls.get_owners_ids(target):
             name = "owner:{0}".format(owner_id)
-            tag = get_tag(name, TagTypes.owner)
+            tag = get_tag(name, session, TagTypes.owner)
             tagged_object = TaggedObject(
                 tag_id=tag.id, object_id=target.id, object_type=cls.object_type
             )
-            db.session.add(tagged_object)
+            session.add(tagged_object)
 
     @classmethod
-    def after_insert(  # pylint: disable=unused-argument
+    def after_insert(
         cls,
         mapper: Mapper,
         connection: Connection,
         target: Union["Dashboard", "FavStar", "Slice"],
     ) -> None:
+        # pylint: disable=unused-argument
+        session = Session(bind=connection)
+
         # add `owner:` tags
-        cls._add_owners(target)
+        cls._add_owners(session, target)
 
         # add `type:` tags
-        tag = get_tag("type:{0}".format(cls.object_type), TagTypes.type)
+        tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type)
         tagged_object = TaggedObject(
             tag_id=tag.id, object_id=target.id, object_type=cls.object_type
         )
-        db.session.add(tagged_object)
-        db.session.commit()
+        session.add(tagged_object)
+
+        session.commit()
 
     @classmethod
-    def after_update(  # pylint: disable=unused-argument
+    def after_update(
         cls,
         mapper: Mapper,
         connection: Connection,
         target: Union["Dashboard", "FavStar", "Slice"],
     ) -> None:
+        # pylint: disable=unused-argument
+        session = Session(bind=connection)
+
         # delete current `owner:` tags
         query = (
-            db.session.query(TaggedObject.id)
+            session.query(TaggedObject.id)
             .join(Tag)
             .filter(
                 TaggedObject.object_type == cls.object_type,
@@ -166,28 +176,32 @@ class ObjectUpdater:
             )
         )
         ids = [row[0] for row in query]
-        db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
+        session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
             synchronize_session=False
         )
 
         # add `owner:` tags
-        cls._add_owners(target)
-        db.session.commit()
+        cls._add_owners(session, target)
+
+        session.commit()
 
     @classmethod
-    def after_delete(  # pylint: disable=unused-argument
+    def after_delete(
         cls,
         mapper: Mapper,
         connection: Connection,
         target: Union["Dashboard", "FavStar", "Slice"],
     ) -> None:
+        # pylint: disable=unused-argument
+        session = Session(bind=connection)
+
         # delete row from `tagged_objects`
-        db.session.query(TaggedObject).filter(
+        session.query(TaggedObject).filter(
             TaggedObject.object_type == cls.object_type,
             TaggedObject.object_id == target.id,
         ).delete()
 
-        db.session.commit()
+        session.commit()
 
 
 class ChartUpdater(ObjectUpdater):
@@ -219,26 +233,31 @@ class QueryUpdater(ObjectUpdater):
 
 class FavStarUpdater:
     @classmethod
-    def after_insert(  # pylint: disable=unused-argument
+    def after_insert(
         cls, mapper: Mapper, connection: Connection, target: "FavStar"
     ) -> None:
+        # pylint: disable=unused-argument
+        session = Session(bind=connection)
         name = "favorited_by:{0}".format(target.user_id)
-        tag = get_tag(name, TagTypes.favorited_by)
+        tag = get_tag(name, session, TagTypes.favorited_by)
         tagged_object = TaggedObject(
             tag_id=tag.id,
             object_id=target.obj_id,
             object_type=get_object_type(target.class_name),
         )
-        db.session.add(tagged_object)
-        db.session.commit()
+        session.add(tagged_object)
+
+        session.commit()
 
     @classmethod
-    def after_delete(  # pylint: disable=unused-argument
-        cls, mapper: Mapper, connection: Connection, target: "FavStar",
+    def after_delete(
+        cls, mapper: Mapper, connection: Connection, target: "FavStar"
     ) -> None:
+        # pylint: disable=unused-argument
+        session = Session(bind=connection)
         name = "favorited_by:{0}".format(target.user_id)
         query = (
-            db.session.query(TaggedObject.id)
+            session.query(TaggedObject.id)
             .join(Tag)
             .filter(
                 TaggedObject.object_id == target.obj_id,
@@ -247,8 +266,8 @@ class FavStarUpdater:
             )
         )
         ids = [row[0] for row in query]
-        db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
+        session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
             synchronize_session=False
         )
 
-        db.session.commit()
+        session.commit()
diff --git a/superset/security/manager.py b/superset/security/manager.py
index f731c42..da92d16 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -507,7 +507,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
         user_perms = self.user_view_menu_names("datasource_access")
         schema_perms = self.user_view_menu_names("schema_access")
         user_datasources = ConnectorRegistry.query_datasources_by_permissions(
-            database, user_perms, schema_perms
+            self.get_session, database, user_perms, schema_perms
         )
         if schema:
             names = {d.table_name for d in user_datasources if d.schema == schema}
@@ -568,7 +568,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 self.add_permission_view_menu(view_menu, perm)
 
         logger.info("Creating missing datasource permissions.")
-        datasources = ConnectorRegistry.get_all_datasources()
+        datasources = ConnectorRegistry.get_all_datasources(self.get_session)
         for datasource in datasources:
             merge_pv("datasource_access", datasource.get_perm())
             merge_pv("schema_access", datasource.get_schema_perm())
@@ -901,7 +901,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
 
                 if not (schema_perm and self.can_access("schema_access", schema_perm)):
                     datasources = SqlaTable.query_datasources_by_name(
-                        database, table_.table, schema=table_.schema
+                        self.get_session, database, table_.table, schema=table_.schema
                     )
 
                     # Access to any datasource is suffice.
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index d941473..8c3f24f 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -132,7 +132,8 @@ def session_scope(nullpool: bool) -> Iterator[Session]:
         )
     if nullpool:
         engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
-        session_class = sessionmaker(bind=engine)
+        session_class = sessionmaker()
+        session_class.configure(bind=engine)
         session = session_class()
     else:
         session = db.session()
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index f4c9e32..54b0dc1 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -134,7 +134,8 @@ class DummyStrategy(Strategy):
     name = "dummy"
 
     def get_urls(self) -> List[str]:
-        charts = db.session.query(Slice).all()
+        session = db.create_scoped_session()
+        charts = session.query(Slice).all()
 
         return [get_url(chart) for chart in charts]
 
@@ -166,9 +167,10 @@ class TopNDashboardsStrategy(Strategy):
 
     def get_urls(self) -> List[str]:
         urls = []
+        session = db.create_scoped_session()
 
         records = (
-            db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
+            session.query(Log.dashboard_id, func.count(Log.dashboard_id))
             .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
             .group_by(Log.dashboard_id)
             .order_by(func.count(Log.dashboard_id).desc())
@@ -176,9 +178,7 @@ class TopNDashboardsStrategy(Strategy):
             .all()
         )
         dash_ids = [record.dashboard_id for record in records]
-        dashboards = (
-            db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
-        )
+        dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
         for dashboard in dashboards:
             for chart in dashboard.slices:
                 form_data_with_filters = get_form_data(chart.id, dashboard)
@@ -211,13 +211,14 @@ class DashboardTagsStrategy(Strategy):
 
     def get_urls(self) -> List[str]:
         urls = []
+        session = db.create_scoped_session()
 
-        tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
+        tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
         tag_ids = [tag.id for tag in tags]
 
         # add dashboards that are tagged
         tagged_objects = (
-            db.session.query(TaggedObject)
+            session.query(TaggedObject)
             .filter(
                 and_(
                     TaggedObject.object_type == "dashboard",
@@ -227,16 +228,14 @@ class DashboardTagsStrategy(Strategy):
             .all()
         )
         dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
-        tagged_dashboards = db.session.query(Dashboard).filter(
-            Dashboard.id.in_(dash_ids)
-        )
+        tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
         for dashboard in tagged_dashboards:
             for chart in dashboard.slices:
                 urls.append(get_url(chart))
 
         # add charts that are tagged
         tagged_objects = (
-            db.session.query(TaggedObject)
+            session.query(TaggedObject)
             .filter(
                 and_(
                     TaggedObject.object_type == "chart",
@@ -246,7 +245,7 @@ class DashboardTagsStrategy(Strategy):
             .all()
         )
         chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
-        tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
+        tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
         for chart in tagged_charts:
             urls.append(get_url(chart))
 
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 0a74ddf..4fd55aa 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -47,6 +47,7 @@ from flask_login import login_user
 from retry.api import retry_call
 from selenium.common.exceptions import WebDriverException
 from selenium.webdriver import chrome, firefox
+from sqlalchemy.orm import Session
 from werkzeug.http import parse_cookie
 
 from superset import app, db, security_manager, thumbnail_cache
@@ -541,7 +542,8 @@ def schedule_alert_query(  # pylint: disable=unused-argument
     is_test_alert: Optional[bool] = False,
 ) -> None:
     model_cls = get_scheduler_model(report_type)
-    schedule = db.session.query(model_cls).get(schedule_id)
+    dbsession = db.create_scoped_session()
+    schedule = dbsession.query(model_cls).get(schedule_id)
 
     # The user may have disabled the schedule. If so, ignore this
     if not schedule or not schedule.active:
@@ -553,7 +555,7 @@ def schedule_alert_query(  # pylint: disable=unused-argument
             deliver_alert(schedule.id, recipients)
             return
 
-        if run_alert_query(schedule.id):
+        if run_alert_query(schedule.id, dbsession):
             # deliver_dashboard OR deliver_slice
             return
     else:
@@ -616,7 +618,7 @@ def deliver_alert(alert_id: int, recipients: Optional[str] = None) -> None:
     _deliver_email(recipients, deliver_as_group, subject, body, data, images)
 
 
-def run_alert_query(alert_id: int) -> Optional[bool]:
+def run_alert_query(alert_id: int, dbsession: Session) -> Optional[bool]:
     """
     Execute alert.sql and return value if any rows are returned
     """
@@ -670,7 +672,7 @@ def run_alert_query(alert_id: int) -> Optional[bool]:
             state=state,
         )
     )
-    db.session.commit()
+    dbsession.commit()
 
     return None
 
@@ -710,7 +712,8 @@ def schedule_window(
     if not model_cls:
         return None
 
-    schedules = db.session.query(model_cls).filter(model_cls.active.is_(True))
+    dbsession = db.create_scoped_session()
+    schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
 
     for schedule in schedules:
         logging.info("Processing schedule %s", schedule)
diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py
index f8f673d..6ae500b 100644
--- a/superset/utils/dashboard_import_export.py
+++ b/superset/utils/dashboard_import_export.py
@@ -22,10 +22,10 @@ from io import BytesIO
 from typing import Any, Dict, Optional
 
 from flask_babel import lazy_gettext as _
+from sqlalchemy.orm import Session
 
 from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
 from superset.exceptions import DashboardImportException
-from superset.extensions import db
 from superset.models.dashboard import Dashboard
 from superset.models.slice import Slice
 
@@ -71,6 +71,7 @@ def decode_dashboards(  # pylint: disable=too-many-return-statements
 
 
 def import_dashboards(
+    session: Session,
     data_stream: BytesIO,
     database_id: Optional[int] = None,
     import_time: Optional[int] = None,
@@ -83,16 +84,16 @@ def import_dashboards(
         raise DashboardImportException(_("No data in file"))
     for table in data["datasources"]:
         type(table).import_obj(table, database_id, import_time=import_time)
-    db.session.commit()
+    session.commit()
     for dashboard in data["dashboards"]:
         Dashboard.import_obj(dashboard, import_time=import_time)
-    db.session.commit()
+    session.commit()
 
 
-def export_dashboards() -> str:
+def export_dashboards(session: Session) -> str:
     """Returns all dashboards metadata as a json dump"""
     logger.info("Starting export")
-    dashboards = db.session.query(Dashboard)
+    dashboards = session.query(Dashboard)
     dashboard_ids = []
     for dashboard in dashboards:
         dashboard_ids.append(dashboard.id)
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index 8edae22..4d9e049 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -17,8 +17,9 @@
 import logging
 from typing import Any, Dict, List, Optional
 
+from sqlalchemy.orm import Session
+
 from superset.connectors.druid.models import DruidCluster
-from superset.extensions import db
 from superset.models.core import Database
 
 DATABASES_KEY = "databases"
@@ -43,11 +44,11 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
 
 
 def export_to_dict(
-    recursive: bool, back_references: bool, include_defaults: bool
+    session: Session, recursive: bool, back_references: bool, include_defaults: bool
 ) -> Dict[str, Any]:
     """Exports databases and druid clusters to a dictionary"""
     logger.info("Starting export")
-    dbs = db.session.query(Database)
+    dbs = session.query(Database)
     databases = [
         database.export_to_dict(
             recursive=recursive,
@@ -57,7 +58,7 @@ def export_to_dict(
         for database in dbs
     ]
     logger.info("Exported %d %s", len(databases), DATABASES_KEY)
-    cls = db.session.query(DruidCluster)
+    cls = session.query(DruidCluster)
     clusters = [
         cluster.export_to_dict(
             recursive=recursive,
@@ -75,20 +76,22 @@ def export_to_dict(
     return data
 
 
-def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None:
+def import_from_dict(
+    session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
+) -> None:
     """Imports databases and druid clusters from dictionary"""
     if not sync:
         sync = []
     if isinstance(data, dict):
         logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
         for database in data.get(DATABASES_KEY, []):
-            Database.import_from_dict(database, sync=sync)
+            Database.import_from_dict(session, database, sync=sync)
 
         logger.info(
             "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
         )
         for datasource in data.get(DRUID_CLUSTERS_KEY, []):
-            DruidCluster.import_from_dict(datasource, sync=sync)
-        db.session.commit()
+            DruidCluster.import_from_dict(session, datasource, sync=sync)
+        session.commit()
     else:
         logger.info("Supplied object is not a dictionary.")
diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py
index a59a3d6..25da876 100644
--- a/superset/utils/import_datasource.py
+++ b/superset/utils/import_datasource.py
@@ -18,14 +18,14 @@ import logging
 from typing import Callable, Optional
 
 from flask_appbuilder import Model
+from sqlalchemy.orm import Session
 from sqlalchemy.orm.session import make_transient
 
-from superset.extensions import db
-
 logger = logging.getLogger(__name__)
 
 
 def import_datasource(  # pylint: disable=too-many-arguments
+    session: Session,
     i_datasource: Model,
     lookup_database: Callable[[Model], Model],
     lookup_datasource: Callable[[Model], Model],
@@ -52,11 +52,11 @@ def import_datasource(  # pylint: disable=too-many-arguments
 
     if datasource:
         datasource.override(i_datasource)
-        db.session.flush()
+        session.flush()
     else:
         datasource = i_datasource.copy()
-        db.session.add(datasource)
-        db.session.flush()
+        session.add(datasource)
+        session.flush()
 
     for metric in i_datasource.metrics:
         new_m = metric.copy()
@@ -81,11 +81,13 @@ def import_datasource(  # pylint: disable=too-many-arguments
         imported_c = i_datasource.column_class.import_obj(new_c)
         if imported_c.column_name not in [c.column_name for c in datasource.columns]:
             datasource.columns.append(imported_c)
-    db.session.flush()
+    session.flush()
     return datasource.id
 
 
-def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
+def import_simple_obj(
+    session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
+) -> Model:
     make_transient(i_obj)
     i_obj.id = None
     i_obj.table = None
@@ -95,9 +97,9 @@ def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Mod
     i_obj.table = None
     if existing_column:
         existing_column.override(i_obj)
-        db.session.flush()
+        session.flush()
         return existing_column
 
-    db.session.add(i_obj)
-    db.session.flush()
+    session.add(i_obj)
+    session.flush()
     return i_obj
diff --git a/superset/views/base.py b/superset/views/base.py
index 58c4943..7aeae79 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -487,7 +487,8 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
     roles = [r.name for r in get_user_roles()]
     if "Admin" in roles:
         return True
-    orig_obj = db.session.query(obj.__class__).filter_by(id=obj.id).first()
+    scoped_session = db.create_scoped_session()
+    orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
 
     # Making a list of owners that works across ORM models
     owners: List[User] = []
diff --git a/superset/views/chart/views.py b/superset/views/chart/views.py
index db100a7..0523e33 100644
--- a/superset/views/chart/views.py
+++ b/superset/views/chart/views.py
@@ -20,7 +20,7 @@ from flask_appbuilder import expose, has_access
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 from flask_babel import lazy_gettext as _
 
-from superset import app
+from superset import app, db
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.constants import RouteMethod
 from superset.models.slice import Slice
@@ -56,7 +56,7 @@ class SliceModelView(
     def add(self) -> FlaskResponse:
         datasources = [
             {"value": str(d.id) + "__" + d.type, "label": repr(d)}
-            for d in ConnectorRegistry.get_all_datasources()
+            for d in ConnectorRegistry.get_all_datasources(db.session)
         ]
         return self.render_template(
             "superset/add_slice.html",
diff --git a/superset/views/core.py b/superset/views/core.py
index bcbce02..f3a2264 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -40,6 +40,7 @@ from sqlalchemy.exc import (
     OperationalError,
     SQLAlchemyError,
 )
+from sqlalchemy.orm.session import Session
 from werkzeug.urls import Href
 
 import superset.models.core as models
@@ -163,7 +164,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             sorted(
                 [
                     datasource.short_data
-                    for datasource in ConnectorRegistry.get_all_datasources()
+                    for datasource in ConnectorRegistry.get_all_datasources(db.session)
                     if datasource.short_data.get("name")
                 ],
                 key=lambda datasource: datasource["name"],
@@ -202,7 +203,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
                     )
                     db_ds_names.add(fullname)
 
-        existing_datasources = ConnectorRegistry.get_all_datasources()
+        existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
         datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
         role = security_manager.find_role(role_name)
         # remove all permissions
@@ -269,15 +270,15 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
     @has_access
     @expose("/approve")
     def approve(self) -> FlaskResponse:  # pylint: disable=too-many-locals,no-self-use
-        def clean_fulfilled_requests() -> None:
-            for dar in db.session.query(DAR).all():
+        def clean_fulfilled_requests(session: Session) -> None:
+            for dar in session.query(DAR).all():
                 datasource = ConnectorRegistry.get_datasource(
-                    dar.datasource_type, dar.datasource_id
+                    dar.datasource_type, dar.datasource_id, session
                 )
                 if not datasource or security_manager.can_access_datasource(datasource):
                     # datasource does not exist anymore
-                    db.session.delete(dar)
-            db.session.commit()
+                    session.delete(dar)
+            session.commit()
 
         datasource_type = request.args["datasource_type"]
         datasource_id = request.args["datasource_id"]
@@ -285,7 +286,10 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         role_to_grant = request.args.get("role_to_grant")
         role_to_extend = request.args.get("role_to_extend")
 
-        datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+        session = db.session
+        datasource = ConnectorRegistry.get_datasource(
+            datasource_type, datasource_id, session
+        )
 
         if not datasource:
             flash(DATASOURCE_MISSING_ERR, "alert")
@@ -297,7 +301,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             return json_error_response(USER_MISSING_ERR)
 
         requests = (
-            db.session.query(DAR)
+            session.query(DAR)
             .filter(
                 DAR.datasource_id == datasource_id,
                 DAR.datasource_type == datasource_type,
@@ -357,13 +361,13 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
                     app.config,
                 )
                 flash(msg, "info")
-            clean_fulfilled_requests()
+            clean_fulfilled_requests(session)
         else:
             flash(__("You have no permission to approve this request"), "danger")
             return redirect("/accessrequestsmodelview/list/")
         for request_ in requests:
-            db.session.delete(request_)
-        db.session.commit()
+            session.delete(request_)
+        session.commit()
         return redirect("/accessrequestsmodelview/list/")
 
     @has_access
@@ -544,7 +548,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             database_id = request.form.get("db_id")
             try:
                 dashboard_import_export.import_dashboards(
-                    import_file.stream, database_id
+                    db.session, import_file.stream, database_id
                 )
                 success = True
             except DatabaseNotFound as ex:
@@ -626,7 +630,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             return redirect(error_redirect)
 
         datasource = ConnectorRegistry.get_datasource(
-            cast(str, datasource_type), datasource_id
+            cast(str, datasource_type), datasource_id, db.session
         )
         if not datasource:
             flash(DATASOURCE_MISSING_ERR, "danger")
@@ -745,7 +749,9 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         :raises SupersetSecurityException: If the user cannot access the resource
         """
         # TODO: Cache endpoint by user, datasource and column
-        datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+        datasource = ConnectorRegistry.get_datasource(
+            datasource_type, datasource_id, db.session
+        )
         if not datasource:
             return json_error_response(DATASOURCE_MISSING_ERR)
 
@@ -1009,9 +1015,10 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         self, dashboard_id: int
     ) -> FlaskResponse:
         """Copy dashboard"""
+        session = db.session()
         data = json.loads(request.form["data"])
         dash = models.Dashboard()
-        original_dash = db.session.query(Dashboard).get(dashboard_id)
+        original_dash = session.query(Dashboard).get(dashboard_id)
 
         dash.owners = [g.user] if g.user else []
         dash.dashboard_title = data["dashboard_title"]
@@ -1022,8 +1029,8 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             for slc in original_dash.slices:
                 new_slice = slc.clone()
                 new_slice.owners = [g.user] if g.user else []
-                db.session.add(new_slice)
-                db.session.flush()
+                session.add(new_slice)
+                session.flush()
                 new_slice.dashboards.append(dash)
                 old_to_new_slice_ids[slc.id] = new_slice.id
 
@@ -1039,9 +1046,10 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         dash.params = original_dash.params
 
         DashboardDAO.set_dash_metadata(dash, data, old_to_new_slice_ids)
-        db.session.add(dash)
-        db.session.commit()
+        session.add(dash)
+        session.commit()
         dash_json = json.dumps(dash.data)
+        session.close()
         return json_success(dash_json)
 
     @api
@@ -1051,12 +1059,14 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         self, dashboard_id: int
     ) -> FlaskResponse:
         """Save a dashboard's metadata"""
-        dash = db.session.query(Dashboard).get(dashboard_id)
+        session = db.session()
+        dash = session.query(Dashboard).get(dashboard_id)
         check_ownership(dash, raise_if_false=True)
         data = json.loads(request.form["data"])
         DashboardDAO.set_dash_metadata(dash, data)
-        db.session.merge(dash)
-        db.session.commit()
+        session.merge(dash)
+        session.commit()
+        session.close()
         return json_success(json.dumps({"status": "SUCCESS"}))
 
     @api
@@ -1067,12 +1077,14 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
     ) -> FlaskResponse:
         """Add and save slices to a dashboard"""
         data = json.loads(request.form["data"])
-        dash = db.session.query(Dashboard).get(dashboard_id)
+        session = db.session()
+        dash = session.query(Dashboard).get(dashboard_id)
         check_ownership(dash, raise_if_false=True)
-        new_slices = db.session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
+        new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
         dash.slices += new_slices
-        db.session.merge(dash)
-        db.session.commit()
+        session.merge(dash)
+        session.commit()
+        session.close()
         return "SLICES ADDED"
 
     @api
@@ -1419,6 +1431,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
 
         Note for slices a force refresh occurs.
         """
+        session = db.session()
         slice_id = request.args.get("slice_id")
         dashboard_id = request.args.get("dashboard_id")
         table_name = request.args.get("table_name")
@@ -1433,14 +1446,14 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
                 status=400,
             )
         if slice_id:
-            slices = db.session.query(Slice).filter_by(id=slice_id).all()
+            slices = session.query(Slice).filter_by(id=slice_id).all()
             if not slices:
                 return json_error_response(
                     __("Chart %(id)s not found", id=slice_id), status=404
                 )
         elif table_name and db_name:
             table = (
-                db.session.query(SqlaTable)
+                session.query(SqlaTable)
                 .join(models.Database)
                 .filter(
                     models.Database.database_name == db_name
@@ -1457,7 +1470,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
                     status=404,
                 )
             slices = (
-                db.session.query(Slice)
+                session.query(Slice)
                 .filter_by(datasource_id=table.id, datasource_type=table.type)
                 .all()
             )
@@ -1500,16 +1513,17 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         self, class_name: str, obj_id: int, action: str
     ) -> FlaskResponse:
         """Toggle favorite stars on Slices and Dashboard"""
+        session = db.session()
         FavStar = models.FavStar
         count = 0
         favs = (
-            db.session.query(FavStar)
+            session.query(FavStar)
             .filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id())
             .all()
         )
         if action == "select":
             if not favs:
-                db.session.add(
+                session.add(
                     FavStar(
                         class_name=class_name,
                         obj_id=obj_id,
@@ -1520,10 +1534,10 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             count = 1
         elif action == "unselect":
             for fav in favs:
-                db.session.delete(fav)
+                session.delete(fav)
         else:
             count = len(favs)
-        db.session.commit()
+        session.commit()
         return json_success(json.dumps({"count": count}))
 
     @api
@@ -1536,13 +1550,12 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         logger.warning(
             "This API endpoint is deprecated and will be removed in version 1.0.0"
         )
+        session = db.session()
         Role = ab_models.Role
         dash = (
-            db.session.query(Dashboard)
-            .filter(Dashboard.id == dashboard_id)
-            .one_or_none()
+            session.query(Dashboard).filter(Dashboard.id == dashboard_id).one_or_none()
         )
-        admin_role = db.session.query(Role).filter(Role.name == "Admin").one_or_none()
+        admin_role = session.query(Role).filter(Role.name == "Admin").one_or_none()
 
         if request.method == "GET":
             if dash:
@@ -1561,7 +1574,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             )
 
         dash.published = str(request.form["published"]).lower() == "true"
-        db.session.commit()
+        session.commit()
         return json_success(json.dumps({"published": dash.published}))
 
     @has_access
@@ -1570,7 +1583,8 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         self, dashboard_id_or_slug: str
     ) -> FlaskResponse:
         """Server side rendering for a dashboard"""
-        qry = db.session.query(Dashboard)
+        session = db.session()
+        qry = session.query(Dashboard)
         if dashboard_id_or_slug.isdigit():
             qry = qry.filter_by(id=int(dashboard_id_or_slug))
         else:
@@ -2028,7 +2042,8 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
                 "SQL validation does not support template parameters", status=400
             )
 
-        mydb = db.session.query(models.Database).filter_by(id=database_id).one_or_none()
+        session = db.session()
+        mydb = session.query(models.Database).filter_by(id=database_id).one_or_none()
         if not mydb:
             return json_error_response(
                 "Database with id {} is missing.".format(database_id), status=400
@@ -2077,6 +2092,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
 
     @staticmethod
     def _sql_json_async(  # pylint: disable=too-many-arguments
+        session: Session,
         rendered_query: str,
         query: Query,
         expand_data: bool,
@@ -2085,6 +2101,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         """
         Send SQL JSON query to celery workers.
 
+        :param session: SQLAlchemy session object
         :param rendered_query: the rendered query to perform by workers
         :param query: The query (SQLAlchemy) object
         :return: A Flask Response
@@ -2115,7 +2132,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             )
             query.status = QueryStatus.FAILED
             query.error_message = msg
-            db.session.commit()
+            session.commit()
             return json_error_response("{}".format(msg))
         resp = json_success(
             json.dumps(
@@ -2125,11 +2142,12 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             ),
             status=202,
         )
-        db.session.commit()
+        session.commit()
         return resp
 
     @staticmethod
     def _sql_json_sync(
+        _session: Session,
         rendered_query: str,
         query: Query,
         expand_data: bool,
@@ -2223,7 +2241,8 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         tab_name: str = cast(str, query_params.get("tab"))
         status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
 
-        mydb = db.session.query(models.Database).get(database_id)
+        session = db.session()
+        mydb = session.query(models.Database).get(database_id)
         if not mydb:
             return json_error_response("Database with id %i is missing.", database_id)
 
@@ -2254,13 +2273,13 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             client_id=client_id,
         )
         try:
-            db.session.add(query)
-            db.session.flush()
+            session.add(query)
+            session.flush()
             query_id = query.id
-            db.session.commit()  # shouldn't be necessary
+            session.commit()  # shouldn't be necessary
         except SQLAlchemyError as ex:
             logger.error("Errors saving query details %s", str(ex))
-            db.session.rollback()
+            session.rollback()
             raise Exception(_("Query record was not created as expected."))
         if not query_id:
             raise Exception(_("Query record was not created as expected."))
@@ -2271,7 +2290,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             query.raise_for_access()
         except SupersetSecurityException as ex:
             query.status = QueryStatus.FAILED
-            db.session.commit()
+            session.commit()
             return json_errors_response([ex.error], status=403)
 
         try:
@@ -2304,9 +2323,13 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
 
         # Async request.
         if async_flag:
-            return self._sql_json_async(rendered_query, query, expand_data, log_params)
+            return self._sql_json_async(
+                session, rendered_query, query, expand_data, log_params
+            )
         # Sync request.
-        return self._sql_json_sync(rendered_query, query, expand_data, log_params)
+        return self._sql_json_sync(
+            session, rendered_query, query, expand_data, log_params
+        )
 
     @has_access
     @expose("/csv/<client_id>")
@@ -2375,7 +2398,9 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
         """
 
         datasource_id, datasource_type = request.args["datasourceKey"].split("__")
-        datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+        datasource = ConnectorRegistry.get_datasource(
+            datasource_type, datasource_id, db.session
+        )
         # Check if datasource exists
         if not datasource:
             return json_error_response(DATASOURCE_MISSING_ERR)
diff --git a/superset/views/datasource.py b/superset/views/datasource.py
index c2affcb..2ce1102 100644
--- a/superset/views/datasource.py
+++ b/superset/views/datasource.py
@@ -47,7 +47,7 @@ class Datasource(BaseSupersetView):
         datasource_type = datasource_dict.get("type")
         database_id = datasource_dict["database"].get("id")
         orm_datasource = ConnectorRegistry.get_datasource(
-            datasource_type, datasource_id
+            datasource_type, datasource_id, db.session
         )
         orm_datasource.database_id = database_id
 
@@ -82,7 +82,7 @@ class Datasource(BaseSupersetView):
     def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
         try:
             orm_datasource = ConnectorRegistry.get_datasource(
-                datasource_type, datasource_id
+                datasource_type, datasource_id, db.session
             )
             if not orm_datasource.data:
                 return json_error_response(
@@ -102,7 +102,7 @@ class Datasource(BaseSupersetView):
         """Gets column info from the source system"""
         if datasource_type == "druid":
             datasource = ConnectorRegistry.get_datasource(
-                datasource_type, datasource_id
+                datasource_type, datasource_id, db.session
             )
         elif datasource_type == "table":
             database = (
diff --git a/superset/views/utils.py b/superset/views/utils.py
index dd2aa06..2a8b2cc 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -105,7 +105,9 @@ def get_viz(
     form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
 ) -> BaseViz:
     viz_type = form_data.get("viz_type", "table")
-    datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+    datasource = ConnectorRegistry.get_datasource(
+        datasource_type, datasource_id, db.session
+    )
     viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
     return viz_obj
 
@@ -291,7 +293,8 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"]
 def get_dashboard_extra_filters(
     slice_id: int, dashboard_id: int
 ) -> List[Dict[str, Any]]:
-    dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
+    session = db.session()
+    dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
 
     # is chart in this dashboard?
     if (
diff --git a/tests/access_tests.py b/tests/access_tests.py
index 10a4867..d452d13 100644
--- a/tests/access_tests.py
+++ b/tests/access_tests.py
@@ -71,17 +71,13 @@ DB_ACCESS_ROLE = "db_access_role"
 SCHEMA_ACCESS_ROLE = "schema_access_role"
 
 
-def create_access_request(ds_type, ds_name, role_name, user_name):
+def create_access_request(session, ds_type, ds_name, role_name, user_name):
     ds_class = ConnectorRegistry.sources[ds_type]
     # TODO: generalize datasource names
     if ds_type == "table":
-        ds = db.session.query(ds_class).filter(ds_class.table_name == ds_name).first()
+        ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first()
     else:
-        ds = (
-            db.session.query(ds_class)
-            .filter(ds_class.datasource_name == ds_name)
-            .first()
-        )
+        ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first()
     ds_perm_view = security_manager.find_permission_view_menu(
         "datasource_access", ds.perm
     )
@@ -93,8 +89,8 @@ def create_access_request(ds_type, ds_name, role_name, user_name):
         datasource_type=ds_type,
         created_by_fk=security_manager.find_user(username=user_name).id,
     )
-    db.session.add(access_request)
-    db.session.commit()
+    session.add(access_request)
+    session.commit()
     return access_request
 
 
@@ -130,6 +126,7 @@ class TestRequestAccess(SupersetTestCase):
         override_me = security_manager.find_role("override_me")
         override_me.permissions = []
         db.session.commit()
+        db.session.close()
 
     def test_override_role_permissions_is_admin_only(self):
         self.logout()
@@ -214,6 +211,7 @@ class TestRequestAccess(SupersetTestCase):
         )
 
     def test_clean_requests_after_role_extend(self):
+        session = db.session
 
         # Case 1. Gamma and gamma2 requested test_role1 on energy_usage access
         # Gamma already has role test_role1
@@ -223,10 +221,12 @@ class TestRequestAccess(SupersetTestCase):
         # gamma2 and gamma request table_role on energy usage
         if app.config["ENABLE_ACCESS_REQUEST"]:
             access_request1 = create_access_request(
-                "table", "random_time_series", TEST_ROLE_1, "gamma2"
+                session, "table", "random_time_series", TEST_ROLE_1, "gamma2"
             )
             ds_1_id = access_request1.datasource_id
-            create_access_request("table", "random_time_series", TEST_ROLE_1, "gamma")
+            create_access_request(
+                session, "table", "random_time_series", TEST_ROLE_1, "gamma"
+            )
             access_requests = self.get_access_requests("gamma", "table", ds_1_id)
             self.assertTrue(access_requests)
             # gamma gets test_role1
@@ -244,20 +244,22 @@ class TestRequestAccess(SupersetTestCase):
             gamma_user.roles.remove(security_manager.find_role("test_role1"))
 
     def test_clean_requests_after_alpha_grant(self):
+        session = db.session
+
         # Case 2. Two access requests from gamma and gamma2
         # Gamma becomes alpha, gamma2 gets granted
         # Check if request by gamma has been deleted
 
         access_request1 = create_access_request(
-            "table", "birth_names", TEST_ROLE_1, "gamma"
+            session, "table", "birth_names", TEST_ROLE_1, "gamma"
         )
-        create_access_request("table", "birth_names", TEST_ROLE_2, "gamma2")
+        create_access_request(session, "table", "birth_names", TEST_ROLE_2, "gamma2")
         ds_1_id = access_request1.datasource_id
         # gamma becomes alpha
         alpha_role = security_manager.find_role("Alpha")
         gamma_user = security_manager.find_user(username="gamma")
         gamma_user.roles.append(alpha_role)
-        db.session.commit()
+        session.commit()
         access_requests = self.get_access_requests("gamma", "table", ds_1_id)
         self.assertTrue(access_requests)
         self.client.get(
@@ -268,21 +270,23 @@ class TestRequestAccess(SupersetTestCase):
 
         gamma_user = security_manager.find_user(username="gamma")
         gamma_user.roles.remove(security_manager.find_role("Alpha"))
-        db.session.commit()
+        session.commit()
 
     def test_clean_requests_after_db_grant(self):
+        session = db.session
+
         # Case 3. Two access requests from gamma and gamma2
         # Gamma gets database access, gamma2 access request granted
         # Check if request by gamma has been deleted
 
         gamma_user = security_manager.find_user(username="gamma")
         access_request1 = create_access_request(
-            "table", "energy_usage", TEST_ROLE_1, "gamma"
+            session, "table", "energy_usage", TEST_ROLE_1, "gamma"
         )
-        create_access_request("table", "energy_usage", TEST_ROLE_2, "gamma2")
+        create_access_request(session, "table", "energy_usage", TEST_ROLE_2, "gamma2")
         ds_1_id = access_request1.datasource_id
         # gamma gets granted database access
-        database = db.session.query(models.Database).first()
+        database = session.query(models.Database).first()
 
         security_manager.add_permission_view_menu("database_access", database.perm)
         ds_perm_view = security_manager.find_permission_view_menu(
@@ -292,7 +296,7 @@ class TestRequestAccess(SupersetTestCase):
             security_manager.find_role(DB_ACCESS_ROLE), ds_perm_view
         )
         gamma_user.roles.append(security_manager.find_role(DB_ACCESS_ROLE))
-        db.session.commit()
+        session.commit()
         access_requests = self.get_access_requests("gamma", "table", ds_1_id)
         self.assertTrue(access_requests)
         # gamma2 request gets fulfilled
@@ -304,21 +308,25 @@ class TestRequestAccess(SupersetTestCase):
         self.assertFalse(access_requests)
         gamma_user = security_manager.find_user(username="gamma")
         gamma_user.roles.remove(security_manager.find_role(DB_ACCESS_ROLE))
-        db.session.commit()
+        session.commit()
 
     def test_clean_requests_after_schema_grant(self):
+        session = db.session
+
         # Case 4. Two access requests from gamma and gamma2
         # Gamma gets schema access, gamma2 access request granted
         # Check if request by gamma has been deleted
 
         gamma_user = security_manager.find_user(username="gamma")
         access_request1 = create_access_request(
-            "table", "wb_health_population", TEST_ROLE_1, "gamma"
+            session, "table", "wb_health_population", TEST_ROLE_1, "gamma"
+        )
+        create_access_request(
+            session, "table", "wb_health_population", TEST_ROLE_2, "gamma2"
         )
-        create_access_request("table", "wb_health_population", TEST_ROLE_2, "gamma2")
         ds_1_id = access_request1.datasource_id
         ds = (
-            db.session.query(SqlaTable)
+            session.query(SqlaTable)
             .filter_by(table_name="wb_health_population")
             .first()
         )
@@ -332,7 +340,7 @@ class TestRequestAccess(SupersetTestCase):
             security_manager.find_role(SCHEMA_ACCESS_ROLE), schema_perm_view
         )
         gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
-        db.session.commit()
+        session.commit()
         # gamma2 request gets fulfilled
         self.client.get(
             EXTEND_ROLE_REQUEST.format("table", ds_1_id, "gamma2", TEST_ROLE_2)
@@ -343,24 +351,25 @@ class TestRequestAccess(SupersetTestCase):
         gamma_user.roles.remove(security_manager.find_role(SCHEMA_ACCESS_ROLE))
 
         ds = (
-            db.session.query(SqlaTable)
+            session.query(SqlaTable)
             .filter_by(table_name="wb_health_population")
             .first()
         )
         ds.schema = None
 
-        db.session.commit()
+        session.commit()
 
     @mock.patch("superset.utils.core.send_mime_email")
     def test_approve(self, mock_send_mime):
         if app.config["ENABLE_ACCESS_REQUEST"]:
+            session = db.session
             TEST_ROLE_NAME = "table_role"
             security_manager.add_role(TEST_ROLE_NAME)
 
             # Case 1. Grant new role to the user.
 
             access_request1 = create_access_request(
-                "table", "unicode_test", TEST_ROLE_NAME, "gamma"
+                session, "table", "unicode_test", TEST_ROLE_NAME, "gamma"
             )
             ds_1_id = access_request1.datasource_id
             self.get_resp(
@@ -395,7 +404,7 @@ class TestRequestAccess(SupersetTestCase):
             # Case 2. Extend the role to have access to the table
 
             access_request2 = create_access_request(
-                "table", "energy_usage", TEST_ROLE_NAME, "gamma"
+                session, "table", "energy_usage", TEST_ROLE_NAME, "gamma"
             )
             ds_2_id = access_request2.datasource_id
             energy_usage_perm = access_request2.datasource.perm
@@ -439,7 +448,7 @@ class TestRequestAccess(SupersetTestCase):
 
             security_manager.add_role("druid_role")
             access_request3 = create_access_request(
-                "druid", "druid_ds_1", "druid_role", "gamma"
+                session, "druid", "druid_ds_1", "druid_role", "gamma"
             )
             self.get_resp(
                 GRANT_ROLE_REQUEST.format(
@@ -454,7 +463,7 @@ class TestRequestAccess(SupersetTestCase):
             # Case 4. Extend the role to have access to the druid datasource
 
             access_request4 = create_access_request(
-                "druid", "druid_ds_2", "druid_role", "gamma"
+                session, "druid", "druid_ds_2", "druid_role", "gamma"
             )
             druid_ds_2_perm = access_request4.datasource.perm
 
@@ -474,18 +483,19 @@ class TestRequestAccess(SupersetTestCase):
             gamma_user = security_manager.find_user(username="gamma")
             gamma_user.roles.remove(security_manager.find_role("druid_role"))
             gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
-            db.session.delete(security_manager.find_role("druid_role"))
-            db.session.delete(security_manager.find_role(TEST_ROLE_NAME))
-            db.session.commit()
+            session.delete(security_manager.find_role("druid_role"))
+            session.delete(security_manager.find_role(TEST_ROLE_NAME))
+            session.commit()
 
     def test_request_access(self):
         if app.config["ENABLE_ACCESS_REQUEST"]:
+            session = db.session
             self.logout()
             self.login(username="gamma")
             gamma_user = security_manager.find_user(username="gamma")
             security_manager.add_role("dummy_role")
             gamma_user.roles.append(security_manager.find_role("dummy_role"))
-            db.session.commit()
+            session.commit()
 
             ACCESS_REQUEST = (
                 "/superset/request_access?"
@@ -501,7 +511,7 @@ class TestRequestAccess(SupersetTestCase):
             # Request table access, there are no roles have this table.
 
             table1 = (
-                db.session.query(SqlaTable)
+                session.query(SqlaTable)
                 .filter_by(table_name="random_time_series")
                 .first()
             )
@@ -516,7 +526,7 @@ class TestRequestAccess(SupersetTestCase):
             # Request access, roles exist that contains the table.
             # add table to the existing roles
             table3 = (
-                db.session.query(SqlaTable).filter_by(table_name="energy_usage").first()
+                session.query(SqlaTable).filter_by(table_name="energy_usage").first()
             )
             table_3_id = table3.id
             table3_perm = table3.perm
@@ -535,7 +545,7 @@ class TestRequestAccess(SupersetTestCase):
                     "datasource_access", table3_perm
                 ),
             )
-            db.session.commit()
+            session.commit()
 
             self.get_resp(ACCESS_REQUEST.format("table", table_3_id, "go"))
             access_request3 = self.get_access_requests("gamma", "table", table_3_id)
@@ -549,7 +559,7 @@ class TestRequestAccess(SupersetTestCase):
 
             # Request druid access, there are no roles have this table.
             druid_ds_4 = (
-                db.session.query(DruidDatasource)
+                session.query(DruidDatasource)
                 .filter_by(datasource_name="druid_ds_1")
                 .first()
             )
@@ -564,7 +574,7 @@ class TestRequestAccess(SupersetTestCase):
             # Case 5. Roles exist that contains the druid datasource.
             # add druid ds to the existing roles
             druid_ds_5 = (
-                db.session.query(DruidDatasource)
+                session.query(DruidDatasource)
                 .filter_by(datasource_name="druid_ds_2")
                 .first()
             )
@@ -585,7 +595,7 @@ class TestRequestAccess(SupersetTestCase):
                     "datasource_access", druid_ds_5_perm
                 ),
             )
-            db.session.commit()
+            session.commit()
 
             self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go"))
             access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id)
@@ -600,7 +610,7 @@ class TestRequestAccess(SupersetTestCase):
             # cleanup
             gamma_user = security_manager.find_user(username="gamma")
             gamma_user.roles.remove(security_manager.find_role("dummy_role"))
-            db.session.commit()
+            session.commit()
 
 
 if __name__ == "__main__":
diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py
index 0720581..c78847c 100644
--- a/tests/alerts_tests.py
+++ b/tests/alerts_tests.py
@@ -32,118 +32,112 @@ logging.basicConfig(level=logging.INFO)
 logger = logging.getLogger(__name__)
 
 
-def setup_module():
+@pytest.yield_fixture(scope="module")
+def setup_database():
     with app.app_context():
         slice_id = db.session.query(Slice).all()[0].id
         database_id = utils.get_example_database().id
 
-        alerts = [
-            Alert(
-                id=1,
-                label="alert_1",
-                active=True,
-                crontab="*/1 * * * *",
-                sql="SELECT 0",
-                alert_type="email",
-                slice_id=slice_id,
-                database_id=database_id,
-            ),
-            Alert(
-                id=2,
-                label="alert_2",
-                active=True,
-                crontab="*/1 * * * *",
-                sql="SELECT 55",
-                alert_type="email",
-                slice_id=slice_id,
-                database_id=database_id,
-            ),
-            Alert(
-                id=3,
-                label="alert_3",
-                active=False,
-                crontab="*/1 * * * *",
-                sql="UPDATE 55",
-                alert_type="email",
-                slice_id=slice_id,
-                database_id=database_id,
-            ),
-            Alert(id=4, active=False, label="alert_4", database_id=-1),
-            Alert(id=5, active=False, label="alert_5", database_id=database_id),
-        ]
-
-        db.session.bulk_save_objects(alerts)
-        db.session.commit()
+        alert1 = Alert(
+            id=1,
+            label="alert_1",
+            active=True,
+            crontab="*/1 * * * *",
+            sql="SELECT 0",
+            alert_type="email",
+            slice_id=slice_id,
+            database_id=database_id,
+        )
+        alert2 = Alert(
+            id=2,
+            label="alert_2",
+            active=True,
+            crontab="*/1 * * * *",
+            sql="SELECT 55",
+            alert_type="email",
+            slice_id=slice_id,
+            database_id=database_id,
+        )
+        alert3 = Alert(
+            id=3,
+            label="alert_3",
+            active=False,
+            crontab="*/1 * * * *",
+            sql="UPDATE 55",
+            alert_type="email",
+            slice_id=slice_id,
+            database_id=database_id,
+        )
+        alert4 = Alert(id=4, active=False, label="alert_4", database_id=-1)
+        alert5 = Alert(id=5, active=False, label="alert_5", database_id=database_id)
 
+        for num in range(1, 6):
+            eval(f"db.session.add(alert{num})")
+        db.session.commit()
+        yield db.session
 
-def teardown_module():
-    with app.app_context():
         db.session.query(AlertLog).delete()
         db.session.query(Alert).delete()
 
 
 @patch("superset.tasks.schedules.deliver_alert")
 @patch("superset.tasks.schedules.logging.Logger.error")
-def test_run_alert_query(mock_error, mock_deliver_alert):
-    with app.app_context():
-        run_alert_query(db.session.query(Alert).filter_by(id=1).one().id)
-        alert1 = db.session.query(Alert).filter_by(id=1).one()
-        assert mock_deliver_alert.call_count == 0
-        assert len(alert1.logs) == 1
-        assert alert1.logs[0].alert_id == 1
-        assert alert1.logs[0].state == "pass"
-
-        run_alert_query(db.session.query(Alert).filter_by(id=2).one().id)
-        alert2 = db.session.query(Alert).filter_by(id=2).one()
-        assert mock_deliver_alert.call_count == 1
-        assert len(alert2.logs) == 1
-        assert alert2.logs[0].alert_id == 2
-        assert alert2.logs[0].state == "trigger"
-
-        run_alert_query(db.session.query(Alert).filter_by(id=3).one().id)
-        alert3 = db.session.query(Alert).filter_by(id=3).one()
-        assert mock_deliver_alert.call_count == 1
-        assert mock_error.call_count == 2
-        assert len(alert3.logs) == 1
-        assert alert3.logs[0].alert_id == 3
-        assert alert3.logs[0].state == "error"
-
-        run_alert_query(db.session.query(Alert).filter_by(id=4).one().id)
-        assert mock_deliver_alert.call_count == 1
-        assert mock_error.call_count == 3
-
-        run_alert_query(db.session.query(Alert).filter_by(id=5).one().id)
-        assert mock_deliver_alert.call_count == 1
-        assert mock_error.call_count == 4
+def test_run_alert_query(mock_error, mock_deliver, setup_database):
+    database = setup_database
+    run_alert_query(database.query(Alert).filter_by(id=1).one().id, database)
+    alert1 = database.query(Alert).filter_by(id=1).one()
+    assert mock_deliver.call_count == 0
+    assert len(alert1.logs) == 1
+    assert alert1.logs[0].alert_id == 1
+    assert alert1.logs[0].state == "pass"
+
+    run_alert_query(database.query(Alert).filter_by(id=2).one().id, database)
+    alert2 = database.query(Alert).filter_by(id=2).one()
+    assert mock_deliver.call_count == 1
+    assert len(alert2.logs) == 1
+    assert alert2.logs[0].alert_id == 2
+    assert alert2.logs[0].state == "trigger"
+
+    run_alert_query(database.query(Alert).filter_by(id=3).one().id, database)
+    alert3 = database.query(Alert).filter_by(id=3).one()
+    assert mock_deliver.call_count == 1
+    assert mock_error.call_count == 2
+    assert len(alert3.logs) == 1
+    assert alert3.logs[0].alert_id == 3
+    assert alert3.logs[0].state == "error"
+
+    run_alert_query(database.query(Alert).filter_by(id=4).one().id, database)
+    assert mock_deliver.call_count == 1
+    assert mock_error.call_count == 3
+
+    run_alert_query(database.query(Alert).filter_by(id=5).one().id, database)
+    assert mock_deliver.call_count == 1
+    assert mock_error.call_count == 4
 
 
 @patch("superset.tasks.schedules.deliver_alert")
 @patch("superset.tasks.schedules.run_alert_query")
-def test_schedule_alert_query(mock_run_alert, mock_deliver_alert):
-    with app.app_context():
-        active_alert = db.session.query(Alert).filter_by(id=1).one()
-        inactive_alert = db.session.query(Alert).filter_by(id=3).one()
-
-        # Test that inactive alerts are no processed
-        schedule_alert_query(
-            report_type=ScheduleType.alert, schedule_id=inactive_alert.id
-        )
-        assert mock_run_alert.call_count == 0
-        assert mock_deliver_alert.call_count == 0
-
-        # Test that active alerts with no recipients passed in are processed regularly
-        schedule_alert_query(
-            report_type=ScheduleType.alert, schedule_id=active_alert.id
-        )
-        assert mock_run_alert.call_count == 1
-        assert mock_deliver_alert.call_count == 0
-
-        # Test that active alerts sent as a test are delivered immediately
-        schedule_alert_query(
-            report_type=ScheduleType.alert,
-            schedule_id=active_alert.id,
-            recipients="testing@email.com",
-            is_test_alert=True,
-        )
-        assert mock_run_alert.call_count == 1
-        assert mock_deliver_alert.call_count == 1
+def test_schedule_alert_query(mock_run_alert, mock_deliver_alert, setup_database):
+    database = setup_database
+    active_alert = database.query(Alert).filter_by(id=1).one()
+    inactive_alert = database.query(Alert).filter_by(id=3).one()
+
+    # Test that inactive alerts are no processed
+    schedule_alert_query(report_type=ScheduleType.alert, schedule_id=inactive_alert.id)
+    assert mock_run_alert.call_count == 0
+    assert mock_deliver_alert.call_count == 0
+
+    # Test that active alerts with no recipients passed in are processed regularly
+    schedule_alert_query(report_type=ScheduleType.alert, schedule_id=active_alert.id)
+    assert mock_run_alert.call_count == 1
+    assert mock_deliver_alert.call_count == 0
+
+    # Test that active alerts sent as a test are delivered immediately
+    schedule_alert_query(
+        report_type=ScheduleType.alert,
+        schedule_id=active_alert.id,
+        recipients="testing@email.com",
+        is_test_alert=True,
+    )
+    assert mock_run_alert.call_count == 1
+    assert mock_deliver_alert.call_count == 1
diff --git a/tests/base_tests.py b/tests/base_tests.py
index c74378f..e0a20a4 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -25,6 +25,7 @@ import pandas as pd
 from flask import Response
 from flask_appbuilder.security.sqla import models as ab_models
 from flask_testing import TestCase
+from sqlalchemy.orm import Session
 
 from tests.test_app import app
 from superset.sql_parse import CtasMethod
@@ -103,25 +104,24 @@ class SupersetTestCase(TestCase):
         # create druid cluster and druid datasources
 
         with app.app_context():
+            session = db.session
             cluster = (
-                db.session.query(DruidCluster)
-                .filter_by(cluster_name="druid_test")
-                .first()
+                session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
             )
             if not cluster:
                 cluster = DruidCluster(cluster_name="druid_test")
-                db.session.add(cluster)
-                db.session.commit()
+                session.add(cluster)
+                session.commit()
 
                 druid_datasource1 = DruidDatasource(
                     datasource_name="druid_ds_1", cluster=cluster
                 )
-                db.session.add(druid_datasource1)
+                session.add(druid_datasource1)
                 druid_datasource2 = DruidDatasource(
                     datasource_name="druid_ds_2", cluster=cluster
                 )
-                db.session.add(druid_datasource2)
-                db.session.commit()
+                session.add(druid_datasource2)
+                session.commit()
 
     @staticmethod
     def get_table_by_id(table_id: int) -> SqlaTable:
@@ -135,23 +135,25 @@ class SupersetTestCase(TestCase):
         except ImportError:
             return False
 
-    def get_or_create(self, cls, criteria, **kwargs):
-        obj = db.session.query(cls).filter_by(**criteria).first()
+    def get_or_create(self, cls, criteria, session, **kwargs):
+        obj = session.query(cls).filter_by(**criteria).first()
         if not obj:
             obj = cls(**criteria)
         obj.__dict__.update(**kwargs)
-        db.session.add(obj)
-        db.session.commit()
+        session.add(obj)
+        session.commit()
         return obj
 
     def login(self, username="admin", password="general"):
         resp = self.get_resp("/login/", data=dict(username=username, password=password))
         self.assertNotIn("User confirmation needed", resp)
 
-    def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice:
-        slc = db.session.query(Slice).filter_by(slice_name=slice_name).one()
+    def get_slice(
+        self, slice_name: str, session: Session, expunge_from_session: bool = True
+    ) -> Slice:
+        slc = session.query(Slice).filter_by(slice_name=slice_name).one()
         if expunge_from_session:
-            db.session.expunge_all()
+            session.expunge_all()
         return slc
 
     @staticmethod
@@ -300,6 +302,7 @@ class SupersetTestCase(TestCase):
         return self.get_or_create(
             cls=models.Database,
             criteria={"database_name": database_name},
+            session=db.session,
             sqlalchemy_uri="sqlite:///:memory:",
             id=db_id,
             extra=extra,
@@ -321,6 +324,7 @@ class SupersetTestCase(TestCase):
         return self.get_or_create(
             cls=models.Database,
             criteria={"database_name": database_name},
+            session=db.session,
             sqlalchemy_uri="presto://user@host:8080/hive",
             id=db_id,
         )
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 68a7213..53190cb 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -99,13 +99,15 @@ class TestAppContext(SupersetTestCase):
 
 class TestCelery(SupersetTestCase):
     def get_query_by_name(self, sql):
-        query = db.session.query(Query).filter_by(sql=sql).first()
-        db.session.close()
+        session = db.session
+        query = session.query(Query).filter_by(sql=sql).first()
+        session.close()
         return query
 
     def get_query_by_id(self, id):
-        query = db.session.query(Query).filter_by(id=id).first()
-        db.session.close()
+        session = db.session
+        query = session.query(Query).filter_by(id=id).first()
+        session.close()
         return query
 
     @classmethod
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 0820518..5048a0a 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -58,7 +58,9 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
         for owner in owners:
             user = db.session.query(security_manager.user_model).get(owner)
             obj_owners.append(user)
-        datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+        datasource = ConnectorRegistry.get_datasource(
+            datasource_type, datasource_id, db.session
+        )
         slice = Slice(
             slice_name=slice_name,
             datasource_id=datasource.id,
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 2793795..ade0095 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -100,7 +100,7 @@ class TestCore(SupersetTestCase):
 
     def test_slice_endpoint(self):
         self.login(username="admin")
-        slc = self.get_slice("Girls")
+        slc = self.get_slice("Girls", db.session)
         resp = self.get_resp("/superset/slice/{}/".format(slc.id))
         assert "Time Column" in resp
         assert "List Roles" in resp
@@ -114,7 +114,7 @@ class TestCore(SupersetTestCase):
 
     def test_viz_cache_key(self):
         self.login(username="admin")
-        slc = self.get_slice("Girls")
+        slc = self.get_slice("Girls", db.session)
 
         viz = slc.viz
         qobj = viz.query_obj()
@@ -233,7 +233,7 @@ class TestCore(SupersetTestCase):
     def test_save_slice(self):
         self.login(username="admin")
         slice_name = f"Energy Sankey"
-        slice_id = self.get_slice(slice_name).id
+        slice_id = self.get_slice(slice_name, db.session).id
         copy_name_prefix = "Test Sankey"
         copy_name = f"{copy_name_prefix}[save]{random.random()}"
         tbl_id = self.table_ids.get("energy_usage")
@@ -299,7 +299,7 @@ class TestCore(SupersetTestCase):
     def test_filter_endpoint(self):
         self.login(username="admin")
         slice_name = "Energy Sankey"
-        slice_id = self.get_slice(slice_name).id
+        slice_id = self.get_slice(slice_name, db.session).id
         db.session.commit()
         tbl_id = self.table_ids.get("energy_usage")
         table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
@@ -319,7 +319,9 @@ class TestCore(SupersetTestCase):
     def test_slice_data(self):
         # slice data should have some required attributes
         self.login(username="admin")
-        slc = self.get_slice(slice_name="Girls", expunge_from_session=False)
+        slc = self.get_slice(
+            slice_name="Girls", session=db.session, expunge_from_session=False
+        )
         slc_data_attributes = slc.data.keys()
         assert "changed_on" in slc_data_attributes
         assert "modified" in slc_data_attributes
@@ -370,7 +372,9 @@ class TestCore(SupersetTestCase):
         self.assertEqual(data, [])
 
         # make user owner of slice and verify that endpoint returns said slice
-        slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
+        slc = self.get_slice(
+            slice_name=slice_name, session=db.session, expunge_from_session=False
+        )
         slc.owners = [user]
         db.session.merge(slc)
         db.session.commit()
@@ -381,7 +385,9 @@ class TestCore(SupersetTestCase):
         self.assertEqual(data[0]["title"], slice_name)
 
         # remove ownership and ensure user no longer gets slice
-        slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
+        slc = self.get_slice(
+            slice_name=slice_name, session=db.session, expunge_from_session=False
+        )
         slc.owners = []
         db.session.merge(slc)
         db.session.commit()
@@ -559,7 +565,7 @@ class TestCore(SupersetTestCase):
         db.session.commit()
 
     def test_warm_up_cache(self):
-        slc = self.get_slice("Girls")
+        slc = self.get_slice("Girls", db.session)
         data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
         self.assertEqual(
             data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
@@ -784,7 +790,7 @@ class TestCore(SupersetTestCase):
 
     def test_user_profile(self, username="admin"):
         self.login(username=username)
-        slc = self.get_slice("Girls")
+        slc = self.get_slice("Girls", db.session)
 
         # Setting some faves
         url = f"/superset/favstar/Slice/{slc.id}/select/"
diff --git a/tests/database_api_tests.py b/tests/database_api_tests.py
index 7126c25..49ede7b 100644
--- a/tests/database_api_tests.py
+++ b/tests/database_api_tests.py
@@ -178,11 +178,12 @@ class TestDatabaseApi(SupersetTestCase):
         """
         Database API: Test get select star with datasource access
         """
+        session = db.session
         table = SqlaTable(
             schema="main", table_name="ab_permission", database=get_main_database()
         )
-        db.session.add(table)
-        db.session.commit()
+        session.add(table)
+        session.commit()
 
         tmp_table_perm = security_manager.find_permission_view_menu(
             "datasource_access", table.get_perm()
diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py
index 7e078fd..08c8d5a 100644
--- a/tests/datasets/api_tests.py
+++ b/tests/datasets/api_tests.py
@@ -156,7 +156,7 @@ class TestDatasetApi(SupersetTestCase):
             "template_params": None,
         }
         for key, value in expected_result.items():
-            self.assertEqual(response["result"][key], value)
+            self.assertEqual(response["result"][key], expected_result[key])
         self.assertEqual(len(response["result"]["columns"]), 8)
         self.assertEqual(len(response["result"]["metrics"]), 2)
 
@@ -721,7 +721,10 @@ class TestDatasetApi(SupersetTestCase):
         )
 
         cli_export = export_to_dict(
-            recursive=True, back_references=False, include_defaults=False,
+            session=db.session,
+            recursive=True,
+            back_references=False,
+            include_defaults=False,
         )
         cli_export_tables = cli_export["databases"][0]["tables"]
         expected_response = []
diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py
index 725ffdf..dc0b8d8 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -47,13 +47,14 @@ class TestDictImportExport(SupersetTestCase):
     def delete_imports(cls):
         with app.app_context():
             # Imported data clean up
-            for table in db.session.query(SqlaTable):
+            session = db.session
+            for table in session.query(SqlaTable):
                 if DBREF in table.params_dict:
-                    db.session.delete(table)
-            for datasource in db.session.query(DruidDatasource):
+                    session.delete(table)
+            for datasource in session.query(DruidDatasource):
                 if DBREF in datasource.params_dict:
-                    db.session.delete(datasource)
-            db.session.commit()
+                    session.delete(datasource)
+            session.commit()
 
     @classmethod
     def setUpClass(cls):
@@ -89,7 +90,9 @@ class TestDictImportExport(SupersetTestCase):
 
     def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
         cluster_name = "druid_test"
-        cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
+        cluster = self.get_or_create(
+            DruidCluster, {"cluster_name": cluster_name}, db.session
+        )
 
         name = "{0}{1}".format(NAME_PREFIX, name)
         params = {DBREF: id, "database_name": cluster_name}
@@ -156,7 +159,7 @@ class TestDictImportExport(SupersetTestCase):
 
     def test_import_table_no_metadata(self):
         table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
-        new_table = SqlaTable.import_from_dict(dict_table)
+        new_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         imported_id = new_table.id
         imported = self.get_table_by_id(imported_id)
@@ -170,7 +173,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["col1"],
             metric_names=["metric1"],
         )
-        imported_table = SqlaTable.import_from_dict(dict_table)
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         imported = self.get_table_by_id(imported_table.id)
         self.assert_table_equals(table, imported)
@@ -186,7 +189,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["c1", "c2"],
             metric_names=["m1", "m2"],
         )
-        imported_table = SqlaTable.import_from_dict(dict_table)
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         imported = self.get_table_by_id(imported_table.id)
         self.assert_table_equals(table, imported)
@@ -196,7 +199,7 @@ class TestDictImportExport(SupersetTestCase):
         table, dict_table = self.create_table(
             "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
         )
-        imported_table = SqlaTable.import_from_dict(dict_table)
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         table_over, dict_table_over = self.create_table(
             "table_override",
@@ -204,7 +207,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_over_table = SqlaTable.import_from_dict(dict_table_over)
+        imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
         db.session.commit()
 
         imported_over = self.get_table_by_id(imported_over_table.id)
@@ -224,7 +227,7 @@ class TestDictImportExport(SupersetTestCase):
         table, dict_table = self.create_table(
             "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
         )
-        imported_table = SqlaTable.import_from_dict(dict_table)
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         table_over, dict_table_over = self.create_table(
             "table_override",
@@ -233,7 +236,7 @@ class TestDictImportExport(SupersetTestCase):
             metric_names=["new_metric1"],
         )
         imported_over_table = SqlaTable.import_from_dict(
-            dict_rep=dict_table_over, sync=["metrics", "columns"]
+            session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
         )
         db.session.commit()
 
@@ -257,7 +260,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_table = SqlaTable.import_from_dict(dict_table)
+        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
         db.session.commit()
         copy_table, dict_copy_table = self.create_table(
             "copy_cat",
@@ -265,7 +268,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
+        imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
         db.session.commit()
         self.assertEqual(imported_table.id, imported_copy_table.id)
         self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
@@ -278,7 +281,10 @@ class TestDictImportExport(SupersetTestCase):
         self.delete_fake_db()
 
         cli_export = export_to_dict(
-            recursive=True, back_references=False, include_defaults=False,
+            session=db.session,
+            recursive=True,
+            back_references=False,
+            include_defaults=False,
         )
         self.get_resp("/login/", data=dict(username="admin", password="general"))
         resp = self.get_resp(
@@ -297,7 +303,7 @@ class TestDictImportExport(SupersetTestCase):
         datasource, dict_datasource = self.create_druid_datasource(
             "pure_druid", id=ID_PREFIX + 1
         )
-        imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+        imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
         db.session.commit()
         imported = self.get_datasource(imported_cluster.id)
         self.assert_datasource_equals(datasource, imported)
@@ -309,7 +315,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["col1"],
             metric_names=["metric1"],
         )
-        imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+        imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
         db.session.commit()
         imported = self.get_datasource(imported_cluster.id)
         self.assert_datasource_equals(datasource, imported)
@@ -325,7 +331,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["c1", "c2"],
             metric_names=["m1", "m2"],
         )
-        imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+        imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
         db.session.commit()
         imported = self.get_datasource(imported_cluster.id)
         self.assert_datasource_equals(datasource, imported)
@@ -334,7 +340,7 @@ class TestDictImportExport(SupersetTestCase):
         datasource, dict_datasource = self.create_druid_datasource(
             "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
         )
-        imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+        imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
         db.session.commit()
         table_over, table_over_dict = self.create_druid_datasource(
             "druid_override",
@@ -342,7 +348,9 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_over_cluster = DruidDatasource.import_from_dict(table_over_dict)
+        imported_over_cluster = DruidDatasource.import_from_dict(
+            db.session, table_over_dict
+        )
         db.session.commit()
         imported_over = self.get_datasource(imported_over_cluster.id)
         self.assertEqual(imported_cluster.id, imported_over.id)
@@ -358,7 +366,7 @@ class TestDictImportExport(SupersetTestCase):
         datasource, dict_datasource = self.create_druid_datasource(
             "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
         )
-        imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+        imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
         db.session.commit()
         table_over, table_over_dict = self.create_druid_datasource(
             "druid_override",
@@ -367,7 +375,7 @@ class TestDictImportExport(SupersetTestCase):
             metric_names=["new_metric1"],
         )
         imported_over_cluster = DruidDatasource.import_from_dict(
-            dict_rep=table_over_dict, sync=["metrics", "columns"]
+            session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"]
         )  # syncing metrics and columns
         db.session.commit()
         imported_over = self.get_datasource(imported_over_cluster.id)
@@ -387,7 +395,9 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported = DruidDatasource.import_from_dict(dict_rep=dict_datasource)
+        imported = DruidDatasource.import_from_dict(
+            session=db.session, dict_rep=dict_datasource
+        )
         db.session.commit()
         copy_datasource, dict_cp_datasource = self.create_druid_datasource(
             "copy_cat",
@@ -395,7 +405,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_copy = DruidDatasource.import_from_dict(dict_cp_datasource)
+        imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource)
         db.session.commit()
 
         self.assertEqual(imported.id, imported_copy.id)
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index c757671..648eb32 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -212,7 +212,9 @@ class TestDruid(SupersetTestCase):
     def test_druid_sync_from_config(self):
         CLUSTER_NAME = "new_druid"
         self.login()
-        cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
+        cluster = self.get_or_create(
+            DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
+        )
 
         db.session.merge(cluster)
         db.session.commit()
@@ -300,12 +302,15 @@ class TestDruid(SupersetTestCase):
     @unittest.skipUnless(app.config["DRUID_IS_ACTIVE"], "DRUID_IS_ACTIVE is false")
     def test_filter_druid_datasource(self):
         CLUSTER_NAME = "new_druid"
-        cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
+        cluster = self.get_or_create(
+            DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
+        )
         db.session.merge(cluster)
 
         gamma_ds = self.get_or_create(
             DruidDatasource,
             {"datasource_name": "datasource_for_gamma", "cluster": cluster},
+            db.session,
         )
         gamma_ds.cluster = cluster
         db.session.merge(gamma_ds)
@@ -313,6 +318,7 @@ class TestDruid(SupersetTestCase):
         no_gamma_ds = self.get_or_create(
             DruidDatasource,
             {"datasource_name": "datasource_not_for_gamma", "cluster": cluster},
+            db.session,
         )
         no_gamma_ds.cluster = cluster
         db.session.merge(no_gamma_ds)
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index fc6ee51..e772d16 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -46,19 +46,20 @@ class TestImportExport(SupersetTestCase):
     def delete_imports(cls):
         with app.app_context():
             # Imported data clean up
-            for slc in db.session.query(Slice):
+            session = db.session
+            for slc in session.query(Slice):
                 if "remote_id" in slc.params_dict:
-                    db.session.delete(slc)
-            for dash in db.session.query(Dashboard):
+                    session.delete(slc)
+            for dash in session.query(Dashboard):
                 if "remote_id" in dash.params_dict:
-                    db.session.delete(dash)
-            for table in db.session.query(SqlaTable):
+                    session.delete(dash)
+            for table in session.query(SqlaTable):
                 if "remote_id" in table.params_dict:
-                    db.session.delete(table)
-            for datasource in db.session.query(DruidDatasource):
+                    session.delete(table)
+            for datasource in session.query(DruidDatasource):
                 if "remote_id" in datasource.params_dict:
-                    db.session.delete(datasource)
-            db.session.commit()
+                    session.delete(datasource)
+            session.commit()
 
     @classmethod
     def setUpClass(cls):
@@ -125,7 +126,9 @@ class TestImportExport(SupersetTestCase):
 
     def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
         cluster_name = "druid_test"
-        cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
+        cluster = self.get_or_create(
+            DruidCluster, {"cluster_name": cluster_name}, db.session
+        )
 
         params = {"remote_id": id, "database_name": cluster_name}
         datasource = DruidDatasource(
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index e016b13..f816bcd 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -83,6 +83,7 @@ class TestQueryContext(SupersetTestCase):
         datasource = ConnectorRegistry.get_datasource(
             datasource_type=payload["datasource"]["type"],
             datasource_id=payload["datasource"]["id"],
+            session=db.session,
         )
         description_original = datasource.description
         datasource.description = "temporary description"
diff --git a/tests/security_tests.py b/tests/security_tests.py
index a161ada..60d20fd 100644
--- a/tests/security_tests.py
+++ b/tests/security_tests.py
@@ -69,8 +69,9 @@ class TestRolePermission(SupersetTestCase):
     """Testing export role permissions."""
 
     def setUp(self):
+        session = db.session
         security_manager.add_role(SCHEMA_ACCESS_ROLE)
-        db.session.commit()
+        session.commit()
 
         ds = (
             db.session.query(SqlaTable)
@@ -81,7 +82,7 @@ class TestRolePermission(SupersetTestCase):
         ds.schema_perm = ds.get_schema_perm()
 
         ds_slices = (
-            db.session.query(Slice)
+            session.query(Slice)
             .filter_by(datasource_type="table")
             .filter_by(datasource_id=ds.id)
             .all()
@@ -91,11 +92,12 @@ class TestRolePermission(SupersetTestCase):
         create_schema_perm("[examples].[temp_schema]")
         gamma_user = security_manager.find_user(username="gamma")
         gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
-        db.session.commit()
+        session.commit()
 
     def tearDown(self):
+        session = db.session
         ds = (
-            db.session.query(SqlaTable)
+            session.query(SqlaTable)
             .filter_by(table_name="wb_health_population")
             .first()
         )
@@ -103,7 +105,7 @@ class TestRolePermission(SupersetTestCase):
         ds.schema = None
         ds.schema_perm = None
         ds_slices = (
-            db.session.query(Slice)
+            session.query(Slice)
             .filter_by(datasource_type="table")
             .filter_by(datasource_id=ds.id)
             .all()
@@ -112,20 +114,21 @@ class TestRolePermission(SupersetTestCase):
             s.schema_perm = None
 
         delete_schema_perm(schema_perm)
-        db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
-        db.session.commit()
+        session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
+        session.commit()
 
     def test_set_perm_sqla_table(self):
+        session = db.session
         table = SqlaTable(
             schema="tmp_schema",
             table_name="tmp_perm_table",
             database=get_example_database(),
         )
-        db.session.add(table)
-        db.session.commit()
+        session.add(table)
+        session.commit()
 
         stored_table = (
-            db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
+            session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
         )
         self.assertEqual(
             stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})"
@@ -144,9 +147,9 @@ class TestRolePermission(SupersetTestCase):
 
         # table name change
         stored_table.table_name = "tmp_perm_table_v2"
-        db.session.commit()
+        session.commit()
         stored_table = (
-            db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+            session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
         )
         self.assertEqual(
             stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -166,9 +169,9 @@ class TestRolePermission(SupersetTestCase):
 
         # schema name change
         stored_table.schema = "tmp_schema_v2"
-        db.session.commit()
+        session.commit()
         stored_table = (
-            db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+            session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
         )
         self.assertEqual(
             stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -188,13 +191,13 @@ class TestRolePermission(SupersetTestCase):
 
         # database change
         new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db")
-        db.session.add(new_db)
+        session.add(new_db)
         stored_table.database = (
-            db.session.query(Database).filter_by(database_name="tmp_db").one()
+            session.query(Database).filter_by(database_name="tmp_db").one()
         )
-        db.session.commit()
+        session.commit()
         stored_table = (
-            db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+            session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
         )
         self.assertEqual(
             stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -214,9 +217,9 @@ class TestRolePermission(SupersetTestCase):
 
         # no schema
         stored_table.schema = None
-        db.session.commit()
+        session.commit()
         stored_table = (
-            db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+            session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
         )
         self.assertEqual(
             stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -228,25 +231,26 @@ class TestRolePermission(SupersetTestCase):
         )
         self.assertIsNone(stored_table.schema_perm)
 
-        db.session.delete(new_db)
-        db.session.delete(stored_table)
-        db.session.commit()
+        session.delete(new_db)
+        session.delete(stored_table)
+        session.commit()
 
     def test_set_perm_druid_datasource(self):
+        session = db.session
         druid_cluster = (
-            db.session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
+            session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
         )
         datasource = DruidDatasource(
             datasource_name="tmp_datasource",
             cluster=druid_cluster,
             cluster_id=druid_cluster.id,
         )
-        db.session.add(datasource)
-        db.session.commit()
+        session.add(datasource)
+        session.commit()
 
         # store without a schema
         stored_datasource = (
-            db.session.query(DruidDatasource)
+            session.query(DruidDatasource)
             .filter_by(datasource_name="tmp_datasource")
             .one()
         )
@@ -263,7 +267,7 @@ class TestRolePermission(SupersetTestCase):
 
         # store with a schema
         stored_datasource.datasource_name = "tmp_schema.tmp_datasource"
-        db.session.commit()
+        session.commit()
         self.assertEqual(
             stored_datasource.perm,
             f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})",
@@ -280,15 +284,16 @@ class TestRolePermission(SupersetTestCase):
             )
         )
 
-        db.session.delete(stored_datasource)
-        db.session.commit()
+        session.delete(stored_datasource)
+        session.commit()
 
     def test_set_perm_druid_cluster(self):
+        session = db.session
         cluster = DruidCluster(cluster_name="tmp_druid_cluster")
-        db.session.add(cluster)
+        session.add(cluster)
 
         stored_cluster = (
-            db.session.query(DruidCluster)
+            session.query(DruidCluster)
             .filter_by(cluster_name="tmp_druid_cluster")
             .one()
         )
@@ -302,7 +307,7 @@ class TestRolePermission(SupersetTestCase):
         )
 
         stored_cluster.cluster_name = "tmp_druid_cluster2"
-        db.session.commit()
+        session.commit()
         self.assertEqual(
             stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})"
         )
@@ -312,17 +317,18 @@ class TestRolePermission(SupersetTestCase):
             )
         )
 
-        db.session.delete(stored_cluster)
-        db.session.commit()
+        session.delete(stored_cluster)
+        session.commit()
 
     def test_set_perm_database(self):
+        session = db.session
         database = Database(
             database_name="tmp_database", sqlalchemy_uri="sqlite://test"
         )
-        db.session.add(database)
+        session.add(database)
 
         stored_db = (
-            db.session.query(Database).filter_by(database_name="tmp_database").one()
+            session.query(Database).filter_by(database_name="tmp_database").one()
         )
         self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})")
         self.assertIsNotNone(
@@ -332,9 +338,9 @@ class TestRolePermission(SupersetTestCase):
         )
 
         stored_db.database_name = "tmp_database2"
-        db.session.commit()
+        session.commit()
         stored_db = (
-            db.session.query(Database).filter_by(database_name="tmp_database2").one()
+            session.query(Database).filter_by(database_name="tmp_database2").one()
         )
         self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})")
         self.assertIsNotNone(
@@ -343,8 +349,8 @@ class TestRolePermission(SupersetTestCase):
             )
         )
 
-        db.session.delete(stored_db)
-        db.session.commit()
+        session.delete(stored_db)
+        session.commit()
 
     def test_hybrid_perm_druid_cluster(self):
         cluster = DruidCluster(cluster_name="tmp_druid_cluster3")
@@ -394,13 +400,14 @@ class TestRolePermission(SupersetTestCase):
         db.session.commit()
 
     def test_set_perm_slice(self):
+        session = db.session
         database = Database(
             database_name="tmp_database", sqlalchemy_uri="sqlite://test"
         )
         table = SqlaTable(table_name="tmp_perm_table", database=database)
-        db.session.add(database)
-        db.session.add(table)
-        db.session.commit()
+        session.add(database)
+        session.add(table)
+        session.commit()
 
         # no schema permission
         slice = Slice(
@@ -409,10 +416,10 @@ class TestRolePermission(SupersetTestCase):
             datasource_name="tmp_perm_table",
             slice_name="slice_name",
         )
-        db.session.add(slice)
-        db.session.commit()
+        session.add(slice)
+        session.commit()
 
-        slice = db.session.query(Slice).filter_by(slice_name="slice_name").one()
+        slice = session.query(Slice).filter_by(slice_name="slice_name").one()
         self.assertEqual(slice.perm, table.perm)
         self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
         self.assertEqual(slice.schema_perm, table.schema_perm)
@@ -420,7 +427,7 @@ class TestRolePermission(SupersetTestCase):
 
         table.schema = "tmp_perm_schema"
         table.table_name = "tmp_perm_table_v2"
-        db.session.commit()
+        session.commit()
         # TODO(bogdan): modify slice permissions on the table update.
         self.assertNotEquals(slice.perm, table.perm)
         self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
@@ -433,7 +440,7 @@ class TestRolePermission(SupersetTestCase):
 
         # updating slice refreshes the permissions
         slice.slice_name = "slice_name_v2"
-        db.session.commit()
+        session.commit()
         self.assertEqual(slice.perm, table.perm)
         self.assertEqual(
             slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})"
@@ -441,10 +448,11 @@ class TestRolePermission(SupersetTestCase):
         self.assertEqual(slice.schema_perm, table.schema_perm)
         self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]")
 
-        db.session.delete(slice)
-        db.session.delete(table)
-        db.session.delete(database)
-        db.session.commit()
+        session.delete(slice)
+        session.delete(table)
+        session.delete(database)
+
+        session.commit()
 
         # TODO test slice permission
 
@@ -524,11 +532,11 @@ class TestRolePermission(SupersetTestCase):
         self.assertNotIn("Girl Name Cloud", data)  # birth_names slice, no access
 
     def test_sqllab_gamma_user_schema_access_to_sqllab(self):
-        example_db = (
-            db.session.query(Database).filter_by(database_name="examples").one()
-        )
+        session = db.session
+
+        example_db = session.query(Database).filter_by(database_name="examples").one()
         example_db.expose_in_sqllab = True
-        db.session.commit()
+        session.commit()
 
         arguments = {
             "keys": ["none"],
@@ -951,10 +959,12 @@ class TestRowLevelSecurity(SupersetTestCase):
     rls_entry = None
 
     def setUp(self):
+        session = db.session
+
         # Create the RowLevelSecurityFilter
         self.rls_entry = RowLevelSecurityFilter()
         self.rls_entry.tables.extend(
-            db.session.query(SqlaTable)
+            session.query(SqlaTable)
             .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
             .all()
         )
@@ -964,11 +974,13 @@ class TestRowLevelSecurity(SupersetTestCase):
         )  # db.session.query(Role).filter_by(name="Gamma").first())
         self.rls_entry.roles.append(security_manager.find_role("Alpha"))
         db.session.add(self.rls_entry)
+
         db.session.commit()
 
     def tearDown(self):
-        db.session.delete(self.rls_entry)
-        db.session.commit()
+        session = db.session
+        session.delete(self.rls_entry)
+        session.commit()
 
     # Do another test to make sure it doesn't alter another query
     def test_rls_filter_alters_query(self):
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 9315d39..bff8d9d 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -63,6 +63,7 @@ class TestSqlLab(SupersetTestCase):
         self.logout()
         db.session.query(Query).delete()
         db.session.commit()
+        db.session.close()
 
     def test_sql_json(self):
         self.login("admin")
@@ -459,6 +460,7 @@ class TestSqlLab(SupersetTestCase):
         Test query api with can_access_all_queries perm added to
         gamma and make sure all queries show up.
         """
+        session = db.session
 
         # Add all_query_access perm to Gamma user
         all_queries_view = security_manager.find_permission_view_menu(
@@ -468,7 +470,7 @@ class TestSqlLab(SupersetTestCase):
         security_manager.add_permission_role(
             security_manager.find_role("gamma_sqllab"), all_queries_view
         )
-        db.session.commit()
+        session.commit()
 
         # Test search_queries for Admin user
         self.run_some_queries()
@@ -485,7 +487,7 @@ class TestSqlLab(SupersetTestCase):
             security_manager.find_role("gamma_sqllab"), all_queries_view
         )
 
-        db.session.commit()
+        session.commit()
 
     def test_query_admin_can_access_all_queries(self) -> None:
         """
diff --git a/tests/strategy_tests.py b/tests/strategy_tests.py
index 49e2349..c4f0019 100644
--- a/tests/strategy_tests.py
+++ b/tests/strategy_tests.py
@@ -194,7 +194,7 @@ class TestCacheWarmUp(SupersetTestCase):
             db.session.commit()
 
     def test_dashboard_tags(self):
-        tag1 = get_tag("tag1", TagTypes.custom)
+        tag1 = get_tag("tag1", db.session, TagTypes.custom)
         # delete first to make test idempotent
         self.reset_tag(tag1)
 
@@ -204,7 +204,7 @@ class TestCacheWarmUp(SupersetTestCase):
         self.assertEqual(result, expected)
 
         # tag dashboard 'births' with `tag1`
-        tag1 = get_tag("tag1", TagTypes.custom)
+        tag1 = get_tag("tag1", db.session, TagTypes.custom)
         dash = self.get_dash_by_slug("births")
         tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
         tagged_object = TaggedObject(
@@ -216,7 +216,7 @@ class TestCacheWarmUp(SupersetTestCase):
         self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
 
         strategy = DashboardTagsStrategy(["tag2"])
-        tag2 = get_tag("tag2", TagTypes.custom)
+        tag2 = get_tag("tag2", db.session, TagTypes.custom)
         self.reset_tag(tag2)
 
         result = sorted(strategy.get_urls())