You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2024/02/13 17:20:23 UTC

(superset) branch master updated: refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 847ed3f5b0 refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
847ed3f5b0 is described below

commit 847ed3f5b0016abed968abfacb6b9980e74dc1bf
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Feb 14 06:20:15 2024 +1300

    refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
---
 .pylintrc                                          |  2 +-
 superset/cli/importexport.py                       |  3 +-
 superset/commands/chart/importers/v1/__init__.py   | 10 ++-
 superset/commands/chart/importers/v1/utils.py      | 13 ++--
 .../commands/dashboard/importers/v1/__init__.py    | 19 +++--
 superset/commands/dashboard/importers/v1/utils.py  | 11 ++-
 .../commands/database/importers/v1/__init__.py     |  8 +--
 superset/commands/database/importers/v1/utils.py   | 13 ++--
 superset/commands/dataset/importers/v0.py          | 57 +++++++--------
 superset/commands/dataset/importers/v1/__init__.py |  8 +--
 superset/commands/dataset/importers/v1/utils.py    | 20 +++---
 superset/commands/importers/v1/__init__.py         |  6 +-
 superset/commands/importers/v1/assets.py           | 21 +++---
 superset/commands/importers/v1/examples.py         | 11 +--
 superset/commands/query/importers/v1/__init__.py   |  8 +--
 superset/commands/query/importers/v1/utils.py      | 13 ++--
 superset/connectors/sqla/models.py                 | 15 ++--
 superset/connectors/sqla/utils.py                  |  6 +-
 superset/daos/base.py                              | 13 ++--
 superset/databases/filters.py                      |  4 +-
 superset/db_engine_specs/gsheets.py                |  6 +-
 superset/extensions/__init__.py                    |  2 +-
 superset/models/dashboard.py                       |  2 +-
 superset/models/helpers.py                         | 15 ++--
 superset/security/manager.py                       | 15 ++--
 superset/sqllab/schemas.py                         |  2 +-
 superset/tables/models.py                          |  2 +-
 superset/tags/models.py                            | 24 ++++---
 superset/utils/dashboard_import_export.py          |  7 +-
 superset/utils/dict_import_export.py               |  7 +-
 tests/integration_tests/base_tests.py              | 30 ++++----
 tests/integration_tests/cache_tests.py             |  4 +-
 tests/integration_tests/charts/api_tests.py        | 10 +--
 tests/integration_tests/charts/commands_tests.py   |  2 +-
 tests/integration_tests/core_tests.py              | 13 ++--
 tests/integration_tests/dashboards/api_tests.py    |  4 +-
 .../dashboards/filter_state/api_tests.py           |  7 +-
 .../dashboards/permalink/api_tests.py              |  6 +-
 .../dashboards/superset_factory_util.py            | 58 ++++++++--------
 tests/integration_tests/databases/api_tests.py     | 10 ++-
 tests/integration_tests/datasets/api_tests.py      |  1 -
 tests/integration_tests/datasource_tests.py        | 36 +++++-----
 .../db_engine_specs/databricks_tests.py            | 16 ++---
 .../db_engine_specs/hive_tests.py                  | 14 ++--
 .../db_engine_specs/postgres_tests.py              | 24 +++----
 .../db_engine_specs/presto_tests.py                | 14 ++--
 .../integration_tests/dict_import_export_tests.py  | 26 ++++---
 tests/integration_tests/explore/api_tests.py       | 10 ++-
 .../explore/form_data/api_tests.py                 | 10 ++-
 .../explore/form_data/commands_tests.py            | 32 ++++-----
 .../explore/permalink/api_tests.py                 |  3 +-
 .../explore/permalink/commands_tests.py            | 32 ++++-----
 tests/integration_tests/fixtures/datasource.py     | 13 ++--
 tests/integration_tests/import_export_tests.py     | 15 ++--
 .../key_value/commands/fixtures.py                 |  3 +-
 .../security/guest_token_security_tests.py         | 15 ++--
 .../security/migrate_roles_tests.py                |  5 +-
 .../security/row_level_security_tests.py           | 29 ++++----
 tests/integration_tests/security_tests.py          |  4 +-
 tests/integration_tests/sqllab_tests.py            |  6 +-
 tests/integration_tests/test_jinja_context.py      | 38 +++++-----
 tests/integration_tests/utils/get_dashboards.py    |  5 +-
 tests/integration_tests/utils_tests.py             |  4 +-
 .../charts/commands/importers/v1/import_test.py    | 12 ++--
 tests/unit_tests/charts/dao/dao_tests.py           | 16 ++---
 tests/unit_tests/charts/test_post_processing.py    |  7 +-
 tests/unit_tests/columns/test_models.py            |  7 +-
 .../commands/importers/v1/assets_test.py           | 34 ++++-----
 tests/unit_tests/config_test.py                    |  4 +-
 tests/unit_tests/conftest.py                       |  4 +-
 tests/unit_tests/dao/dataset_test.py               |  5 +-
 tests/unit_tests/dao/queries_test.py               | 80 ++++++++++++----------
 tests/unit_tests/dao/tag_test.py                   |  2 +-
 .../commands/importers/v1/import_test.py           | 14 ++--
 tests/unit_tests/dashboards/dao_tests.py           | 12 ++--
 tests/unit_tests/databases/api_test.py             | 46 +++++++------
 .../databases/commands/importers/v1/import_test.py | 27 ++++----
 tests/unit_tests/databases/dao/dao_tests.py        | 10 +--
 .../databases/ssh_tunnel/commands/create_test.py   | 12 ++--
 .../databases/ssh_tunnel/commands/delete_test.py   | 10 +--
 .../databases/ssh_tunnel/commands/update_test.py   | 10 +--
 tests/unit_tests/databases/ssh_tunnel/dao_tests.py |  4 +-
 tests/unit_tests/datasets/api_tests.py             |  8 ++-
 tests/unit_tests/datasets/commands/export_test.py  |  8 ++-
 .../datasets/commands/importers/v1/import_test.py  | 58 ++++++++--------
 tests/unit_tests/datasets/dao/dao_tests.py         | 10 +--
 tests/unit_tests/datasource/dao_tests.py           | 14 ++--
 tests/unit_tests/db_engine_specs/test_druid.py     | 16 ++---
 tests/unit_tests/db_engine_specs/test_pinot.py     |  8 +--
 tests/unit_tests/extensions/test_sqlalchemy.py     | 27 ++++----
 tests/unit_tests/queries/dao_test.py               |  4 +-
 tests/unit_tests/sql_lab_test.py                   | 10 +--
 tests/unit_tests/sql_parse_tests.py                |  3 +-
 tests/unit_tests/tables/test_models.py             | 10 +--
 tests/unit_tests/tags/commands/create_test.py      | 29 ++++----
 tests/unit_tests/tags/commands/update_test.py      | 23 ++++---
 96 files changed, 656 insertions(+), 730 deletions(-)

diff --git a/.pylintrc b/.pylintrc
index 6083060624..1cab7a587a 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -108,7 +108,7 @@ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / stateme
 good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x,y
 
 # Bad variable names which should always be refused, separated by a comma
-bad-names=fd,foo,bar,baz,toto,tutu,tata
+bad-names=bar,baz,db,fd,foo,sesh,session,tata,toto,tutu
 
 # Colon-delimited sets of names that determine each other's naming style when
 # the name regexes allow several styles.
diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py
index fc6a9ad3c4..ebf94b444a 100755
--- a/superset/cli/importexport.py
+++ b/superset/cli/importexport.py
@@ -214,7 +214,7 @@ def legacy_export_dashboards(
     # pylint: disable=import-outside-toplevel
     from superset.utils import dashboard_import_export
 
-    data = dashboard_import_export.export_dashboards(db.session)
+    data = dashboard_import_export.export_dashboards()
     if print_stdout or not dashboard_file:
         print(data)
     if dashboard_file:
@@ -263,7 +263,6 @@ def legacy_export_datasources(
     from superset.utils import dict_import_export
 
     data = dict_import_export.export_to_dict(
-        session=db.session,
         recursive=True,
         back_references=back_references,
         include_defaults=include_defaults,
diff --git a/superset/commands/chart/importers/v1/__init__.py b/superset/commands/chart/importers/v1/__init__.py
index f99fbb9008..7f2537383f 100644
--- a/superset/commands/chart/importers/v1/__init__.py
+++ b/superset/commands/chart/importers/v1/__init__.py
@@ -47,9 +47,7 @@ class ImportChartsCommand(ImportModelsCommand):
     import_error = ChartImportError
 
     @staticmethod
-    def _import(
-        session: Session, configs: dict[str, Any], overwrite: bool = False
-    ) -> None:
+    def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
         # discover datasets associated with charts
         dataset_uuids: set[str] = set()
         for file_name, config in configs.items():
@@ -66,7 +64,7 @@ class ImportChartsCommand(ImportModelsCommand):
         database_ids: dict[str, int] = {}
         for file_name, config in configs.items():
             if file_name.startswith("databases/") and config["uuid"] in database_uuids:
-                database = import_database(session, config, overwrite=False)
+                database = import_database(config, overwrite=False)
                 database_ids[str(database.uuid)] = database.id
 
         # import datasets with the correct parent ref
@@ -77,7 +75,7 @@ class ImportChartsCommand(ImportModelsCommand):
                 and config["database_uuid"] in database_ids
             ):
                 config["database_id"] = database_ids[config["database_uuid"]]
-                dataset = import_dataset(session, config, overwrite=False)
+                dataset = import_dataset(config, overwrite=False)
                 datasets[str(dataset.uuid)] = dataset
 
         # import charts with the correct parent ref
@@ -101,4 +99,4 @@ class ImportChartsCommand(ImportModelsCommand):
                 if "query_context" in config:
                     config["query_context"] = None
 
-                import_chart(session, config, overwrite=overwrite)
+                import_chart(config, overwrite=overwrite)
diff --git a/superset/commands/chart/importers/v1/utils.py b/superset/commands/chart/importers/v1/utils.py
index 2aac3ea9c4..f1b38e7ddc 100644
--- a/superset/commands/chart/importers/v1/utils.py
+++ b/superset/commands/chart/importers/v1/utils.py
@@ -20,9 +20,7 @@ import json
 from inspect import isclass
 from typing import Any
 
-from sqlalchemy.orm import Session
-
-from superset import security_manager
+from superset import db, security_manager
 from superset.commands.exceptions import ImportFailedError
 from superset.migrations.shared.migrate_viz import processors
 from superset.migrations.shared.migrate_viz.base import MigrateViz
@@ -46,13 +44,12 @@ def filter_chart_annotations(chart_config: dict[str, Any]) -> None:
 
 
 def import_chart(
-    session: Session,
     config: dict[str, Any],
     overwrite: bool = False,
     ignore_permissions: bool = False,
 ) -> Slice:
     can_write = ignore_permissions or security_manager.can_access("can_write", "Chart")
-    existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
+    existing = db.session.query(Slice).filter_by(uuid=config["uuid"]).first()
     if existing:
         if overwrite and can_write and get_user():
             if not security_manager.can_access_chart(existing):
@@ -76,11 +73,9 @@ def import_chart(
     # migrate old viz types to new ones
     config = migrate_chart(config)
 
-    chart = Slice.import_from_dict(
-        session, config, recursive=False, allow_reparenting=True
-    )
+    chart = Slice.import_from_dict(config, recursive=False, allow_reparenting=True)
     if chart.id is None:
-        session.flush()
+        db.session.flush()
 
     if user := get_user():
         chart.owners.append(user)
diff --git a/superset/commands/dashboard/importers/v1/__init__.py b/superset/commands/dashboard/importers/v1/__init__.py
index 62f5f393e9..77d28696cf 100644
--- a/superset/commands/dashboard/importers/v1/__init__.py
+++ b/superset/commands/dashboard/importers/v1/__init__.py
@@ -21,6 +21,7 @@ from marshmallow import Schema
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import select
 
+from superset import db
 from superset.charts.schemas import ImportV1ChartSchema
 from superset.commands.chart.importers.v1.utils import import_chart
 from superset.commands.dashboard.exceptions import DashboardImportError
@@ -59,9 +60,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
     # TODO (betodealmeida): refactor to use code from other commands
     # pylint: disable=too-many-branches, too-many-locals
     @staticmethod
-    def _import(
-        session: Session, configs: dict[str, Any], overwrite: bool = False
-    ) -> None:
+    def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
         # discover charts and datasets associated with dashboards
         chart_uuids: set[str] = set()
         dataset_uuids: set[str] = set()
@@ -87,7 +86,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
         database_ids: dict[str, int] = {}
         for file_name, config in configs.items():
             if file_name.startswith("databases/") and config["uuid"] in database_uuids:
-                database = import_database(session, config, overwrite=False)
+                database = import_database(config, overwrite=False)
                 database_ids[str(database.uuid)] = database.id
 
         # import datasets with the correct parent ref
@@ -98,7 +97,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
                 and config["database_uuid"] in database_ids
             ):
                 config["database_id"] = database_ids[config["database_uuid"]]
-                dataset = import_dataset(session, config, overwrite=False)
+                dataset = import_dataset(config, overwrite=False)
                 dataset_info[str(dataset.uuid)] = {
                     "datasource_id": dataset.id,
                     "datasource_type": dataset.datasource_type,
@@ -122,12 +121,12 @@ class ImportDashboardsCommand(ImportModelsCommand):
                 if "query_context" in config:
                     config["query_context"] = None
 
-                chart = import_chart(session, config, overwrite=False)
+                chart = import_chart(config, overwrite=False)
                 charts.append(chart)
                 chart_ids[str(chart.uuid)] = chart.id
 
         # store the existing relationship between dashboards and charts
-        existing_relationships = session.execute(
+        existing_relationships = db.session.execute(
             select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
         ).fetchall()
 
@@ -137,7 +136,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
         for file_name, config in configs.items():
             if file_name.startswith("dashboards/"):
                 config = update_id_refs(config, chart_ids, dataset_info)
-                dashboard = import_dashboard(session, config, overwrite=overwrite)
+                dashboard = import_dashboard(config, overwrite=overwrite)
                 dashboards.append(dashboard)
                 for uuid in find_chart_uuids(config["position"]):
                     if uuid not in chart_ids:
@@ -151,7 +150,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
             {"dashboard_id": dashboard_id, "slice_id": chart_id}
             for (dashboard_id, chart_id) in dashboard_chart_ids
         ]
-        session.execute(dashboard_slices.insert(), values)
+        db.session.execute(dashboard_slices.insert(), values)
 
         # Migrate any filter-box charts to native dashboard filters.
         for dashboard in dashboards:
@@ -160,4 +159,4 @@ class ImportDashboardsCommand(ImportModelsCommand):
         # Remove all obsolete filter-box charts.
         for chart in charts:
             if chart.viz_type == "filter_box":
-                session.delete(chart)
+                db.session.delete(chart)
diff --git a/superset/commands/dashboard/importers/v1/utils.py b/superset/commands/dashboard/importers/v1/utils.py
index b8ac3144db..09be75a6ea 100644
--- a/superset/commands/dashboard/importers/v1/utils.py
+++ b/superset/commands/dashboard/importers/v1/utils.py
@@ -19,9 +19,7 @@ import json
 import logging
 from typing import Any
 
-from sqlalchemy.orm import Session
-
-from superset import security_manager
+from superset import db, security_manager
 from superset.commands.exceptions import ImportFailedError
 from superset.models.dashboard import Dashboard
 from superset.utils.core import get_user
@@ -146,7 +144,6 @@ def update_id_refs(  # pylint: disable=too-many-locals
 
 
 def import_dashboard(
-    session: Session,
     config: dict[str, Any],
     overwrite: bool = False,
     ignore_permissions: bool = False,
@@ -155,7 +152,7 @@ def import_dashboard(
         "can_write",
         "Dashboard",
     )
-    existing = session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
+    existing = db.session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
     if existing:
         if overwrite and can_write and get_user():
             if not security_manager.can_access_dashboard(existing):
@@ -187,9 +184,9 @@ def import_dashboard(
             except TypeError:
                 logger.info("Unable to encode `%s` field: %s", key, value)
 
-    dashboard = Dashboard.import_from_dict(session, config, recursive=False)
+    dashboard = Dashboard.import_from_dict(config, recursive=False)
     if dashboard.id is None:
-        session.flush()
+        db.session.flush()
 
     if user := get_user():
         dashboard.owners.append(user)
diff --git a/superset/commands/database/importers/v1/__init__.py b/superset/commands/database/importers/v1/__init__.py
index 73b1bca531..203f0e3089 100644
--- a/superset/commands/database/importers/v1/__init__.py
+++ b/superset/commands/database/importers/v1/__init__.py
@@ -43,14 +43,12 @@ class ImportDatabasesCommand(ImportModelsCommand):
     import_error = DatabaseImportError
 
     @staticmethod
-    def _import(
-        session: Session, configs: dict[str, Any], overwrite: bool = False
-    ) -> None:
+    def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
         # first import databases
         database_ids: dict[str, int] = {}
         for file_name, config in configs.items():
             if file_name.startswith("databases/"):
-                database = import_database(session, config, overwrite=overwrite)
+                database = import_database(config, overwrite=overwrite)
                 database_ids[str(database.uuid)] = database.id
 
         # import related datasets
@@ -61,4 +59,4 @@ class ImportDatabasesCommand(ImportModelsCommand):
             ):
                 config["database_id"] = database_ids[config["database_uuid"]]
                 # overwrite=False prevents deleting any non-imported columns/metrics
-                import_dataset(session, config, overwrite=False)
+                import_dataset(config, overwrite=False)
diff --git a/superset/commands/database/importers/v1/utils.py b/superset/commands/database/importers/v1/utils.py
index c8c2847b9f..17b8488b44 100644
--- a/superset/commands/database/importers/v1/utils.py
+++ b/superset/commands/database/importers/v1/utils.py
@@ -18,9 +18,7 @@
 import json
 from typing import Any
 
-from sqlalchemy.orm import Session
-
-from superset import app, security_manager
+from superset import app, db, security_manager
 from superset.commands.exceptions import ImportFailedError
 from superset.databases.ssh_tunnel.models import SSHTunnel
 from superset.databases.utils import make_url_safe
@@ -30,7 +28,6 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
 
 
 def import_database(
-    session: Session,
     config: dict[str, Any],
     overwrite: bool = False,
     ignore_permissions: bool = False,
@@ -39,7 +36,7 @@ def import_database(
         "can_write",
         "Database",
     )
-    existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
+    existing = db.session.query(Database).filter_by(uuid=config["uuid"]).first()
     if existing:
         if not overwrite or not can_write:
             return existing
@@ -67,12 +64,12 @@ def import_database(
     # Before it gets removed in import_from_dict
     ssh_tunnel = config.pop("ssh_tunnel", None)
 
-    database = Database.import_from_dict(session, config, recursive=False)
+    database = Database.import_from_dict(config, recursive=False)
     if database.id is None:
-        session.flush()
+        db.session.flush()
 
     if ssh_tunnel:
         ssh_tunnel["database_id"] = database.id
-        SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
+        SSHTunnel.import_from_dict(ssh_tunnel, recursive=False)
 
     return database
diff --git a/superset/commands/dataset/importers/v0.py b/superset/commands/dataset/importers/v0.py
index d389a17651..6c1d79779e 100644
--- a/superset/commands/dataset/importers/v0.py
+++ b/superset/commands/dataset/importers/v0.py
@@ -20,7 +20,6 @@ from typing import Any, Callable, Optional
 
 import yaml
 from flask_appbuilder import Model
-from sqlalchemy.orm import Session
 from sqlalchemy.orm.session import make_transient
 
 from superset import db
@@ -86,7 +85,6 @@ def import_dataset(
         raise DatasetInvalidError
 
     return import_datasource(
-        db.session,
         i_datasource,
         lookup_database,
         lookup_datasource,
@@ -95,9 +93,9 @@ def import_dataset(
     )
 
 
-def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric:
+def lookup_sqla_metric(metric: SqlMetric) -> SqlMetric:
     return (
-        session.query(SqlMetric)
+        db.session.query(SqlMetric)
         .filter(
             SqlMetric.table_id == metric.table_id,
             SqlMetric.metric_name == metric.metric_name,
@@ -106,13 +104,13 @@ def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric:
     )
 
 
-def import_metric(session: Session, metric: SqlMetric) -> SqlMetric:
-    return import_simple_obj(session, metric, lookup_sqla_metric)
+def import_metric(metric: SqlMetric) -> SqlMetric:
+    return import_simple_obj(metric, lookup_sqla_metric)
 
 
-def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn:
+def lookup_sqla_column(column: TableColumn) -> TableColumn:
     return (
-        session.query(TableColumn)
+        db.session.query(TableColumn)
         .filter(
             TableColumn.table_id == column.table_id,
             TableColumn.column_name == column.column_name,
@@ -121,12 +119,11 @@ def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn:
     )
 
 
-def import_column(session: Session, column: TableColumn) -> TableColumn:
-    return import_simple_obj(session, column, lookup_sqla_column)
+def import_column(column: TableColumn) -> TableColumn:
+    return import_simple_obj(column, lookup_sqla_column)
 
 
-def import_datasource(  # pylint: disable=too-many-arguments
-    session: Session,
+def import_datasource(
     i_datasource: Model,
     lookup_database: Callable[[Model], Optional[Model]],
     lookup_datasource: Callable[[Model], Optional[Model]],
@@ -155,11 +152,11 @@ def import_datasource(  # pylint: disable=too-many-arguments
 
     if datasource:
         datasource.override(i_datasource)
-        session.flush()
+        db.session.flush()
     else:
         datasource = i_datasource.copy()
-        session.add(datasource)
-        session.flush()
+        db.session.add(datasource)
+        db.session.flush()
 
     for metric in i_datasource.metrics:
         new_m = metric.copy()
@@ -169,7 +166,7 @@ def import_datasource(  # pylint: disable=too-many-arguments
             new_m.to_json(),
             i_datasource.full_name,
         )
-        imported_m = import_metric(session, new_m)
+        imported_m = import_metric(new_m)
         if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]:
             datasource.metrics.append(imported_m)
 
@@ -181,44 +178,40 @@ def import_datasource(  # pylint: disable=too-many-arguments
             new_c.to_json(),
             i_datasource.full_name,
         )
-        imported_c = import_column(session, new_c)
+        imported_c = import_column(new_c)
         if imported_c.column_name not in [c.column_name for c in datasource.columns]:
             datasource.columns.append(imported_c)
-    session.flush()
+    db.session.flush()
     return datasource.id
 
 
-def import_simple_obj(
-    session: Session, i_obj: Model, lookup_obj: Callable[[Session, Model], Model]
-) -> Model:
+def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
     make_transient(i_obj)
     i_obj.id = None
     i_obj.table = None
 
     # find if the column was already imported
-    existing_column = lookup_obj(session, i_obj)
+    existing_column = lookup_obj(i_obj)
     i_obj.table = None
     if existing_column:
         existing_column.override(i_obj)
-        session.flush()
+        db.session.flush()
         return existing_column
 
-    session.add(i_obj)
-    session.flush()
+    db.session.add(i_obj)
+    db.session.flush()
     return i_obj
 
 
-def import_from_dict(
-    session: Session, data: dict[str, Any], sync: Optional[list[str]] = None
-) -> None:
+def import_from_dict(data: dict[str, Any], sync: Optional[list[str]] = None) -> None:
     """Imports databases from dictionary"""
     if not sync:
         sync = []
     if isinstance(data, dict):
         logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
         for database in data.get(DATABASES_KEY, []):
-            Database.import_from_dict(session, database, sync=sync)
-        session.commit()
+            Database.import_from_dict(database, sync=sync)
+        db.session.commit()
     else:
         logger.info("Supplied object is not a dictionary.")
 
@@ -254,7 +247,7 @@ class ImportDatasetsCommand(BaseCommand):
         for file_name, config in self._configs.items():
             logger.info("Importing dataset from file %s", file_name)
             if isinstance(config, dict):
-                import_from_dict(db.session, config, sync=self.sync)
+                import_from_dict(config, sync=self.sync)
             else:  # list
                 for dataset in config:
                     # UI exports don't have the database metadata, so we assume
@@ -266,7 +259,7 @@ class ImportDatasetsCommand(BaseCommand):
                         .one()
                     )
                     dataset["database_id"] = database.id
-                    SqlaTable.import_from_dict(db.session, dataset, sync=self.sync)
+                    SqlaTable.import_from_dict(dataset, sync=self.sync)
 
     def validate(self) -> None:
         # ensure all files are YAML
diff --git a/superset/commands/dataset/importers/v1/__init__.py b/superset/commands/dataset/importers/v1/__init__.py
index 600a39bf48..29f850258c 100644
--- a/superset/commands/dataset/importers/v1/__init__.py
+++ b/superset/commands/dataset/importers/v1/__init__.py
@@ -43,9 +43,7 @@ class ImportDatasetsCommand(ImportModelsCommand):
     import_error = DatasetImportError
 
     @staticmethod
-    def _import(
-        session: Session, configs: dict[str, Any], overwrite: bool = False
-    ) -> None:
+    def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
         # discover databases associated with datasets
         database_uuids: set[str] = set()
         for file_name, config in configs.items():
@@ -56,7 +54,7 @@ class ImportDatasetsCommand(ImportModelsCommand):
         database_ids: dict[str, int] = {}
         for file_name, config in configs.items():
             if file_name.startswith("databases/") and config["uuid"] in database_uuids:
-                database = import_database(session, config, overwrite=False)
+                database = import_database(config, overwrite=False)
                 database_ids[str(database.uuid)] = database.id
 
         # import datasets with the correct parent ref
@@ -66,4 +64,4 @@ class ImportDatasetsCommand(ImportModelsCommand):
                 and config["database_uuid"] in database_ids
             ):
                 config["database_id"] = database_ids[config["database_uuid"]]
-                import_dataset(session, config, overwrite=overwrite)
+                import_dataset(config, overwrite=overwrite)
diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py
index 014a864da4..04fc81e241 100644
--- a/superset/commands/dataset/importers/v1/utils.py
+++ b/superset/commands/dataset/importers/v1/utils.py
@@ -25,10 +25,9 @@ import pandas as pd
 from flask import current_app
 from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
 from sqlalchemy.exc import MultipleResultsFound
-from sqlalchemy.orm import Session
 from sqlalchemy.sql.visitors import VisitableType
 
-from superset import security_manager
+from superset import db, security_manager
 from superset.commands.dataset.exceptions import DatasetForbiddenDataURI
 from superset.commands.exceptions import ImportFailedError
 from superset.connectors.sqla.models import SqlaTable
@@ -103,7 +102,6 @@ def validate_data_uri(data_uri: str) -> None:
 
 
 def import_dataset(
-    session: Session,
     config: dict[str, Any],
     overwrite: bool = False,
     force_data: bool = False,
@@ -113,7 +111,7 @@ def import_dataset(
         "can_write",
         "Dataset",
     )
-    existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
+    existing = db.session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
     if existing:
         if not overwrite or not can_write:
             return existing
@@ -150,7 +148,7 @@ def import_dataset(
 
     # import recursively to include columns and metrics
     try:
-        dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
+        dataset = SqlaTable.import_from_dict(config, recursive=True, sync=sync)
     except MultipleResultsFound:
         # Finding multiple results when importing a dataset only happens because initially
         # datasets were imported without schemas (eg, `examples.NULL.users`), and later
@@ -160,10 +158,10 @@ def import_dataset(
         # `examples.public.users`, resulting in a conflict.
         #
         # When that happens, we return the original dataset, unmodified.
-        dataset = session.query(SqlaTable).filter_by(uuid=config["uuid"]).one()
+        dataset = db.session.query(SqlaTable).filter_by(uuid=config["uuid"]).one()
 
     if dataset.id is None:
-        session.flush()
+        db.session.flush()
 
     try:
         table_exists = dataset.database.has_table_by_name(dataset.table_name)
@@ -175,7 +173,7 @@ def import_dataset(
         table_exists = True
 
     if data_uri and (not table_exists or force_data):
-        load_data(data_uri, dataset, dataset.database, session)
+        load_data(data_uri, dataset, dataset.database)
 
     if user := get_user():
         dataset.owners.append(user)
@@ -183,9 +181,7 @@ def import_dataset(
     return dataset
 
 
-def load_data(
-    data_uri: str, dataset: SqlaTable, database: Database, session: Session
-) -> None:
+def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None:
     """
     Load data from a data URI into a dataset.
 
@@ -208,7 +204,7 @@ def load_data(
     # reuse session when loading data if possible, to make import atomic
     if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"):
         logger.info("Loading data inside the import transaction")
-        connection = session.connection()
+        connection = db.session.connection()
         df.to_sql(
             dataset.table_name,
             con=connection,
diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py
index 38d6568af4..8d90875fd3 100644
--- a/superset/commands/importers/v1/__init__.py
+++ b/superset/commands/importers/v1/__init__.py
@@ -60,9 +60,7 @@ class ImportModelsCommand(BaseCommand):
         self._configs: dict[str, Any] = {}
 
     @staticmethod
-    def _import(
-        session: Session, configs: dict[str, Any], overwrite: bool = False
-    ) -> None:
+    def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
         raise NotImplementedError("Subclasses MUST implement _import")
 
     @classmethod
@@ -74,7 +72,7 @@ class ImportModelsCommand(BaseCommand):
 
         # rollback to prevent partial imports
         try:
-            self._import(db.session, self._configs, self.overwrite)
+            self._import(self._configs, self.overwrite)
             db.session.commit()
         except CommandException as ex:
             db.session.rollback()
diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py
index fe9539ac80..876ce509ae 100644
--- a/superset/commands/importers/v1/assets.py
+++ b/superset/commands/importers/v1/assets.py
@@ -18,7 +18,6 @@ from typing import Any, Optional
 
 from marshmallow import Schema
 from marshmallow.exceptions import ValidationError
-from sqlalchemy.orm import Session
 from sqlalchemy.sql import delete, insert
 
 from superset import db
@@ -80,26 +79,26 @@ class ImportAssetsCommand(BaseCommand):
 
     # pylint: disable=too-many-locals
     @staticmethod
-    def _import(session: Session, configs: dict[str, Any]) -> None:
+    def _import(configs: dict[str, Any]) -> None:
         # import databases first
         database_ids: dict[str, int] = {}
         for file_name, config in configs.items():
             if file_name.startswith("databases/"):
-                database = import_database(session, config, overwrite=True)
+                database = import_database(config, overwrite=True)
                 database_ids[str(database.uuid)] = database.id
 
         # import saved queries
         for file_name, config in configs.items():
             if file_name.startswith("queries/"):
                 config["db_id"] = database_ids[config["database_uuid"]]
-                import_saved_query(session, config, overwrite=True)
+                import_saved_query(config, overwrite=True)
 
         # import datasets
         dataset_info: dict[str, dict[str, Any]] = {}
         for file_name, config in configs.items():
             if file_name.startswith("datasets/"):
                 config["database_id"] = database_ids[config["database_uuid"]]
-                dataset = import_dataset(session, config, overwrite=True)
+                dataset = import_dataset(config, overwrite=True)
                 dataset_info[str(dataset.uuid)] = {
                     "datasource_id": dataset.id,
                     "datasource_type": dataset.datasource_type,
@@ -118,7 +117,7 @@ class ImportAssetsCommand(BaseCommand):
                 config["params"].update({"datasource": dataset_uid})
                 if "query_context" in config:
                     config["query_context"] = None
-                chart = import_chart(session, config, overwrite=True)
+                chart = import_chart(config, overwrite=True)
                 charts.append(chart)
                 chart_ids[str(chart.uuid)] = chart.id
 
@@ -126,7 +125,7 @@ class ImportAssetsCommand(BaseCommand):
         for file_name, config in configs.items():
             if file_name.startswith("dashboards/"):
                 config = update_id_refs(config, chart_ids, dataset_info)
-                dashboard = import_dashboard(session, config, overwrite=True)
+                dashboard = import_dashboard(config, overwrite=True)
 
                 # set ref in the dashboard_slices table
                 dashboard_chart_ids: list[dict[str, int]] = []
@@ -140,12 +139,12 @@ class ImportAssetsCommand(BaseCommand):
                     }
                     dashboard_chart_ids.append(dashboard_chart_id)
 
-                session.execute(
+                db.session.execute(
                     delete(dashboard_slices).where(
                         dashboard_slices.c.dashboard_id == dashboard.id
                     )
                 )
-                session.execute(insert(dashboard_slices).values(dashboard_chart_ids))
+                db.session.execute(insert(dashboard_slices).values(dashboard_chart_ids))
 
                 # Migrate any filter-box charts to native dashboard filters.
                 migrate_dashboard(dashboard)
@@ -153,14 +152,14 @@ class ImportAssetsCommand(BaseCommand):
         # Remove all obsolete filter-box charts.
         for chart in charts:
             if chart.viz_type == "filter_box":
-                session.delete(chart)
+                db.session.delete(chart)
 
     def run(self) -> None:
         self.validate()
 
         # rollback to prevent partial imports
         try:
-            self._import(db.session, self._configs)
+            self._import(self._configs)
             db.session.commit()
         except Exception as ex:
             db.session.rollback()
diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py
index 87280033eb..ff69aadc45 100644
--- a/superset/commands/importers/v1/examples.py
+++ b/superset/commands/importers/v1/examples.py
@@ -18,7 +18,6 @@ from typing import Any
 
 from marshmallow import Schema
 from sqlalchemy.exc import MultipleResultsFound
-from sqlalchemy.orm import Session
 from sqlalchemy.sql import select
 
 from superset import db
@@ -70,7 +69,6 @@ class ImportExamplesCommand(ImportModelsCommand):
         # rollback to prevent partial imports
         try:
             self._import(
-                db.session,
                 self._configs,
                 self.overwrite,
                 self.force_data,
@@ -92,7 +90,6 @@ class ImportExamplesCommand(ImportModelsCommand):
 
     @staticmethod
     def _import(  # pylint: disable=too-many-locals, too-many-branches
-        session: Session,
         configs: dict[str, Any],
         overwrite: bool = False,
         force_data: bool = False,
@@ -102,7 +99,6 @@ class ImportExamplesCommand(ImportModelsCommand):
         for file_name, config in configs.items():
             if file_name.startswith("databases/"):
                 database = import_database(
-                    session,
                     config,
                     overwrite=overwrite,
                     ignore_permissions=True,
@@ -133,7 +129,6 @@ class ImportExamplesCommand(ImportModelsCommand):
 
                 try:
                     dataset = import_dataset(
-                        session,
                         config,
                         overwrite=overwrite,
                         force_data=force_data,
@@ -164,7 +159,6 @@ class ImportExamplesCommand(ImportModelsCommand):
                 # update datasource id, type, and name
                 config.update(dataset_info[config["dataset_uuid"]])
                 chart = import_chart(
-                    session,
                     config,
                     overwrite=overwrite,
                     ignore_permissions=True,
@@ -172,7 +166,7 @@ class ImportExamplesCommand(ImportModelsCommand):
                 chart_ids[str(chart.uuid)] = chart.id
 
         # store the existing relationship between dashboards and charts
-        existing_relationships = session.execute(
+        existing_relationships = db.session.execute(
             select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
         ).fetchall()
 
@@ -186,7 +180,6 @@ class ImportExamplesCommand(ImportModelsCommand):
                     continue
 
                 dashboard = import_dashboard(
-                    session,
                     config,
                     overwrite=overwrite,
                     ignore_permissions=True,
@@ -203,4 +196,4 @@ class ImportExamplesCommand(ImportModelsCommand):
             {"dashboard_id": dashboard_id, "slice_id": chart_id}
             for (dashboard_id, chart_id) in dashboard_chart_ids
         ]
-        session.execute(dashboard_slices.insert(), values)
+        db.session.execute(dashboard_slices.insert(), values)
diff --git a/superset/commands/query/importers/v1/__init__.py b/superset/commands/query/importers/v1/__init__.py
index fa1f21b6fc..f251759c38 100644
--- a/superset/commands/query/importers/v1/__init__.py
+++ b/superset/commands/query/importers/v1/__init__.py
@@ -43,9 +43,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
     import_error = SavedQueryImportError
 
     @staticmethod
-    def _import(
-        session: Session, configs: dict[str, Any], overwrite: bool = False
-    ) -> None:
+    def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
         # discover databases associated with saved queries
         database_uuids: set[str] = set()
         for file_name, config in configs.items():
@@ -56,7 +54,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
         database_ids: dict[str, int] = {}
         for file_name, config in configs.items():
             if file_name.startswith("databases/") and config["uuid"] in database_uuids:
-                database = import_database(session, config, overwrite=False)
+                database = import_database(config, overwrite=False)
                 database_ids[str(database.uuid)] = database.id
 
         # import saved queries with the correct parent ref
@@ -66,4 +64,4 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
                 and config["database_uuid"] in database_ids
             ):
                 config["db_id"] = database_ids[config["database_uuid"]]
-                import_saved_query(session, config, overwrite=overwrite)
+                import_saved_query(config, overwrite=overwrite)
diff --git a/superset/commands/query/importers/v1/utils.py b/superset/commands/query/importers/v1/utils.py
index 813f3c2295..d611aa5e3a 100644
--- a/superset/commands/query/importers/v1/utils.py
+++ b/superset/commands/query/importers/v1/utils.py
@@ -17,22 +17,19 @@
 
 from typing import Any
 
-from sqlalchemy.orm import Session
-
+from superset import db
 from superset.models.sql_lab import SavedQuery
 
 
-def import_saved_query(
-    session: Session, config: dict[str, Any], overwrite: bool = False
-) -> SavedQuery:
-    existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
+def import_saved_query(config: dict[str, Any], overwrite: bool = False) -> SavedQuery:
+    existing = db.session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
     if existing:
         if not overwrite:
             return existing
         config["id"] = existing.id
 
-    saved_query = SavedQuery.import_from_dict(session, config, recursive=False)
+    saved_query = SavedQuery.import_from_dict(config, recursive=False)
     if saved_query.id is None:
-        session.flush()
+        db.session.flush()
 
     return saved_query
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 08dc923c21..2552740695 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -65,7 +65,6 @@ from sqlalchemy.orm import (
     reconstructor,
     relationship,
     RelationshipProperty,
-    Session,
 )
 from sqlalchemy.orm.mapper import Mapper
 from sqlalchemy.schema import UniqueConstraint
@@ -1902,13 +1901,12 @@ class SqlaTable(
     @classmethod
     def query_datasources_by_name(
         cls,
-        session: Session,
         database: Database,
         datasource_name: str,
         schema: str | None = None,
     ) -> list[SqlaTable]:
         query = (
-            session.query(cls)
+            db.session.query(cls)
             .filter_by(database_id=database.id)
             .filter_by(table_name=datasource_name)
         )
@@ -1919,14 +1917,13 @@ class SqlaTable(
     @classmethod
     def query_datasources_by_permissions(  # pylint: disable=invalid-name
         cls,
-        session: Session,
         database: Database,
         permissions: set[str],
         schema_perms: set[str],
     ) -> list[SqlaTable]:
         # TODO(hughhhh): add unit test
         return (
-            session.query(cls)
+            db.session.query(cls)
             .filter_by(database_id=database.id)
             .filter(
                 or_(
@@ -1951,8 +1948,8 @@ class SqlaTable(
         )
 
     @classmethod
-    def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
-        qry = session.query(cls)
+    def get_all_datasources(cls) -> list[SqlaTable]:
+        qry = db.session.query(cls)
         qry = cls.default_query(qry)
         return qry.all()
 
@@ -2034,7 +2031,7 @@ class SqlaTable(
         :param connection: Unused.
         :param target: The metric or column that was updated.
         """
-        session = inspect(target).session
+        session = inspect(target).session  # pylint: disable=disallowed-name
 
         # Forces an update to the table's changed_on value when a metric or column on the
         # table is updated. This busts the cache key for all charts that use the table.
@@ -2068,7 +2065,7 @@ class SqlaTable(
         if self.database_id and (
             not self.database or self.database.id != self.database_id
         ):
-            session = inspect(self).session
+            session = inspect(self).session  # pylint: disable=disallowed-name
             self.database = session.query(Database).filter_by(id=self.database_id).one()
 
 
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 688be53515..58a90e6eca 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -26,10 +26,10 @@ from flask_babel import lazy_gettext as _
 from sqlalchemy.engine.url import URL as SqlaURL
 from sqlalchemy.exc import NoSuchTableError
 from sqlalchemy.ext.declarative import DeclarativeMeta
-from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import ObjectDeletedError
 from sqlalchemy.sql.type_api import TypeEngine
 
+from superset import db
 from superset.constants import LRU_CACHE_MAX_SIZE
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import (
@@ -168,14 +168,12 @@ logger = logging.getLogger(__name__)
 
 
 def find_cached_objects_in_session(
-    session: Session,
     cls: type[DeclarativeModel],
     ids: Iterable[int] | None = None,
     uuids: Iterable[UUID] | None = None,
 ) -> Iterator[DeclarativeModel]:
     """Find known ORM instances in cached SQLA session states.
 
-    :param session: a SQLA session
     :param cls: a SQLA DeclarativeModel
     :param ids: ids of the desired model instances (optional)
     :param uuids: uuids of the desired instances, will be ignored if `ids` are provides
@@ -184,7 +182,7 @@ def find_cached_objects_in_session(
         return iter([])
     uuids = uuids or []
     try:
-        items = list(session)
+        items = list(db.session)
     except ObjectDeletedError:
         logger.warning("ObjectDeletedError", exc_info=True)
         return iter(())
diff --git a/superset/daos/base.py b/superset/daos/base.py
index 1133a76a1e..ed6471ac81 100644
--- a/superset/daos/base.py
+++ b/superset/daos/base.py
@@ -22,7 +22,6 @@ from flask_appbuilder.models.filters import BaseFilter
 from flask_appbuilder.models.sqla import Model
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 from sqlalchemy.exc import SQLAlchemyError, StatementError
-from sqlalchemy.orm import Session
 
 from superset.daos.exceptions import (
     DAOCreateFailedError,
@@ -59,16 +58,14 @@ class BaseDAO(Generic[T]):
     def find_by_id(
         cls,
         model_id: str | int,
-        session: Session = None,
         skip_base_filter: bool = False,
     ) -> T | None:
         """
         Find a model by id, if defined applies `base_filter`
         """
-        session = session or db.session
-        query = session.query(cls.model_cls)
+        query = db.session.query(cls.model_cls)
         if cls.base_filter and not skip_base_filter:
-            data_model = SQLAInterface(cls.model_cls, session)
+            data_model = SQLAInterface(cls.model_cls, db.session)
             query = cls.base_filter(  # pylint: disable=not-callable
                 cls.id_column_name, data_model
             ).apply(query, None)
@@ -83,7 +80,6 @@ class BaseDAO(Generic[T]):
     def find_by_ids(
         cls,
         model_ids: list[str] | list[int],
-        session: Session = None,
         skip_base_filter: bool = False,
     ) -> list[T]:
         """
@@ -92,10 +88,9 @@ class BaseDAO(Generic[T]):
         id_col = getattr(cls.model_cls, cls.id_column_name, None)
         if id_col is None:
             return []
-        session = session or db.session
-        query = session.query(cls.model_cls).filter(id_col.in_(model_ids))
+        query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
         if cls.base_filter and not skip_base_filter:
-            data_model = SQLAInterface(cls.model_cls, session)
+            data_model = SQLAInterface(cls.model_cls, db.session)
             query = cls.base_filter(  # pylint: disable=not-callable
                 cls.id_column_name, data_model
             ).apply(query, None)
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 384a62c9d3..33748da4b6 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -86,8 +86,8 @@ class DatabaseUploadEnabledFilter(BaseFilter):  # pylint: disable=too-few-public
 
         if hasattr(g, "user"):
             allowed_schemas = [
-                app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](db, g.user)
-                for db in datasource_access_databases
+                app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](database, g.user)
+                for database in datasource_access_databases
             ]
 
             if len(allowed_schemas):
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index b78a24ec12..18349f4314 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -310,7 +310,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
 
     @staticmethod
     def _do_post(
-        session: Session,
+        session: Session,  # pylint: disable=disallowed-name
         url: str,
         body: dict[str, Any],
         **kwargs: Any,
@@ -385,7 +385,9 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
                     conn,
                     spreadsheet_url or EXAMPLE_GSHEETS_URL,
                 )
-                session = adapter._get_session()  # pylint: disable=protected-access
+                session = (  # pylint: disable=disallowed-name
+                    adapter._get_session()  # pylint: disable=protected-access
+                )
 
         # clear existing sheet, or create a new one
         if spreadsheet_url:
diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py
index c68332738b..65ba7eebc8 100644
--- a/superset/extensions/__init__.py
+++ b/superset/extensions/__init__.py
@@ -122,7 +122,7 @@ async_query_manager: AsyncQueryManager = LocalProxy(
 cache_manager = CacheManager()
 celery_app = celery.Celery()
 csrf = CSRFProtect()
-db = SQLA()
+db = SQLA()  # pylint: disable=disallowed-name
 _event_logger: dict[str, Any] = {}
 encrypted_field_factory = EncryptedFieldFactory()
 event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index ef346dbd62..01b1bf9624 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -64,7 +64,7 @@ def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard)
     if dashboard_id is None:
         return
 
-    session = sqla.inspect(target).session
+    session = sqla.inspect(target).session  # pylint: disable=disallowed-name
     new_user = session.query(User).filter_by(id=target.id).first()
 
     # copy template dashboard to user
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 9322e8c46d..fb2f959f31 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -46,13 +46,13 @@ from jinja2.exceptions import TemplateError
 from sqlalchemy import and_, Column, or_, UniqueConstraint
 from sqlalchemy.exc import MultipleResultsFound
 from sqlalchemy.ext.declarative import declared_attr
-from sqlalchemy.orm import Mapper, Session, validates
+from sqlalchemy.orm import Mapper, validates
 from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause
 from sqlalchemy.sql.expression import Label, Select, TextAsFrom
 from sqlalchemy.sql.selectable import Alias, TableClause
 from sqlalchemy_utils import UUIDType
 
-from superset import app, is_feature_enabled, security_manager
+from superset import app, db, is_feature_enabled, security_manager
 from superset.advanced_data_type.types import AdvancedDataTypeResponse
 from superset.common.db_query_status import QueryStatus
 from superset.common.utils.time_range_utils import get_since_until_from_time_range
@@ -245,7 +245,6 @@ class ImportExportMixin:
     def import_from_dict(
         # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
         cls,
-        session: Session,
         dict_rep: dict[Any, Any],
         parent: Optional[Any] = None,
         recursive: bool = True,
@@ -303,7 +302,7 @@ class ImportExportMixin:
 
         # Check if object already exists in DB, break if more than one is found
         try:
-            obj_query = session.query(cls).filter(and_(*filters))
+            obj_query = db.session.query(cls).filter(and_(*filters))
             obj = obj_query.one_or_none()
         except MultipleResultsFound as ex:
             logger.error(
@@ -322,7 +321,7 @@ class ImportExportMixin:
             logger.info("Importing new %s %s", obj.__tablename__, str(obj))
             if cls.export_parent and parent:
                 setattr(obj, cls.export_parent, parent)
-            session.add(obj)
+            db.session.add(obj)
         else:
             is_new_obj = False
             logger.info("Updating %s %s", obj.__tablename__, str(obj))
@@ -341,7 +340,7 @@ class ImportExportMixin:
                 for c_obj in new_children.get(child, []):
                     added.append(
                         child_class.import_from_dict(
-                            session=session, dict_rep=c_obj, parent=obj, sync=sync
+                            dict_rep=c_obj, parent=obj, sync=sync
                         )
                     )
                 # If children should get synced, delete the ones that did not
@@ -353,11 +352,11 @@ class ImportExportMixin:
                         for k in back_refs.keys()
                     ]
                     to_delete = set(
-                        session.query(child_class).filter(and_(*delete_filters))
+                        db.session.query(child_class).filter(and_(*delete_filters))
                     ).difference(set(added))
                     for o in to_delete:
                         logger.info("Deleting %s %s", child, str(obj))
-                        session.delete(o)
+                        db.session.delete(o)
 
         return obj
 
diff --git a/superset/security/manager.py b/superset/security/manager.py
index ffc4da250f..356ea06852 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -453,7 +453,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
             level=ErrorLevel.ERROR,
         )
 
-    def get_chart_access_error_object(  # pylint: disable=invalid-name
+    def get_chart_access_error_object(
         self,
         dashboard: "Dashboard",  # pylint: disable=unused-argument
     ) -> SupersetError:
@@ -576,7 +576,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
         )
 
         # group all datasources by database
-        all_datasources = SqlaTable.get_all_datasources(self.get_session)
+        all_datasources = SqlaTable.get_all_datasources()
         datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set)
         for datasource in all_datasources:
             datasources_by_database[datasource.database].add(datasource)
@@ -714,7 +714,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
         user_perms = self.user_view_menu_names("datasource_access")
         schema_perms = self.user_view_menu_names("schema_access")
         user_datasources = SqlaTable.query_datasources_by_permissions(
-            self.get_session, database, user_perms, schema_perms
+            database, user_perms, schema_perms
         )
         if schema:
             names = {d.table_name for d in user_datasources if d.schema == schema}
@@ -781,7 +781,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 self.add_permission_view_menu(view_menu, perm)
 
         logger.info("Creating missing datasource permissions.")
-        datasources = SqlaTable.get_all_datasources(self.get_session)
+        datasources = SqlaTable.get_all_datasources()
         for datasource in datasources:
             merge_pv("datasource_access", datasource.get_perm())
             merge_pv("schema_access", datasource.get_schema_perm())
@@ -797,8 +797,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
         """
 
         logger.info("Cleaning faulty perms")
-        sesh = self.get_session
-        pvms = sesh.query(PermissionView).filter(
+        pvms = self.get_session.query(PermissionView).filter(
             or_(
                 PermissionView.permission  # pylint: disable=singleton-comparison
                 == None,
@@ -806,7 +805,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 == None,
             )
         )
-        sesh.commit()
+        self.get_session.commit()
         if deleted_count := pvms.delete():
             logger.info("Deleted %i faulty permissions", deleted_count)
 
@@ -1925,7 +1924,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
 
                 if not (schema_perm and self.can_access("schema_access", schema_perm)):
                     datasources = SqlaTable.query_datasources_by_name(
-                        self.get_session, database, table_.table, schema=table_.schema
+                        database, table_.table, schema=table_.schema
                     )
 
                     # Access to any datasource is suffice.
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index 66f90a6e92..dba54cd3b5 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -66,7 +66,7 @@ class ExecutePayloadSchema(Schema):
 class QueryResultSchema(Schema):
     changed_on = fields.DateTime()
     dbId = fields.Integer()
-    db = fields.String()  # pylint: disable=invalid-name
+    db = fields.String()  # pylint: disable=disallowed-name
     endDttm = fields.Float()
     errorMessage = fields.String(allow_none=True)
     executedSql = fields.String()
diff --git a/superset/tables/models.py b/superset/tables/models.py
index 11f1021197..2616aaf90f 100644
--- a/superset/tables/models.py
+++ b/superset/tables/models.py
@@ -169,7 +169,7 @@ class Table(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
             )
 
         default_props = default_props or {}
-        session: Session = inspect(database).session
+        session: Session = inspect(database).session  # pylint: disable=disallowed-name
         # load existing tables
         predicate = or_(
             *[
diff --git a/superset/tags/models.py b/superset/tags/models.py
index 1e8ca7de1a..7361441940 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -131,7 +131,9 @@ class TaggedObject(Model, AuditMixinNullable):
         return f"<TaggedObject: {self.object_type}:{self.object_id} TAG:{self.tag_id}>"
 
 
-def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag:
+def get_tag(
+    name: str, session: orm.Session, type_: TagType  # pylint: disable=disallowed-name
+) -> Tag:
     tag_name = name.strip()
     tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none()
     if tag is None:
@@ -168,7 +170,7 @@ class ObjectUpdater:
     @classmethod
     def get_owner_tag_ids(
         cls,
-        session: orm.Session,
+        session: orm.Session,  # pylint: disable=disallowed-name
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> set[int]:
         tag_ids = set()
@@ -181,7 +183,7 @@ class ObjectUpdater:
     @classmethod
     def _add_owners(
         cls,
-        session: orm.Session,
+        session: orm.Session,  # pylint: disable=disallowed-name
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
         for owner_id in cls.get_owners_ids(target):
@@ -193,7 +195,11 @@ class ObjectUpdater:
 
     @classmethod
     def add_tag_object_if_not_tagged(
-        cls, session: orm.Session, tag_id: int, object_id: int, object_type: str
+        cls,
+        session: orm.Session,  # pylint: disable=disallowed-name
+        tag_id: int,
+        object_id: int,
+        object_type: str,
     ) -> None:
         # Check if the object is already tagged
         exists_query = exists().where(
@@ -217,7 +223,7 @@ class ObjectUpdater:
         connection: Connection,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
-        with Session(bind=connection) as session:
+        with Session(bind=connection) as session:  # pylint: disable=disallowed-name
             # add `owner:` tags
             cls._add_owners(session, target)
 
@@ -235,7 +241,7 @@ class ObjectUpdater:
         connection: Connection,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
-        with Session(bind=connection) as session:
+        with Session(bind=connection) as session:  # pylint: disable=disallowed-name
             # Fetch current owner tags
             existing_tags = (
                 session.query(TaggedObject)
@@ -274,7 +280,7 @@ class ObjectUpdater:
         connection: Connection,
         target: Dashboard | FavStar | Slice | Query | SqlaTable,
     ) -> None:
-        with Session(bind=connection) as session:
+        with Session(bind=connection) as session:  # pylint: disable=disallowed-name
             # delete row from `tagged_objects`
             session.query(TaggedObject).filter(
                 TaggedObject.object_type == cls.object_type,
@@ -321,7 +327,7 @@ class FavStarUpdater:
     def after_insert(
         cls, _mapper: Mapper, connection: Connection, target: FavStar
     ) -> None:
-        with Session(bind=connection) as session:
+        with Session(bind=connection) as session:  # pylint: disable=disallowed-name
             name = f"favorited_by:{target.user_id}"
             tag = get_tag(name, session, TagType.favorited_by)
             tagged_object = TaggedObject(
@@ -336,7 +342,7 @@ class FavStarUpdater:
     def after_delete(
         cls, _mapper: Mapper, connection: Connection, target: FavStar
     ) -> None:
-        with Session(bind=connection) as session:
+        with Session(bind=connection) as session:  # pylint: disable=disallowed-name
             name = f"favorited_by:{target.user_id}"
             query = (
                 session.query(TaggedObject.id)
diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py
index eef8cbe6df..c21761dadb 100644
--- a/superset/utils/dashboard_import_export.py
+++ b/superset/utils/dashboard_import_export.py
@@ -16,17 +16,16 @@
 # under the License.
 import logging
 
-from sqlalchemy.orm import Session
-
+from superset import db
 from superset.models.dashboard import Dashboard
 
 logger = logging.getLogger(__name__)
 
 
-def export_dashboards(session: Session) -> str:
+def export_dashboards() -> str:
     """Returns all dashboards metadata as a json dump"""
     logger.info("Starting export")
-    dashboards = session.query(Dashboard)
+    dashboards = db.session.query(Dashboard)
     dashboard_ids = set()
     for dashboard in dashboards:
         dashboard_ids.add(dashboard.id)
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index fbd9db7d81..7b3d995249 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -17,8 +17,7 @@
 import logging
 from typing import Any
 
-from sqlalchemy.orm import Session
-
+from superset import db
 from superset.models.core import Database
 
 EXPORT_VERSION = "1.0.0"
@@ -38,11 +37,11 @@ def export_schema_to_dict(back_references: bool) -> dict[str, Any]:
 
 
 def export_to_dict(
-    session: Session, recursive: bool, back_references: bool, include_defaults: bool
+    recursive: bool, back_references: bool, include_defaults: bool
 ) -> dict[str, Any]:
     """Exports databases to a dictionary"""
     logger.info("Starting export")
-    dbs = session.query(Database)
+    dbs = db.session.query(Database)
     databases = [
         database.export_to_dict(
             recursive=recursive,
diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py
index 0040ec60f6..81ad8ccc9f 100644
--- a/tests/integration_tests/base_tests.py
+++ b/tests/integration_tests/base_tests.py
@@ -187,24 +187,22 @@ class SupersetTestCase(TestCase):
         except ImportError:
             return False
 
-    def get_or_create(self, cls, criteria, session, **kwargs):
-        obj = session.query(cls).filter_by(**criteria).first()
+    def get_or_create(self, cls, criteria, **kwargs):
+        obj = db.session.query(cls).filter_by(**criteria).first()
         if not obj:
             obj = cls(**criteria)
         obj.__dict__.update(**kwargs)
-        session.add(obj)
-        session.commit()
+        db.session.add(obj)
+        db.session.commit()
         return obj
 
     def login(self, username="admin", password="general"):
         return login(self.client, username, password)
 
-    def get_slice(
-        self, slice_name: str, session: Session, expunge_from_session: bool = True
-    ) -> Slice:
-        slc = session.query(Slice).filter_by(slice_name=slice_name).one()
+    def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice:
+        slc = db.session.query(Slice).filter_by(slice_name=slice_name).one()
         if expunge_from_session:
-            session.expunge_all()
+            db.session.expunge_all()
         return slc
 
     @staticmethod
@@ -353,7 +351,6 @@ class SupersetTestCase(TestCase):
         return self.get_or_create(
             cls=models.Database,
             criteria={"database_name": database_name},
-            session=db.session,
             sqlalchemy_uri="sqlite:///:memory:",
             id=db_id,
             extra=extra,
@@ -375,7 +372,6 @@ class SupersetTestCase(TestCase):
         database = self.get_or_create(
             cls=models.Database,
             criteria={"database_name": database_name},
-            session=db.session,
             sqlalchemy_uri="db_for_macros_testing://user@host:8080/hive",
             id=db_id,
         )
@@ -398,8 +394,7 @@ class SupersetTestCase(TestCase):
             db.session.commit()
 
     def get_dash_by_slug(self, dash_slug):
-        sesh = db.session()
-        return sesh.query(Dashboard).filter_by(slug=dash_slug).first()
+        return db.session.query(Dashboard).filter_by(slug=dash_slug).first()
 
     def get_assert_metric(self, uri: str, func_name: str) -> Response:
         """
@@ -522,11 +517,10 @@ class SupersetTestCase(TestCase):
 @contextmanager
 def db_insert_temp_object(obj: DeclarativeMeta):
     """Insert a temporary object in database; delete when done."""
-    session = db.session
     try:
-        session.add(obj)
-        session.commit()
+        db.session.add(obj)
+        db.session.commit()
         yield obj
     finally:
-        session.delete(obj)
-        session.commit()
+        db.session.delete(obj)
+        db.session.commit()
diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py
index b2a8704dfb..89093db864 100644
--- a/tests/integration_tests/cache_tests.py
+++ b/tests/integration_tests/cache_tests.py
@@ -46,7 +46,7 @@ class TestCache(SupersetTestCase):
         app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "NullCache"}
         cache_manager.init_app(app)
 
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
         json_endpoint = "/superset/explore_json/{}/{}/".format(
             slc.datasource_type, slc.datasource_id
         )
@@ -73,7 +73,7 @@ class TestCache(SupersetTestCase):
         }
         cache_manager.init_app(app)
 
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
         json_endpoint = "/superset/explore_json/{}/{}/".format(
             slc.datasource_type, slc.datasource_id
         )
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index a58ce1779e..d0985124e2 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -453,7 +453,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
         """
         Chart API: Test create chart
         """
-        dashboards_ids = get_dashboards_ids(db, ["world_health", "births"])
+        dashboards_ids = get_dashboards_ids(["world_health", "births"])
         admin_id = self.get_user("admin").id
         chart_data = {
             "slice_name": "name1",
@@ -1736,7 +1736,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_warm_up_cache(self, slice_name):
         self.login()
-        slc = self.get_slice(slice_name, db.session)
+        slc = self.get_slice(slice_name)
         rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
         self.assertEqual(rv.status_code, 200)
         data = json.loads(rv.data.decode("utf-8"))
@@ -1815,7 +1815,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_warm_up_cache_error(self) -> None:
         self.login()
-        slc = self.get_slice("Pivot Table v2", db.session)
+        slc = self.get_slice("Pivot Table v2")
 
         with mock.patch.object(ChartDataCommand, "run") as mock_run:
             mock_run.side_effect = ChartDataQueryFailedError(
@@ -1843,7 +1843,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_warm_up_cache_no_query_context(self) -> None:
         self.login()
-        slc = self.get_slice("Pivot Table v2", db.session)
+        slc = self.get_slice("Pivot Table v2")
 
         with mock.patch.object(Slice, "get_query_context") as mock_get_query_context:
             mock_get_query_context.return_value = None
@@ -1866,7 +1866,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_warm_up_cache_no_datasource(self) -> None:
         self.login()
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
 
         with mock.patch.object(
             Slice,
diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py
index a72a716d17..6ee3e45b5f 100644
--- a/tests/integration_tests/charts/commands_tests.py
+++ b/tests/integration_tests/charts/commands_tests.py
@@ -413,7 +413,7 @@ class TestChartWarmUpCacheCommand(SupersetTestCase):
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_warm_up_cache(self):
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
         result = ChartWarmUpCacheCommand(slc.id, None, None).run()
         self.assertEqual(
             result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index 9e1a9ad11c..1b1e128b07 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -135,7 +135,7 @@ class TestCore(SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_viz_cache_key(self):
         self.login(username="admin")
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
 
         viz = slc.viz
         qobj = viz.query_obj()
@@ -175,7 +175,7 @@ class TestCore(SupersetTestCase):
     def test_save_slice(self):
         self.login(username="admin")
         slice_name = f"Energy Sankey"
-        slice_id = self.get_slice(slice_name, db.session).id
+        slice_id = self.get_slice(slice_name).id
         copy_name_prefix = "Test Sankey"
         copy_name = f"{copy_name_prefix}[save]{random.random()}"
         tbl_id = self.table_ids.get("energy_usage")
@@ -242,7 +242,6 @@ class TestCore(SupersetTestCase):
         self.login(username="admin")
         slc = self.get_slice(
             slice_name="Top 10 Girl Name Share",
-            session=db.session,
             expunge_from_session=False,
         )
         slc_data_attributes = slc.data.keys()
@@ -356,7 +355,7 @@ class TestCore(SupersetTestCase):
     )
     def test_warm_up_cache(self):
         self.login()
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
         data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
         self.assertEqual(
             data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
@@ -381,7 +380,7 @@ class TestCore(SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_warm_up_cache_error(self) -> None:
         self.login()
-        slc = self.get_slice("Pivot Table v2", db.session)
+        slc = self.get_slice("Pivot Table v2")
 
         with mock.patch.object(
             ChartDataCommand,
@@ -406,7 +405,7 @@ class TestCore(SupersetTestCase):
         self.login("admin")
         store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
         app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
         self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
         ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
         assert ck.datasource_uid == f"{slc.table.id}__table"
@@ -1172,7 +1171,7 @@ class TestCore(SupersetTestCase):
         random_key = "random_key"
         mock_command.return_value = random_key
         slice_name = f"Energy Sankey"
-        slice_id = self.get_slice(slice_name, db.session).id
+        slice_id = self.get_slice(slice_name).id
         form_data = {"slice_id": slice_id, "viz_type": "line", "datasource": "1__table"}
         rv = self.client.get(
             f"/superset/explore/?form_data={quote(json.dumps(form_data))}"
diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py
index d809880bf7..623572c713 100644
--- a/tests/integration_tests/dashboards/api_tests.py
+++ b/tests/integration_tests/dashboards/api_tests.py
@@ -1661,7 +1661,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
         Dashboard API: Test dashboard export
         """
         self.login(username="admin")
-        dashboards_ids = get_dashboards_ids(db, ["world_health", "births"])
+        dashboards_ids = get_dashboards_ids(["world_health", "births"])
         uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}"
 
         rv = self.get_assert_metric(uri, "export")
@@ -1699,7 +1699,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
         """
         Dashboard API: Test dashboard export
         """
-        dashboards_ids = get_dashboards_ids(db, ["world_health", "births"])
+        dashboards_ids = get_dashboards_ids(["world_health", "births"])
         uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}"
 
         self.login(username="admin")
diff --git a/tests/integration_tests/dashboards/filter_state/api_tests.py b/tests/integration_tests/dashboards/filter_state/api_tests.py
index 3538e14012..4dd02bfb65 100644
--- a/tests/integration_tests/dashboards/filter_state/api_tests.py
+++ b/tests/integration_tests/dashboards/filter_state/api_tests.py
@@ -22,6 +22,7 @@ from flask.ctx import AppContext
 from flask_appbuilder.security.sqla.models import User
 from sqlalchemy.orm import Session
 
+from superset import db
 from superset.commands.dashboard.exceptions import DashboardAccessDeniedError
 from superset.commands.temporary_cache.entry import Entry
 from superset.extensions import cache_manager
@@ -40,15 +41,13 @@ UPDATED_VALUE = json.dumps({"test": "updated value"})
 
 @pytest.fixture
 def dashboard_id(app_context: AppContext, load_world_bank_dashboard_with_slices) -> int:
-    session: Session = app_context.app.appbuilder.get_session
-    dashboard = session.query(Dashboard).filter_by(slug="world_health").one()
+    dashboard = db.session.query(Dashboard).filter_by(slug="world_health").one()
     return dashboard.id
 
 
 @pytest.fixture
 def admin_id(app_context: AppContext) -> int:
-    session: Session = app_context.app.appbuilder.get_session
-    admin = session.query(User).filter_by(username="admin").one_or_none()
+    admin = db.session.query(User).filter_by(username="admin").one_or_none()
     return admin.id
 
 
diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py
index a49f1e6f4c..bfa20fd8a3 100644
--- a/tests/integration_tests/dashboards/permalink/api_tests.py
+++ b/tests/integration_tests/dashboards/permalink/api_tests.py
@@ -42,10 +42,8 @@ STATE = {
 
 @pytest.fixture
 def dashboard_id(load_world_bank_dashboard_with_slices) -> int:
-    with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
-        dashboard = session.query(Dashboard).filter_by(slug="world_health").one()
-        return dashboard.id
+    dashboard = db.session.query(Dashboard).filter_by(slug="world_health").one()
+    return dashboard.id
 
 
 @pytest.fixture
diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py
index 88495b03b4..aeae6171df 100644
--- a/tests/integration_tests/dashboards/superset_factory_util.py
+++ b/tests/integration_tests/dashboards/superset_factory_util.py
@@ -38,8 +38,6 @@ from tests.integration_tests.dashboards.dashboard_test_utils import (
 
 logger = logging.getLogger(__name__)
 
-session = db.session
-
 inserted_dashboards_ids = []
 inserted_databases_ids = []
 inserted_sqltables_ids = []
@@ -99,9 +97,9 @@ def create_dashboard(
 
 
 def insert_model(dashboard: Model) -> None:
-    session.add(dashboard)
-    session.commit()
-    session.refresh(dashboard)
+    db.session.add(dashboard)
+    db.session.commit()
+    db.session.refresh(dashboard)
 
 
 def create_slice_to_db(
@@ -193,7 +191,7 @@ def delete_all_inserted_objects() -> None:
 def delete_all_inserted_dashboards():
     try:
         dashboards_to_delete: list[Dashboard] = (
-            session.query(Dashboard)
+            db.session.query(Dashboard)
             .filter(Dashboard.id.in_(inserted_dashboards_ids))
             .all()
         )
@@ -204,7 +202,7 @@ def delete_all_inserted_dashboards():
                 logger.error(f"failed to delete {dashboard.id}", exc_info=True)
                 raise ex
         if len(inserted_dashboards_ids) > 0:
-            session.commit()
+            db.session.commit()
             inserted_dashboards_ids.clear()
     except Exception as ex2:
         logger.error("delete_all_inserted_dashboards failed", exc_info=True)
@@ -216,25 +214,25 @@ def delete_dashboard(dashboard: Dashboard, do_commit: bool = False) -> None:
     delete_dashboard_roles_associations(dashboard)
     delete_dashboard_users_associations(dashboard)
     delete_dashboard_slices_associations(dashboard)
-    session.delete(dashboard)
+    db.session.delete(dashboard)
     if do_commit:
-        session.commit()
+        db.session.commit()
 
 
 def delete_dashboard_users_associations(dashboard: Dashboard) -> None:
-    session.execute(
+    db.session.execute(
         dashboard_user.delete().where(dashboard_user.c.dashboard_id == dashboard.id)
     )
 
 
 def delete_dashboard_roles_associations(dashboard: Dashboard) -> None:
-    session.execute(
+    db.session.execute(
         DashboardRoles.delete().where(DashboardRoles.c.dashboard_id == dashboard.id)
     )
 
 
 def delete_dashboard_slices_associations(dashboard: Dashboard) -> None:
-    session.execute(
+    db.session.execute(
         dashboard_slices.delete().where(dashboard_slices.c.dashboard_id == dashboard.id)
     )
 
@@ -242,7 +240,7 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None:
 def delete_all_inserted_slices():
     try:
         slices_to_delete: list[Slice] = (
-            session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all()
+            db.session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all()
         )
         for slice in slices_to_delete:
             try:
@@ -251,7 +249,7 @@ def delete_all_inserted_slices():
                 logger.error(f"failed to delete {slice.id}", exc_info=True)
                 raise ex
         if len(inserted_slices_ids) > 0:
-            session.commit()
+            db.session.commit()
             inserted_slices_ids.clear()
     except Exception as ex2:
         logger.error("delete_all_inserted_slices failed", exc_info=True)
@@ -261,19 +259,19 @@ def delete_all_inserted_slices():
 def delete_slice(slice_: Slice, do_commit: bool = False) -> None:
     logger.info(f"deleting slice{slice_.id}")
     delete_slice_users_associations(slice_)
-    session.delete(slice_)
+    db.session.delete(slice_)
     if do_commit:
-        session.commit()
+        db.session.commit()
 
 
 def delete_slice_users_associations(slice_: Slice) -> None:
-    session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id))
+    db.session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id))
 
 
 def delete_all_inserted_tables():
     try:
         tables_to_delete: list[SqlaTable] = (
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter(SqlaTable.id.in_(inserted_sqltables_ids))
             .all()
         )
@@ -284,7 +282,7 @@ def delete_all_inserted_tables():
                 logger.error(f"failed to delete {table.id}", exc_info=True)
                 raise ex
         if len(inserted_sqltables_ids) > 0:
-            session.commit()
+            db.session.commit()
             inserted_sqltables_ids.clear()
     except Exception as ex2:
         logger.error("delete_all_inserted_tables failed", exc_info=True)
@@ -294,32 +292,32 @@ def delete_all_inserted_tables():
 def delete_sqltable(table: SqlaTable, do_commit: bool = False) -> None:
     logger.info(f"deleting table{table.id}")
     delete_table_users_associations(table)
-    session.delete(table)
+    db.session.delete(table)
     if do_commit:
-        session.commit()
+        db.session.commit()
 
 
 def delete_table_users_associations(table: SqlaTable) -> None:
-    session.execute(
+    db.session.execute(
         sqlatable_user.delete().where(sqlatable_user.c.table_id == table.id)
     )
 
 
 def delete_all_inserted_dbs():
     try:
-        dbs_to_delete: list[Database] = (
-            session.query(Database)
+        databases_to_delete: list[Database] = (
+            db.session.query(Database)
             .filter(Database.id.in_(inserted_databases_ids))
             .all()
         )
-        for db in dbs_to_delete:
+        for database in databases_to_delete:
             try:
-                delete_database(db, False)
+                delete_database(database, False)
             except Exception as ex:
-                logger.error(f"failed to delete {db.id}", exc_info=True)
+                logger.error(f"failed to delete {database.id}", exc_info=True)
                 raise ex
         if len(inserted_databases_ids) > 0:
-            session.commit()
+            db.session.commit()
             inserted_databases_ids.clear()
     except Exception as ex2:
         logger.error("delete_all_inserted_databases failed", exc_info=True)
@@ -328,6 +326,6 @@ def delete_all_inserted_dbs():
 
 def delete_database(database: Database, do_commit: bool = False) -> None:
     logger.info(f"deleting database{database.id}")
-    session.delete(database)
+    db.session.delete(database)
     if do_commit:
-        session.commit()
+        db.session.commit()
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index f7b8cc0ec8..ebabc16e87 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -1365,12 +1365,11 @@ class TestDatabaseApi(SupersetTestCase):
         """
         Database API: Test get select star with datasource access
         """
-        session = db.session
         table = SqlaTable(
             schema="main", table_name="ab_permission", database=get_main_database()
         )
-        session.add(table)
-        session.commit()
+        db.session.add(table)
+        db.session.commit()
 
         tmp_table_perm = security_manager.find_permission_view_menu(
             "datasource_access", table.get_perm()
@@ -1732,15 +1731,14 @@ class TestDatabaseApi(SupersetTestCase):
         with self.create_app().app_context():
             main_db = get_main_database()
             main_db.allow_file_upload = True
-            session = db.session
             table = SqlaTable(
                 schema="public",
                 table_name="ab_permission",
                 database=get_main_database(),
             )
 
-            session.add(table)
-            session.commit()
+            db.session.add(table)
+            db.session.commit()
             tmp_table_perm = security_manager.find_permission_view_menu(
                 "datasource_access", table.get_perm()
             )
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 3530bdec1a..1ebe5bd1f7 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -1748,7 +1748,6 @@ class TestDatasetApi(SupersetTestCase):
         assert rv.status_code == 200
 
         cli_export = export_to_dict(
-            session=db.session,
             recursive=True,
             back_references=False,
             include_defaults=False,
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index 4e05b63002..91e843fc3f 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -79,7 +79,6 @@ class TestDatasource(SupersetTestCase):
 
     def test_always_filter_main_dttm(self):
         self.login(username="admin")
-        session = db.session
         database = get_example_database()
 
         sql = f"SELECT DATE() as default_dttm, DATE() as additional_dttm, 1 as metric;"
@@ -115,8 +114,8 @@ class TestDatasource(SupersetTestCase):
             sql=sql,
         )
 
-        session.add(table)
-        session.commit()
+        db.session.add(table)
+        db.session.commit()
 
         table.always_filter_main_dttm = False
         result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
@@ -126,27 +125,26 @@ class TestDatasource(SupersetTestCase):
         result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
         assert "default_dttm" in result and "additional_dttm" in result
 
-        session.delete(table)
-        session.commit()
+        db.session.delete(table)
+        db.session.commit()
 
     def test_external_metadata_for_virtual_table(self):
         self.login(username="admin")
-        session = db.session
         table = SqlaTable(
             table_name="dummy_sql_table",
             database=get_example_database(),
             schema=get_example_default_schema(),
             sql="select 123 as intcol, 'abc' as strcol",
         )
-        session.add(table)
-        session.commit()
+        db.session.add(table)
+        db.session.commit()
 
         table = self.get_table(name="dummy_sql_table")
         url = f"/datasource/external_metadata/table/{table.id}/"
         resp = self.get_json_resp(url)
         assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
-        session.delete(table)
-        session.commit()
+        db.session.delete(table)
+        db.session.commit()
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_external_metadata_by_name_for_physical_table(self):
@@ -171,15 +169,14 @@ class TestDatasource(SupersetTestCase):
 
     def test_external_metadata_by_name_for_virtual_table(self):
         self.login(username="admin")
-        session = db.session
         table = SqlaTable(
             table_name="dummy_sql_table",
             database=get_example_database(),
             schema=get_example_default_schema(),
             sql="select 123 as intcol, 'abc' as strcol",
         )
-        session.add(table)
-        session.commit()
+        db.session.add(table)
+        db.session.commit()
 
         tbl = self.get_table(name="dummy_sql_table")
         params = prison.dumps(
@@ -195,8 +192,8 @@ class TestDatasource(SupersetTestCase):
         url = f"/datasource/external_metadata_by_name/?q={params}"
         resp = self.get_json_resp(url)
         assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
-        session.delete(tbl)
-        session.commit()
+        db.session.delete(tbl)
+        db.session.commit()
 
     def test_external_metadata_by_name_from_sqla_inspector(self):
         self.login(username="admin")
@@ -265,7 +262,6 @@ class TestDatasource(SupersetTestCase):
 
     def test_external_metadata_for_virtual_table_template_params(self):
         self.login(username="admin")
-        session = db.session
         table = SqlaTable(
             table_name="dummy_sql_table_with_template_params",
             database=get_example_database(),
@@ -273,15 +269,15 @@ class TestDatasource(SupersetTestCase):
             sql="select {{ foo }} as intcol",
             template_params=json.dumps({"foo": "123"}),
         )
-        session.add(table)
-        session.commit()
+        db.session.add(table)
+        db.session.commit()
 
         table = self.get_table(name="dummy_sql_table_with_template_params")
         url = f"/datasource/external_metadata/table/{table.id}/"
         resp = self.get_json_resp(url)
         assert {o.get("column_name") for o in resp} == {"intcol"}
-        session.delete(table)
-        session.commit()
+        db.session.delete(table)
+        db.session.commit()
 
     def test_external_metadata_for_malicious_virtual_table(self):
         self.login(username="admin")
diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py
index 5ff20b7347..bf4d7e8b9f 100644
--- a/tests/integration_tests/db_engine_specs/databricks_tests.py
+++ b/tests/integration_tests/db_engine_specs/databricks_tests.py
@@ -33,10 +33,10 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
         assert get_engine_spec("databricks", "pyhive").engine == "databricks"
 
     def test_extras_without_ssl(self):
-        db = mock.Mock()
-        db.extra = default_db_extra
-        db.server_cert = None
-        extras = DatabricksNativeEngineSpec.get_extra_params(db)
+        database = mock.Mock()
+        database.extra = default_db_extra
+        database.server_cert = None
+        extras = DatabricksNativeEngineSpec.get_extra_params(database)
         assert extras == {
             "engine_params": {
                 "connect_args": {
@@ -50,12 +50,12 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
         }
 
     def test_extras_with_ssl_custom(self):
-        db = mock.Mock()
-        db.extra = default_db_extra.replace(
+        database = mock.Mock()
+        database.extra = default_db_extra.replace(
             '"engine_params": {}',
             '"engine_params": {"connect_args": {"ssl": "1"}}',
         )
-        db.server_cert = ssl_certificate
-        extras = DatabricksNativeEngineSpec.get_extra_params(db)
+        database.server_cert = ssl_certificate
+        extras = DatabricksNativeEngineSpec.get_extra_params(database)
         connect_args = extras["engine_params"]["connect_args"]
         assert connect_args["ssl"] == "1"
diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py
index 341b494927..374d99c02e 100644
--- a/tests/integration_tests/db_engine_specs/hive_tests.py
+++ b/tests/integration_tests/db_engine_specs/hive_tests.py
@@ -337,14 +337,14 @@ def test_fetch_data_success(fetch_data_mock):
 @mock.patch("superset.db_engine_specs.hive.HiveEngineSpec._latest_partition_from_df")
 def test_where_latest_partition(mock_method):
     mock_method.return_value = ("01-01-19", 1)
-    db = mock.Mock()
-    db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
-    db.get_extra = mock.Mock(return_value={})
-    db.get_df = mock.Mock()
+    database = mock.Mock()
+    database.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
+    database.get_extra = mock.Mock(return_value={})
+    database.get_df = mock.Mock()
     columns = [{"name": "ds"}, {"name": "hour"}]
     with app.app_context():
         result = HiveEngineSpec.where_latest_partition(
-            "test_table", "test_schema", db, select(), columns
+            "test_table", "test_schema", database, select(), columns
         )
     query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
     assert "SELECT  \nWHERE ds = '01-01-19' AND hour = 1" == query_result
@@ -353,11 +353,11 @@ def test_where_latest_partition(mock_method):
 @mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
 def test_where_latest_partition_super_method_exception(mock_method):
     mock_method.side_effect = Exception()
-    db = mock.Mock()
+    database = mock.Mock()
     columns = [{"name": "ds"}, {"name": "hour"}]
     with app.app_context():
         result = HiveEngineSpec.where_latest_partition(
-            "test_table", "test_schema", db, select(), columns
+            "test_table", "test_schema", database, select(), columns
         )
     assert result is None
     mock_method.assert_called()
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index 2b543e8e25..0f4841fb35 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -119,29 +119,29 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
         assert "postgres" in backends
 
     def test_extras_without_ssl(self):
-        db = mock.Mock()
-        db.extra = default_db_extra
-        db.server_cert = None
-        extras = PostgresEngineSpec.get_extra_params(db)
+        database = mock.Mock()
+        database.extra = default_db_extra
+        database.server_cert = None
+        extras = PostgresEngineSpec.get_extra_params(database)
         assert "connect_args" not in extras["engine_params"]
 
     def test_extras_with_ssl_default(self):
-        db = mock.Mock()
-        db.extra = default_db_extra
-        db.server_cert = ssl_certificate
-        extras = PostgresEngineSpec.get_extra_params(db)
+        database = mock.Mock()
+        database.extra = default_db_extra
+        database.server_cert = ssl_certificate
+        extras = PostgresEngineSpec.get_extra_params(database)
         connect_args = extras["engine_params"]["connect_args"]
         assert connect_args["sslmode"] == "verify-full"
         assert "sslrootcert" in connect_args
 
     def test_extras_with_ssl_custom(self):
-        db = mock.Mock()
-        db.extra = default_db_extra.replace(
+        database = mock.Mock()
+        database.extra = default_db_extra.replace(
             '"engine_params": {}',
             '"engine_params": {"connect_args": {"sslmode": "verify-ca"}}',
         )
-        db.server_cert = ssl_certificate
-        extras = PostgresEngineSpec.get_extra_params(db)
+        database.server_cert = ssl_certificate
+        extras = PostgresEngineSpec.get_extra_params(database)
         connect_args = extras["engine_params"]["connect_args"]
         assert connect_args["sslmode"] == "verify-ca"
         assert "sslrootcert" in connect_args
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 7e151648a6..c28e78afe6 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -550,13 +550,17 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
         self.assertEqual(actual_expanded_cols, expected_expanded_cols)
 
     def test_presto_extra_table_metadata(self):
-        db = mock.Mock()
-        db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
-        db.get_extra = mock.Mock(return_value={})
+        database = mock.Mock()
+        database.get_indexes = mock.Mock(
+            return_value=[{"column_names": ["ds", "hour"]}]
+        )
+        database.get_extra = mock.Mock(return_value={})
         df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
-        db.get_df = mock.Mock(return_value=df)
+        database.get_df = mock.Mock(return_value=df)
         PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
-        result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
+        result = PrestoEngineSpec.extra_table_metadata(
+            database, "test_table", "test_schema"
+        )
         assert result["partitions"]["cols"] == ["ds", "hour"]
         assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
 
diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py
index 6018e59a92..b4dddff09f 100644
--- a/tests/integration_tests/dict_import_export_tests.py
+++ b/tests/integration_tests/dict_import_export_tests.py
@@ -43,11 +43,10 @@ class TestDictImportExport(SupersetTestCase):
     def delete_imports(cls):
         with app.app_context():
             # Imported data clean up
-            session = db.session
-            for table in session.query(SqlaTable):
+            for table in db.session.query(SqlaTable):
                 if DBREF in table.params_dict:
-                    session.delete(table)
-            session.commit()
+                    db.session.delete(table)
+            db.session.commit()
 
     @classmethod
     def setUpClass(cls):
@@ -124,7 +123,7 @@ class TestDictImportExport(SupersetTestCase):
 
     def test_import_table_no_metadata(self):
         table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
-        new_table = SqlaTable.import_from_dict(db.session, dict_table)
+        new_table = SqlaTable.import_from_dict(dict_table)
         db.session.commit()
         imported_id = new_table.id
         imported = self.get_table_by_id(imported_id)
@@ -139,7 +138,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_uuids=[uuid4()],
             metric_names=["metric1"],
         )
-        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        imported_table = SqlaTable.import_from_dict(dict_table)
         db.session.commit()
         imported = self.get_table_by_id(imported_table.id)
         self.assert_table_equals(table, imported)
@@ -156,7 +155,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_uuids=[uuid4(), uuid4()],
             metric_names=["m1", "m2"],
         )
-        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        imported_table = SqlaTable.import_from_dict(dict_table)
         db.session.commit()
         imported = self.get_table_by_id(imported_table.id)
         self.assert_table_equals(table, imported)
@@ -166,7 +165,7 @@ class TestDictImportExport(SupersetTestCase):
         table, dict_table = self.create_table(
             "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
         )
-        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        imported_table = SqlaTable.import_from_dict(dict_table)
         db.session.commit()
         table_over, dict_table_over = self.create_table(
             "table_override",
@@ -174,7 +173,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
+        imported_over_table = SqlaTable.import_from_dict(dict_table_over)
         db.session.commit()
 
         imported_over = self.get_table_by_id(imported_over_table.id)
@@ -195,7 +194,7 @@ class TestDictImportExport(SupersetTestCase):
         table, dict_table = self.create_table(
             "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
         )
-        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        imported_table = SqlaTable.import_from_dict(dict_table)
         db.session.commit()
         table_over, dict_table_over = self.create_table(
             "table_override",
@@ -204,7 +203,7 @@ class TestDictImportExport(SupersetTestCase):
             metric_names=["new_metric1"],
         )
         imported_over_table = SqlaTable.import_from_dict(
-            session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
+            dict_rep=dict_table_over, sync=["metrics", "columns"]
         )
         db.session.commit()
 
@@ -229,7 +228,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+        imported_table = SqlaTable.import_from_dict(dict_table)
         db.session.commit()
         copy_table, dict_copy_table = self.create_table(
             "copy_cat",
@@ -237,7 +236,7 @@ class TestDictImportExport(SupersetTestCase):
             cols_names=["new_col1", "col2", "col3"],
             metric_names=["new_metric1"],
         )
-        imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
+        imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
         db.session.commit()
         self.assertEqual(imported_table.id, imported_copy_table.id)
         self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
@@ -250,7 +249,6 @@ class TestDictImportExport(SupersetTestCase):
         self.delete_fake_db()
 
         cli_export = export_to_dict(
-            session=db.session,
             recursive=True,
             back_references=False,
             include_defaults=False,
diff --git a/tests/integration_tests/explore/api_tests.py b/tests/integration_tests/explore/api_tests.py
index e37200e310..c0b7f5fcd4 100644
--- a/tests/integration_tests/explore/api_tests.py
+++ b/tests/integration_tests/explore/api_tests.py
@@ -21,6 +21,7 @@ import pytest
 from flask_appbuilder.security.sqla.models import User
 from sqlalchemy.orm import Session
 
+from superset import db
 from superset.commands.explore.form_data.state import TemporaryExploreState
 from superset.connectors.sqla.models import SqlaTable
 from superset.explore.exceptions import DatasetAccessDeniedError
@@ -39,25 +40,22 @@ FORM_DATA = {"test": "test value"}
 @pytest.fixture
 def chart_id(load_world_bank_dashboard_with_slices) -> int:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
-        chart = session.query(Slice).filter_by(slice_name="World's Population").one()
+        chart = db.session.query(Slice).filter_by(slice_name="World's Population").one()
         return chart.id
 
 
 @pytest.fixture
 def admin_id() -> int:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
-        admin = session.query(User).filter_by(username="admin").one()
+        admin = db.session.query(User).filter_by(username="admin").one()
         return admin.id
 
 
 @pytest.fixture
 def dataset() -> int:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
         dataset = (
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter_by(table_name="wb_health_population")
             .first()
         )
diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py
index 5dbd67d4f5..9187e46213 100644
--- a/tests/integration_tests/explore/form_data/api_tests.py
+++ b/tests/integration_tests/explore/form_data/api_tests.py
@@ -21,6 +21,7 @@ import pytest
 from flask_appbuilder.security.sqla.models import User
 from sqlalchemy.orm import Session
 
+from superset import db
 from superset.commands.dataset.exceptions import DatasetAccessDeniedError
 from superset.commands.explore.form_data.state import TemporaryExploreState
 from superset.connectors.sqla.models import SqlaTable
@@ -41,25 +42,22 @@ UPDATED_FORM_DATA = json.dumps({"test": "updated value"})
 @pytest.fixture
 def chart_id(load_world_bank_dashboard_with_slices) -> int:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
-        chart = session.query(Slice).filter_by(slice_name="World's Population").one()
+        chart = db.session.query(Slice).filter_by(slice_name="World's Population").one()
         return chart.id
 
 
 @pytest.fixture
 def admin_id() -> int:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
-        admin = session.query(User).filter_by(username="admin").one()
+        admin = db.session.query(User).filter_by(username="admin").one()
         return admin.id
 
 
 @pytest.fixture
 def datasource() -> int:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
         dataset = (
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter_by(table_name="wb_health_population")
             .first()
         )
diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py
index 781c4fdbb2..293a2c556f 100644
--- a/tests/integration_tests/explore/form_data/commands_tests.py
+++ b/tests/integration_tests/explore/form_data/commands_tests.py
@@ -45,22 +45,22 @@ class TestCreateFormDataCommand(SupersetTestCase):
                 schema=get_example_default_schema(),
                 sql="select 123 as intcol, 'abc' as strcol",
             )
-            session = db.session
-            session.add(dataset)
-            session.commit()
+            db.session.add(dataset)
+            db.session.commit()
 
             yield dataset
 
             # rollback
-            session.delete(dataset)
-            session.commit()
+            db.session.delete(dataset)
+            db.session.commit()
 
     @pytest.fixture()
     def create_slice(self):
         with self.create_app().app_context():
-            session = db.session
             dataset = (
-                session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
+                db.session.query(SqlaTable)
+                .filter_by(table_name="dummy_sql_table")
+                .first()
             )
             slice = Slice(
                 datasource_id=dataset.id,
@@ -69,34 +69,32 @@ class TestCreateFormDataCommand(SupersetTestCase):
                 slice_name="slice_name",
             )
 
-            session.add(slice)
-            session.commit()
+            db.session.add(slice)
+            db.session.commit()
 
             yield slice
 
             # rollback
-            session.delete(slice)
-            session.commit()
+            db.session.delete(slice)
+            db.session.commit()
 
     @pytest.fixture()
     def create_query(self):
         with self.create_app().app_context():
-            session = db.session
-
             query = Query(
                 sql="select 1 as foo;",
                 client_id="sldkfjlk",
                 database=get_example_database(),
             )
 
-            session.add(query)
-            session.commit()
+            db.session.add(query)
+            db.session.commit()
 
             yield query
 
             # rollback
-            session.delete(query)
-            session.commit()
+            db.session.delete(query)
+            db.session.commit()
 
     @patch("superset.security.manager.g")
     @pytest.mark.usefixtures("create_dataset", "create_slice")
diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py
index 81be2f0de8..a171504cc6 100644
--- a/tests/integration_tests/explore/permalink/api_tests.py
+++ b/tests/integration_tests/explore/permalink/api_tests.py
@@ -38,8 +38,7 @@ from tests.integration_tests.test_app import app
 
 @pytest.fixture
 def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice:
-    session: Session = app_context.app.appbuilder.get_session
-    chart = session.query(Slice).filter_by(slice_name="World's Population").one()
+    chart = db.session.query(Slice).filter_by(slice_name="World's Population").one()
     return chart
 
 
diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py
index 5402a419bc..f499591aa5 100644
--- a/tests/integration_tests/explore/permalink/commands_tests.py
+++ b/tests/integration_tests/explore/permalink/commands_tests.py
@@ -43,22 +43,22 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
                 schema=get_example_default_schema(),
                 sql="select 123 as intcol, 'abc' as strcol",
             )
-            session = db.session
-            session.add(dataset)
-            session.commit()
+            db.session.add(dataset)
+            db.session.commit()
 
             yield dataset
 
             # rollback
-            session.delete(dataset)
-            session.commit()
+            db.session.delete(dataset)
+            db.session.commit()
 
     @pytest.fixture()
     def create_slice(self):
         with self.create_app().app_context():
-            session = db.session
             dataset = (
-                session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
+                db.session.query(SqlaTable)
+                .filter_by(table_name="dummy_sql_table")
+                .first()
             )
             slice = Slice(
                 datasource_id=dataset.id,
@@ -67,34 +67,32 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
                 slice_name="slice_name",
             )
 
-            session.add(slice)
-            session.commit()
+            db.session.add(slice)
+            db.session.commit()
 
             yield slice
 
             # rollback
-            session.delete(slice)
-            session.commit()
+            db.session.delete(slice)
+            db.session.commit()
 
     @pytest.fixture()
     def create_query(self):
         with self.create_app().app_context():
-            session = db.session
-
             query = Query(
                 sql="select 1 as foo;",
                 client_id="sldkfjlk",
                 database=get_example_database(),
             )
 
-            session.add(query)
-            session.commit()
+            db.session.add(query)
+            db.session.commit()
 
             yield query
 
             # rollback
-            session.delete(query)
-            session.commit()
+            db.session.delete(query)
+            db.session.commit()
 
     @patch("superset.security.manager.g")
     @pytest.mark.usefixtures("create_dataset", "create_slice")
diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py
index 279b67eda0..fd7c69deca 100644
--- a/tests/integration_tests/fixtures/datasource.py
+++ b/tests/integration_tests/fixtures/datasource.py
@@ -177,7 +177,6 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
     with app.app_context():
         engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True)
         meta = MetaData()
-        session = db.session
 
         students = Table(
             "students",
@@ -196,8 +195,8 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
         )
         column = TableColumn(table_id=dataset.id, column_name="name")
         dataset.columns = [column]
-        session.add(dataset)
-        session.commit()
+        db.session.add(dataset)
+        db.session.commit()
         yield dataset
 
         # cleanup
@@ -205,8 +204,8 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
         if students_table is not None:
             base = declarative_base()
             # needed for sqlite
-            session.commit()
+            db.session.commit()
             base.metadata.drop_all(engine, [students_table], checkfirst=True)
-        session.delete(dataset)
-        session.delete(column)
-        session.commit()
+        db.session.delete(dataset)
+        db.session.delete(column)
+        db.session.commit()
diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py
index adc398e785..4a1558ffd8 100644
--- a/tests/integration_tests/import_export_tests.py
+++ b/tests/integration_tests/import_export_tests.py
@@ -53,17 +53,16 @@ from .base_tests import SupersetTestCase
 def delete_imports():
     with app.app_context():
         # Imported data clean up
-        session = db.session
-        for slc in session.query(Slice):
+        for slc in db.session.query(Slice):
             if "remote_id" in slc.params_dict:
-                session.delete(slc)
-        for dash in session.query(Dashboard):
+                db.session.delete(slc)
+        for dash in db.session.query(Dashboard):
             if "remote_id" in dash.params_dict:
-                session.delete(dash)
-        for table in session.query(SqlaTable):
+                db.session.delete(dash)
+        for table in db.session.query(SqlaTable):
             if "remote_id" in table.params_dict:
-                session.delete(table)
-        session.commit()
+                db.session.delete(table)
+        db.session.commit()
 
 
 @pytest.fixture(autouse=True, scope="module")
diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py
index ac33d003e0..6ba09c8a18 100644
--- a/tests/integration_tests/key_value/commands/fixtures.py
+++ b/tests/integration_tests/key_value/commands/fixtures.py
@@ -66,6 +66,5 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]:
 @pytest.fixture
 def admin() -> User:
     with app.app_context() as ctx:
-        session: Session = ctx.app.appbuilder.get_session
-        admin = session.query(User).filter_by(username="admin").one()
+        admin = db.session.query(User).filter_by(username="admin").one()
         return admin
diff --git a/tests/integration_tests/security/guest_token_security_tests.py b/tests/integration_tests/security/guest_token_security_tests.py
index b812929433..44a4cdd3ce 100644
--- a/tests/integration_tests/security/guest_token_security_tests.py
+++ b/tests/integration_tests/security/guest_token_security_tests.py
@@ -230,15 +230,14 @@ class TestGuestUserDatasourceAccess(SupersetTestCase):
                 schema=get_example_default_schema(),
                 sql="select 123 as intcol, 'abc' as strcol",
             )
-            session = db.session
-            session.add(dataset)
-            session.commit()
+            db.session.add(dataset)
+            db.session.commit()
 
             yield dataset
 
             # rollback
-            session.delete(dataset)
-            session.commit()
+            db.session.delete(dataset)
+            db.session.commit()
 
     def setUp(self) -> None:
         self.dash = self.get_dash_by_slug("births")
@@ -258,11 +257,9 @@ class TestGuestUserDatasourceAccess(SupersetTestCase):
                 ],
             }
         )
-        self.chart = self.get_slice("Girls", db.session, expunge_from_session=False)
+        self.chart = self.get_slice("Girls", expunge_from_session=False)
         self.datasource = self.chart.datasource
-        self.other_chart = self.get_slice(
-            "Treemap", db.session, expunge_from_session=False
-        )
+        self.other_chart = self.get_slice("Treemap", expunge_from_session=False)
         self.other_datasource = self.other_chart.datasource
         self.native_filter_datasource = (
             db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py
index 39d66a82aa..4ab73a713e 100644
--- a/tests/integration_tests/security/migrate_roles_tests.py
+++ b/tests/integration_tests/security/migrate_roles_tests.py
@@ -245,11 +245,10 @@ def test_migrate_role(
     logger.info(description)
     with create_old_role(pvm_map, external_pvms) as old_role:
         role_name = old_role.name
-        session = db.session
 
         # Run migrations
-        add_pvms(session, new_pvms)
-        migrate_roles(session, pvm_map)
+        add_pvms(db.session, new_pvms)
+        migrate_roles(db.session, pvm_map)
 
         role = db.session.query(Role).filter(Role.name == role_name).one_or_none()
         for old_pvm, new_pvms in pvm_map.items():
diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py
index 41ca0d5e79..7518621ddd 100644
--- a/tests/integration_tests/security/row_level_security_tests.py
+++ b/tests/integration_tests/security/row_level_security_tests.py
@@ -74,8 +74,6 @@ class TestRowLevelSecurity(SupersetTestCase):
     BASE_FILTER_REGEX = re.compile(r"gender = 'boy'")
 
     def setUp(self):
-        session = db.session
-
         # Create roles
         self.role_ab = security_manager.add_role(self.NAME_AB_ROLE)
         self.role_q = security_manager.add_role(self.NAME_Q_ROLE)
@@ -83,13 +81,13 @@ class TestRowLevelSecurity(SupersetTestCase):
         gamma_user.roles.append(self.role_ab)
         gamma_user.roles.append(self.role_q)
         self.create_user_with_roles("NoRlsRoleUser", ["Gamma"])
-        session.commit()
+        db.session.commit()
 
         # Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
         self.rls_entry1 = RowLevelSecurityFilter()
         self.rls_entry1.name = "rls_entry1"
         self.rls_entry1.tables.extend(
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
             .all()
         )
@@ -104,7 +102,7 @@ class TestRowLevelSecurity(SupersetTestCase):
         self.rls_entry2 = RowLevelSecurityFilter()
         self.rls_entry2.name = "rls_entry2"
         self.rls_entry2.tables.extend(
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter(SqlaTable.table_name.in_(["birth_names"]))
             .all()
         )
@@ -118,7 +116,7 @@ class TestRowLevelSecurity(SupersetTestCase):
         self.rls_entry3 = RowLevelSecurityFilter()
         self.rls_entry3.name = "rls_entry3"
         self.rls_entry3.tables.extend(
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter(SqlaTable.table_name.in_(["birth_names"]))
             .all()
         )
@@ -132,7 +130,7 @@ class TestRowLevelSecurity(SupersetTestCase):
         self.rls_entry4 = RowLevelSecurityFilter()
         self.rls_entry4.name = "rls_entry4"
         self.rls_entry4.tables.extend(
-            session.query(SqlaTable)
+            db.session.query(SqlaTable)
             .filter(SqlaTable.table_name.in_(["birth_names"]))
             .all()
         )
@@ -145,15 +143,14 @@ class TestRowLevelSecurity(SupersetTestCase):
         db.session.commit()
 
     def tearDown(self):
-        session = db.session
-        session.delete(self.rls_entry1)
-        session.delete(self.rls_entry2)
-        session.delete(self.rls_entry3)
-        session.delete(self.rls_entry4)
-        session.delete(security_manager.find_role("NameAB"))
-        session.delete(security_manager.find_role("NameQ"))
-        session.delete(self.get_user("NoRlsRoleUser"))
-        session.commit()
+        db.session.delete(self.rls_entry1)
+        db.session.delete(self.rls_entry2)
+        db.session.delete(self.rls_entry3)
+        db.session.delete(self.rls_entry4)
+        db.session.delete(security_manager.find_role("NameAB"))
+        db.session.delete(security_manager.find_role("NameQ"))
+        db.session.delete(self.get_user("NoRlsRoleUser"))
+        db.session.commit()
 
     @pytest.fixture()
     def create_dataset(self):
diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py
index ece9afcccb..b1f66b0d6c 100644
--- a/tests/integration_tests/security_tests.py
+++ b/tests/integration_tests/security_tests.py
@@ -1704,11 +1704,11 @@ class TestSecurityManager(SupersetTestCase):
         mock_is_owner,
     ):
         births = self.get_dash_by_slug("births")
-        girls = self.get_slice("Girls", db.session, expunge_from_session=False)
+        girls = self.get_slice("Girls", expunge_from_session=False)
         birth_names = girls.datasource
 
         world_health = self.get_dash_by_slug("world_health")
-        treemap = self.get_slice("Treemap", db.session, expunge_from_session=False)
+        treemap = self.get_slice("Treemap", expunge_from_session=False)
 
         births.json_metadata = json.dumps(
             {
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index 4410f19782..0dc4e26aca 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -434,8 +434,6 @@ class TestSqlLab(SupersetTestCase):
         Test query api with can_access_all_queries perm added to
         gamma and make sure all queries show up.
         """
-        session = db.session
-
         # Add all_query_access perm to Gamma user
         all_queries_view = security_manager.find_permission_view_menu(
             "all_query_access", "all_query_access"
@@ -444,7 +442,7 @@ class TestSqlLab(SupersetTestCase):
         security_manager.add_permission_role(
             security_manager.find_role("gamma_sqllab"), all_queries_view
         )
-        session.commit()
+        db.session.commit()
 
         # Test search_queries for Admin user
         self.run_some_queries()
@@ -461,7 +459,7 @@ class TestSqlLab(SupersetTestCase):
             security_manager.find_role("gamma_sqllab"), all_queries_view
         )
 
-        session.commit()
+        db.session.commit()
 
     def test_query_admin_can_access_all_queries(self) -> None:
         """
diff --git a/tests/integration_tests/test_jinja_context.py b/tests/integration_tests/test_jinja_context.py
index 8c2db6920d..6f776017fb 100644
--- a/tests/integration_tests/test_jinja_context.py
+++ b/tests/integration_tests/test_jinja_context.py
@@ -114,10 +114,10 @@ def test_template_hive(app_context: AppContext, mocker: MockFixture) -> None:
         "superset.jinja_context.HiveTemplateProcessor.latest_partition"
     )
     lp_mock.return_value = "the_latest"
-    db = mock.Mock()
-    db.backend = "hive"
+    database = mock.Mock()
+    database.backend = "hive"
     template = "{{ hive.latest_partition('my_table') }}"
-    tp = get_template_processor(database=db)
+    tp = get_template_processor(database=database)
     assert tp.process_template(template) == "the_latest"
 
 
@@ -126,15 +126,15 @@ def test_template_trino(app_context: AppContext, mocker: MockFixture) -> None:
         "superset.jinja_context.TrinoTemplateProcessor.latest_partition"
     )
     lp_mock.return_value = "the_latest"
-    db = mock.Mock()
-    db.backend = "trino"
+    database = mock.Mock()
+    database.backend = "trino"
     template = "{{ trino.latest_partition('my_table') }}"
-    tp = get_template_processor(database=db)
+    tp = get_template_processor(database=database)
     assert tp.process_template(template) == "the_latest"
 
     # Backwards compatibility if migrating from Presto.
     template = "{{ presto.latest_partition('my_table') }}"
-    tp = get_template_processor(database=db)
+    tp = get_template_processor(database=database)
     assert tp.process_template(template) == "the_latest"
 
 
@@ -154,9 +154,9 @@ def test_custom_process_template(app_context: AppContext, mocker: MockFixture) -
         "tests.integration_tests.superset_test_custom_template_processors.datetime"
     )
     mock_dt.utcnow = mock.Mock(return_value=datetime(1970, 1, 1))
-    db = mock.Mock()
-    db.backend = "db_for_macros_testing"
-    tp = get_template_processor(database=db)
+    database = mock.Mock()
+    database.backend = "db_for_macros_testing"
+    tp = get_template_processor(database=database)
 
     template = "SELECT '$DATE()'"
     assert tp.process_template(template) == f"SELECT '1970-01-01'"
@@ -168,28 +168,28 @@ def test_custom_process_template(app_context: AppContext, mocker: MockFixture) -
 def test_custom_get_template_kwarg(app_context: AppContext) -> None:
     """Test macro passed as kwargs when getting template processor
     works in custom template processor."""
-    db = mock.Mock()
-    db.backend = "db_for_macros_testing"
+    database = mock.Mock()
+    database.backend = "db_for_macros_testing"
     template = "$foo()"
-    tp = get_template_processor(database=db, foo=lambda: "bar")
+    tp = get_template_processor(database=database, foo=lambda: "bar")
     assert tp.process_template(template) == "bar"
 
 
 def test_custom_template_kwarg(app_context: AppContext) -> None:
     """Test macro passed as kwargs when processing template
     works in custom template processor."""
-    db = mock.Mock()
-    db.backend = "db_for_macros_testing"
+    database = mock.Mock()
+    database.backend = "db_for_macros_testing"
     template = "$foo()"
-    tp = get_template_processor(database=db)
+    tp = get_template_processor(database=database)
     assert tp.process_template(template, foo=lambda: "bar") == "bar"
 
 
 def test_custom_template_processors_overwrite(app_context: AppContext) -> None:
     """Test template processor for presto gets overwritten by custom one."""
-    db = mock.Mock()
-    db.backend = "db_for_macros_testing"
-    tp = get_template_processor(database=db)
+    database = mock.Mock()
+    database.backend = "db_for_macros_testing"
+    tp = get_template_processor(database=database)
 
     template = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
     assert tp.process_template(template) == template
diff --git a/tests/integration_tests/utils/get_dashboards.py b/tests/integration_tests/utils/get_dashboards.py
index 7012bf08a0..b23b372310 100644
--- a/tests/integration_tests/utils/get_dashboards.py
+++ b/tests/integration_tests/utils/get_dashboards.py
@@ -15,12 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from flask_appbuilder import SQLA
-
+from superset import db
 from superset.models.dashboard import Dashboard
 
 
-def get_dashboards_ids(db: SQLA, dashboard_slugs: list[str]) -> list[int]:
+def get_dashboards_ids(dashboard_slugs: list[str]) -> list[int]:
     result = (
         db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all()
     )
diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py
index bdbb912eec..b4ab08dc55 100644
--- a/tests/integration_tests/utils_tests.py
+++ b/tests/integration_tests/utils_tests.py
@@ -898,7 +898,7 @@ class TestUtils(SupersetTestCase):
     def test_log_this(self) -> None:
         # TODO: Add additional scenarios.
         self.login(username="admin")
-        slc = self.get_slice("Top 10 Girl Name Share", db.session)
+        slc = self.get_slice("Top 10 Girl Name Share")
         dashboard_id = 1
 
         assert slc.viz is not None
@@ -956,7 +956,7 @@ class TestUtils(SupersetTestCase):
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_extract_dataframe_dtypes(self):
-        slc = self.get_slice("Girls", db.session)
+        slc = self.get_slice("Girls")
         cols: tuple[tuple[str, GenericDataType, list[Any]], ...] = (
             ("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
             (
diff --git a/tests/unit_tests/charts/commands/importers/v1/import_test.py b/tests/unit_tests/charts/commands/importers/v1/import_test.py
index bcff3ee411..e6f6d00206 100644
--- a/tests/unit_tests/charts/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/charts/commands/importers/v1/import_test.py
@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
-from superset import security_manager
+from superset import db, security_manager
 from superset.commands.chart.importers.v1.utils import import_chart
 from superset.commands.exceptions import ImportFailedError
 from superset.connectors.sqla.models import Database, SqlaTable
@@ -82,7 +82,7 @@ def test_import_chart(mocker: MockFixture, session_with_schema: Session) -> None
     config["datasource_id"] = 1
     config["datasource_type"] = "table"
 
-    chart = import_chart(session_with_schema, config)
+    chart = import_chart(config)
     assert chart.slice_name == "Deck Path"
     assert chart.viz_type == "deck_path"
     assert chart.is_managed_externally is False
@@ -106,7 +106,7 @@ def test_import_chart_managed_externally(
     config["is_managed_externally"] = True
     config["external_url"] = "https://example.org/my_chart"
 
-    chart = import_chart(session_with_schema, config)
+    chart = import_chart(config)
     assert chart.is_managed_externally is True
     assert chart.external_url == "https://example.org/my_chart"
 
@@ -128,7 +128,7 @@ def test_import_chart_without_permission(
     config["datasource_type"] = "table"
 
     with pytest.raises(ImportFailedError) as excinfo:
-        import_chart(session_with_schema, config)
+        import_chart(config)
     assert (
         str(excinfo.value)
         == "Chart doesn't exist and user doesn't have permission to create charts"
@@ -173,7 +173,7 @@ def test_import_existing_chart_without_permission(
 
     with override_user("admin"):
         with pytest.raises(ImportFailedError) as excinfo:
-            import_chart(session_with_data, chart_config, overwrite=True)
+            import_chart(chart_config, overwrite=True)
         assert (
             str(excinfo.value)
             == "A chart already exists and user doesn't have permissions to overwrite it"
@@ -213,7 +213,7 @@ def test_import_existing_chart_with_permission(
     )
 
     with override_user(admin):
-        import_chart(session_with_data, config, overwrite=True)
+        import_chart(config, overwrite=True)
     # Assert that the can write to chart was checked
     security_manager.can_access.assert_called_once_with("can_write", "Chart")
     security_manager.can_access_chart.assert_called_once_with(slice)
diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py
index e8c58b5600..e811223a98 100644
--- a/tests/unit_tests/charts/dao/dao_tests.py
+++ b/tests/unit_tests/charts/dao/dao_tests.py
@@ -48,7 +48,7 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
     from superset.daos.chart import ChartDAO
     from superset.models.slice import Slice
 
-    result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
+    result = ChartDAO.find_by_id(1, skip_base_filter=True)
 
     assert result
     assert 1 == result.id
@@ -57,20 +57,18 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
 
 
 def test_datasource_find_by_id_skip_base_filter_not_found(
-    session_with_data: Session,
+    session: Session,
 ) -> None:
     from superset.daos.chart import ChartDAO
 
-    result = ChartDAO.find_by_id(
-        125326326, session=session_with_data, skip_base_filter=True
-    )
+    result = ChartDAO.find_by_id(125326326, skip_base_filter=True)
     assert result is None
 
 
-def test_add_favorite(session_with_data: Session) -> None:
+def test_add_favorite(session: Session) -> None:
     from superset.daos.chart import ChartDAO
 
-    chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
+    chart = ChartDAO.find_by_id(1, skip_base_filter=True)
     if not chart:
         return
     assert len(ChartDAO.favorited_ids([chart])) == 0
@@ -82,10 +80,10 @@ def test_add_favorite(session_with_data: Session) -> None:
     assert len(ChartDAO.favorited_ids([chart])) == 1
 
 
-def test_remove_favorite(session_with_data: Session) -> None:
+def test_remove_favorite(session: Session) -> None:
     from superset.daos.chart import ChartDAO
 
-    chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
+    chart = ChartDAO.find_by_id(1, skip_base_filter=True)
     if not chart:
         return
     assert len(ChartDAO.favorited_ids([chart])) == 0
diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py
index 945b337fad..9f8962f85c 100644
--- a/tests/unit_tests/charts/test_post_processing.py
+++ b/tests/unit_tests/charts/test_post_processing.py
@@ -1965,12 +1965,13 @@ def test_apply_post_process_json_format_data_is_none():
 
 
 def test_apply_post_process_verbose_map(session: Session):
+    from superset import db
     from superset.connectors.sqla.models import SqlaTable, SqlMetric
     from superset.models.core import Database
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
     sqla_table = SqlaTable(
         table_name="my_sqla_table",
         columns=[],
@@ -1982,7 +1983,7 @@ def test_apply_post_process_verbose_map(session: Session):
                 expression="COUNT(*)",
             )
         ],
-        database=db,
+        database=database,
     )
 
     result = {
diff --git a/tests/unit_tests/columns/test_models.py b/tests/unit_tests/columns/test_models.py
index 068557e7a6..0ea230da17 100644
--- a/tests/unit_tests/columns/test_models.py
+++ b/tests/unit_tests/columns/test_models.py
@@ -24,9 +24,10 @@ def test_column_model(session: Session) -> None:
     """
     Test basic attributes of a ``Column``.
     """
+    from superset import db
     from superset.columns.models import Column
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Column.metadata.create_all(engine)  # pylint: disable=no-member
 
     column = Column(
@@ -35,8 +36,8 @@ def test_column_model(session: Session) -> None:
         expression="ds",
     )
 
-    session.add(column)
-    session.flush()
+    db.session.add(column)
+    db.session.flush()
 
     assert column.id == 1
     assert column.uuid is not None
diff --git a/tests/unit_tests/commands/importers/v1/assets_test.py b/tests/unit_tests/commands/importers/v1/assets_test.py
index d48eed1be7..9609b0b45c 100644
--- a/tests/unit_tests/commands/importers/v1/assets_test.py
+++ b/tests/unit_tests/commands/importers/v1/assets_test.py
@@ -35,14 +35,14 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
     """
     Test that all new assets are imported correctly.
     """
-    from superset import security_manager
+    from superset import db, security_manager
     from superset.commands.importers.v1.assets import ImportAssetsCommand
     from superset.models.dashboard import dashboard_slices
     from superset.models.slice import Slice
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Slice.metadata.create_all(engine)  # pylint: disable=no-member
     configs = {
         **copy.deepcopy(databases_config),
@@ -53,11 +53,11 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
     expected_number_of_dashboards = len(dashboards_config_1)
     expected_number_of_charts = len(charts_config_1)
 
-    ImportAssetsCommand._import(session, configs)
-    dashboard_ids = session.scalars(
+    ImportAssetsCommand._import(configs)
+    dashboard_ids = db.session.scalars(
         select(dashboard_slices.c.dashboard_id).distinct()
     ).all()
-    chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
+    chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
 
     assert len(chart_ids) == expected_number_of_charts
     assert len(dashboard_ids) == expected_number_of_dashboards
@@ -67,14 +67,14 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
     """
     Test that existing dashboards are updated with new charts.
     """
-    from superset import security_manager
+    from superset import db, security_manager
     from superset.commands.importers.v1.assets import ImportAssetsCommand
     from superset.models.dashboard import dashboard_slices
     from superset.models.slice import Slice
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Slice.metadata.create_all(engine)  # pylint: disable=no-member
     base_configs = {
         **copy.deepcopy(databases_config),
@@ -91,12 +91,12 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
     expected_number_of_dashboards = len(dashboards_config_1)
     expected_number_of_charts = len(charts_config_1)
 
-    ImportAssetsCommand._import(session, base_configs)
-    ImportAssetsCommand._import(session, new_configs)
-    dashboard_ids = session.scalars(
+    ImportAssetsCommand._import(base_configs)
+    ImportAssetsCommand._import(new_configs)
+    dashboard_ids = db.session.scalars(
         select(dashboard_slices.c.dashboard_id).distinct()
     ).all()
-    chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
+    chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
 
     assert len(chart_ids) == expected_number_of_charts
     assert len(dashboard_ids) == expected_number_of_dashboards
@@ -106,14 +106,14 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
     """
     Test that existing dashboards are updated without old charts.
     """
-    from superset import security_manager
+    from superset import db, security_manager
     from superset.commands.importers.v1.assets import ImportAssetsCommand
     from superset.models.dashboard import dashboard_slices
     from superset.models.slice import Slice
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Slice.metadata.create_all(engine)  # pylint: disable=no-member
     base_configs = {
         **copy.deepcopy(databases_config),
@@ -130,12 +130,12 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
     expected_number_of_dashboards = len(dashboards_config_2)
     expected_number_of_charts = len(charts_config_2)
 
-    ImportAssetsCommand._import(session, base_configs)
-    ImportAssetsCommand._import(session, new_configs)
-    dashboard_ids = session.scalars(
+    ImportAssetsCommand._import(base_configs)
+    ImportAssetsCommand._import(new_configs)
+    dashboard_ids = db.session.scalars(
         select(dashboard_slices.c.dashboard_id).distinct()
     ).all()
-    chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
+    chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
 
     assert len(chart_ids) == expected_number_of_charts
     assert len(dashboard_ids) == expected_number_of_dashboards
diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py
index a69d9eaede..837c53ec07 100644
--- a/tests/unit_tests/config_test.py
+++ b/tests/unit_tests/config_test.py
@@ -23,6 +23,8 @@ import pytest
 from pytest_mock import MockerFixture
 from sqlalchemy.orm.session import Session
 
+from superset import db
+
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import SqlaTable
 
@@ -81,7 +83,7 @@ def test_table(session: Session) -> "SqlaTable":
     from superset.connectors.sqla.models import SqlaTable, TableColumn
     from superset.models.core import Database
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     columns = [
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index beb4e99472..3824d7b7b7 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -41,7 +41,7 @@ from superset.initialization import SupersetAppInitializer
 @pytest.fixture
 def get_session(mocker: MockFixture) -> Callable[[], Session]:
     """
-    Create an in-memory SQLite session to test models.
+    Create an in-memory SQLite db.session.to test models.
     """
     engine = create_engine("sqlite://")
 
@@ -49,7 +49,7 @@ def get_session(mocker: MockFixture) -> Callable[[], Session]:
         Session_ = sessionmaker(bind=engine)  # pylint: disable=invalid-name
         in_memory_session = Session_()
 
-        # flask calls session.remove()
+        # flask calls db.session.remove()
         in_memory_session.remove = lambda: None
 
         # patch session
diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py
index 288f68cae0..1e3d1ec975 100644
--- a/tests/unit_tests/dao/dataset_test.py
+++ b/tests/unit_tests/dao/dataset_test.py
@@ -27,6 +27,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
     In particular, allow datasets with the same name in the same database as long as they
     are in different schemas
     """
+    from superset import db
     from superset.connectors.sqla.models import SqlaTable
     from superset.models.core import Database
 
@@ -46,8 +47,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
         schema="dev",
         database=database,
     )
-    session.add_all([database, dataset1, dataset2])
-    session.flush()
+    db.session.add_all([database, dataset1, dataset2])
+    db.session.flush()
 
     # same table name, different schema
     assert (
diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py
index 65e9bbfbfb..eb84b288fd 100644
--- a/tests/unit_tests/dao/queries_test.py
+++ b/tests/unit_tests/dao/queries_test.py
@@ -25,17 +25,18 @@ from superset.exceptions import QueryNotFoundException, SupersetCancelQueryExcep
 
 
 def test_query_dao_save_metadata(session: Session) -> None:
+    from superset import db
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -48,30 +49,31 @@ def test_query_dao_save_metadata(session: Session) -> None:
         results_key="abc",
     )
 
-    session.add(db)
-    session.add(query_obj)
+    db.session.add(database)
+    db.session.add(query_obj)
 
     from superset.daos.query import QueryDAO
 
-    query = session.query(Query).one()
+    query = db.session.query(Query).one()
     QueryDAO.save_metadata(query=query, payload={"columns": []})
     assert query.extra.get("columns", None) == []
 
 
 def test_query_dao_get_queries_changed_after(session: Session) -> None:
+    from superset import db
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     now = datetime.utcnow()
 
     old_query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -87,7 +89,7 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
 
     updated_query_obj = Query(
         client_id="updated_foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from foo",
@@ -101,9 +103,9 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
         changed_on=now - timedelta(days=1),
     )
 
-    session.add(db)
-    session.add(old_query_obj)
-    session.add(updated_query_obj)
+    db.session.add(database)
+    db.session.add(old_query_obj)
+    db.session.add(updated_query_obj)
 
     from superset.daos.query import QueryDAO
 
@@ -116,18 +118,19 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
 def test_query_dao_stop_query_not_found(
     mocker: MockFixture, app: Any, session: Session
 ) -> None:
+    from superset import db
     from superset.common.db_query_status import QueryStatus
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -141,8 +144,8 @@ def test_query_dao_stop_query_not_found(
         status=QueryStatus.RUNNING,
     )
 
-    session.add(db)
-    session.add(query_obj)
+    db.session.add(database)
+    db.session.add(query_obj)
 
     mocker.patch("superset.sql_lab.cancel_query", return_value=False)
 
@@ -151,25 +154,26 @@ def test_query_dao_stop_query_not_found(
     with pytest.raises(QueryNotFoundException):
         QueryDAO.stop_query("foo2")
 
-    query = session.query(Query).one()
+    query = db.session.query(Query).one()
     assert query.status == QueryStatus.RUNNING
 
 
 def test_query_dao_stop_query_not_running(
     mocker: MockFixture, app: Any, session: Session
 ) -> None:
+    from superset import db
     from superset.common.db_query_status import QueryStatus
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -183,31 +187,32 @@ def test_query_dao_stop_query_not_running(
         status=QueryStatus.FAILED,
     )
 
-    session.add(db)
-    session.add(query_obj)
+    db.session.add(database)
+    db.session.add(query_obj)
 
     from superset.daos.query import QueryDAO
 
     QueryDAO.stop_query(query_obj.client_id)
-    query = session.query(Query).one()
+    query = db.session.query(Query).one()
     assert query.status == QueryStatus.FAILED
 
 
 def test_query_dao_stop_query_failed(
     mocker: MockFixture, app: Any, session: Session
 ) -> None:
+    from superset import db
     from superset.common.db_query_status import QueryStatus
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -221,8 +226,8 @@ def test_query_dao_stop_query_failed(
         status=QueryStatus.RUNNING,
     )
 
-    session.add(db)
-    session.add(query_obj)
+    db.session.add(database)
+    db.session.add(query_obj)
 
     mocker.patch("superset.sql_lab.cancel_query", return_value=False)
 
@@ -231,23 +236,24 @@ def test_query_dao_stop_query_failed(
     with pytest.raises(SupersetCancelQueryException):
         QueryDAO.stop_query(query_obj.client_id)
 
-    query = session.query(Query).one()
+    query = db.session.query(Query).one()
     assert query.status == QueryStatus.RUNNING
 
 
 def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -> None:
+    from superset import db
     from superset.common.db_query_status import QueryStatus
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -261,13 +267,13 @@ def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -
         status=QueryStatus.RUNNING,
     )
 
-    session.add(db)
-    session.add(query_obj)
+    db.session.add(database)
+    db.session.add(query_obj)
 
     mocker.patch("superset.sql_lab.cancel_query", return_value=True)
 
     from superset.daos.query import QueryDAO
 
     QueryDAO.stop_query(query_obj.client_id)
-    query = session.query(Query).one()
+    query = db.session.query(Query).one()
     assert query.status == QueryStatus.STOPPED
diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py
index 5f29d0f28c..652d3729b7 100644
--- a/tests/unit_tests/dao/tag_test.py
+++ b/tests/unit_tests/dao/tag_test.py
@@ -70,7 +70,7 @@ def test_remove_user_favorite_tag(mocker):
     # Check that users_favorited no longer contains the user
     assert mock_user not in mock_tag.users_favorited
 
-    # Check that the session was committed
+    # Check that the db.session.was committed
     mock_session.commit.assert_called_once()
 
 
diff --git a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py
index afbce49cd9..ac3d2a919b 100644
--- a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py
@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
-from superset import security_manager
+from superset import db, security_manager
 from superset.commands.dashboard.importers.v1.utils import import_dashboard
 from superset.commands.exceptions import ImportFailedError
 from superset.models.dashboard import Dashboard
@@ -67,7 +67,7 @@ def test_import_dashboard(mocker: MockFixture, session_with_schema: Session) ->
     """
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    dashboard = import_dashboard(session_with_schema, dashboard_config)
+    dashboard = import_dashboard(dashboard_config)
     assert dashboard.dashboard_title == "Test dash"
     assert dashboard.description is None
     assert dashboard.is_managed_externally is False
@@ -88,8 +88,7 @@ def test_import_dashboard_managed_externally(
     config = copy.deepcopy(dashboard_config)
     config["is_managed_externally"] = True
     config["external_url"] = "https://example.org/my_dashboard"
-
-    dashboard = import_dashboard(session_with_schema, config)
+    dashboard = import_dashboard(config)
     assert dashboard.is_managed_externally is True
     assert dashboard.external_url == "https://example.org/my_dashboard"
 
@@ -107,7 +106,7 @@ def test_import_dashboard_without_permission(
     mocker.patch.object(security_manager, "can_access", return_value=False)
 
     with pytest.raises(ImportFailedError) as excinfo:
-        import_dashboard(session_with_schema, dashboard_config)
+        import_dashboard(dashboard_config)
     assert (
         str(excinfo.value)
         == "Dashboard doesn't exist and user doesn't have permission to create dashboards"
@@ -135,7 +134,7 @@ def test_import_existing_dashboard_without_permission(
 
     with override_user("admin"):
         with pytest.raises(ImportFailedError) as excinfo:
-            import_dashboard(session_with_data, dashboard_config, overwrite=True)
+            import_dashboard(dashboard_config, overwrite=True)
         assert (
             str(excinfo.value)
             == "A dashboard already exists and user doesn't have permissions to overwrite it"
@@ -171,7 +170,8 @@ def test_import_existing_dashboard_with_permission(
     )
 
     with override_user(admin):
-        import_dashboard(session_with_data, dashboard_config, overwrite=True)
+        import_dashboard(dashboard_config, overwrite=True)
+
     # Assert that the can write to dashboard was checked
     security_manager.can_access.assert_called_once_with("can_write", "Dashboard")
     security_manager.can_access_dashboard.assert_called_once_with(dashboard)
diff --git a/tests/unit_tests/dashboards/dao_tests.py b/tests/unit_tests/dashboards/dao_tests.py
index 3bf4038f16..09edfacd44 100644
--- a/tests/unit_tests/dashboards/dao_tests.py
+++ b/tests/unit_tests/dashboards/dao_tests.py
@@ -42,12 +42,10 @@ def session_with_data(session: Session) -> Iterator[Session]:
     session.rollback()
 
 
-def test_add_favorite(session_with_data: Session) -> None:
+def test_add_favorite(session: Session) -> None:
     from superset.daos.dashboard import DashboardDAO
 
-    dashboard = DashboardDAO.find_by_id(
-        100, session=session_with_data, skip_base_filter=True
-    )
+    dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
     if not dashboard:
         return
     assert len(DashboardDAO.favorited_ids([dashboard])) == 0
@@ -59,12 +57,10 @@ def test_add_favorite(session_with_data: Session) -> None:
     assert len(DashboardDAO.favorited_ids([dashboard])) == 1
 
 
-def test_remove_favorite(session_with_data: Session) -> None:
+def test_remove_favorite(session: Session) -> None:
     from superset.daos.dashboard import DashboardDAO
 
-    dashboard = DashboardDAO.find_by_id(
-        100, session=session_with_data, skip_base_filter=True
-    )
+    dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
     if not dashboard:
         return
     assert len(DashboardDAO.favorited_ids([dashboard])) == 0
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index cf3e64c306..f867f82a98 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -28,6 +28,8 @@ from flask import current_app
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
+from superset import db
+
 
 def test_filter_by_uuid(
     session: Session,
@@ -49,14 +51,14 @@ def test_filter_by_uuid(
 
     # create table for databases
     Database.metadata.create_all(session.get_bind())  # pylint: disable=no-member
-    session.add(
+    db.session.add(
         Database(
             database_name="my_db",
             sqlalchemy_uri="sqlite://",
             uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
         )
     )
-    session.commit()
+    db.session.commit()
 
     response = client.get(
         "/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:"
@@ -96,7 +98,7 @@ def test_post_with_uuid(
     payload = response.json
     assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
 
-    database = session.query(Database).one()
+    database = db.session.query(Database).one()
     assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
 
 
@@ -139,8 +141,8 @@ def test_password_mask(
             }
         ),
     )
-    session.add(database)
-    session.commit()
+    db.session.add(database)
+    db.session.commit()
 
     # mock the lookup so that we don't need to include the driver
     mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -195,8 +197,8 @@ def test_database_connection(
             }
         ),
     )
-    session.add(database)
-    session.commit()
+    db.session.add(database)
+    db.session.commit()
 
     # mock the lookup so that we don't need to include the driver
     mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -331,8 +333,8 @@ def test_update_with_password_mask(
             }
         ),
     )
-    session.add(database)
-    session.commit()
+    db.session.add(database)
+    db.session.commit()
 
     client.put(
         "/api/v1/database/1",
@@ -347,7 +349,7 @@ def test_update_with_password_mask(
             ),
         },
     )
-    database = session.query(Database).one()
+    database = db.session.query(Database).one()
     assert (
         database.encrypted_extra
         == '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}'
@@ -429,8 +431,8 @@ def test_delete_ssh_tunnel(
                 }
             ),
         )
-        session.add(database)
-        session.commit()
+        db.session.add(database)
+        db.session.commit()
 
         # mock the lookup so that we don't need to include the driver
         mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -446,8 +448,8 @@ def test_delete_ssh_tunnel(
             database=database,
         )
 
-        session.add(tunnel)
-        session.commit()
+        db.session.add(tunnel)
+        db.session.commit()
 
         # Get our recently created SSHTunnel
         response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
@@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found(
                 }
             ),
         )
-        session.add(database)
-        session.commit()
+        db.session.add(database)
+        db.session.commit()
 
         # mock the lookup so that we don't need to include the driver
         mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found(
             database=database,
         )
 
-        session.add(tunnel)
-        session.commit()
+        db.session.add(tunnel)
+        db.session.commit()
 
         # Delete the recently created SSHTunnel
         response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
@@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter(
                 }
             ),
         )
-        session.add(database)
-        session.commit()
+        db.session.add(database)
+        db.session.commit()
 
         # Create our Second Database
         database = Database(
@@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter(
                 }
             ),
         )
-        session.add(database)
-        session.commit()
+        db.session.add(database)
+        db.session.commit()
 
         # mock the lookup so that we don't need to include the driver
         mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py
index 5fb4d12ce5..ad18f0157c 100644
--- a/tests/unit_tests/databases/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py
@@ -23,6 +23,7 @@ import pytest
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
+from superset import db
 from superset.commands.exceptions import ImportFailedError
 
 
@@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Database.metadata.create_all(engine)  # pylint: disable=no-member
 
     config = copy.deepcopy(database_config)
-    database = import_database(session, config)
+    database = import_database(config)
     assert database.database_name == "imported_database"
     assert database.sqlalchemy_uri == "someengine://user:pass@host1"
     assert database.cache_timeout is None
@@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
     # missing
     config = copy.deepcopy(database_config)
     del config["allow_dml"]
-    session.delete(database)
-    session.flush()
-    database = import_database(session, config)
+    db.session.delete(database)
+    db.session.flush()
+    database = import_database(config)
     assert database.allow_dml is False
 
 
@@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -
     app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Database.metadata.create_all(engine)  # pylint: disable=no-member
 
     config = copy.deepcopy(database_config_sqlite)
     with pytest.raises(ImportFailedError) as excinfo:
-        _ = import_database(session, config)
+        _ = import_database(config)
     assert (
         str(excinfo.value)
         == "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
@@ -106,14 +107,14 @@ def test_import_database_managed_externally(
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Database.metadata.create_all(engine)  # pylint: disable=no-member
 
     config = copy.deepcopy(database_config)
     config["is_managed_externally"] = True
     config["external_url"] = "https://example.org/my_database"
 
-    database = import_database(session, config)
+    database = import_database(config)
     assert database.is_managed_externally is True
     assert database.external_url == "https://example.org/my_database"
 
@@ -132,13 +133,13 @@ def test_import_database_without_permission(
 
     mocker.patch.object(security_manager, "can_access", return_value=False)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Database.metadata.create_all(engine)  # pylint: disable=no-member
 
     config = copy.deepcopy(database_config)
 
     with pytest.raises(ImportFailedError) as excinfo:
-        import_database(session, config)
+        import_database(config)
     assert (
         str(excinfo.value)
         == "Database doesn't exist and user doesn't have permission to create databases"
@@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) ->
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Database.metadata.create_all(engine)  # pylint: disable=no-member
 
     config = copy.deepcopy(database_config)
     config["extra"]["version"] = "1.1.1"
-    database = import_database(session, config)
+    database = import_database(config)
     assert json.loads(database.extra)["version"] == "1.1.1"
diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py
index b792a65336..a826d01be8 100644
--- a/tests/unit_tests/databases/dao/dao_tests.py
+++ b/tests/unit_tests/databases/dao/dao_tests.py
@@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
     engine = session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
     sqla_table = SqlaTable(
         table_name="my_sqla_table",
         columns=[],
         metrics=[],
-        database=db,
+        database=database,
     )
     ssh_tunnel = SSHTunnel(
-        database_id=db.id,
-        database=db,
+        database_id=database.id,
+        database=database,
     )
 
-    session.add(db)
+    session.add(database)
     session.add(sqla_table)
     session.add(ssh_tunnel)
     session.flush()
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
index 1777bdc2e1..4b05cce637 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None:
     from superset.databases.ssh_tunnel.models import SSHTunnel
     from superset.models.core import Database
 
-    db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
 
     properties = {
-        "database_id": db.id,
+        "database_id": database.id,
         "server_address": "123.132.123.1",
         "server_port": "3005",
         "username": "foo",
         "password": "bar",
     }
 
-    result = CreateSSHTunnelCommand(db, properties).run()
+    result = CreateSSHTunnelCommand(database, properties).run()
 
     assert result is not None
     assert isinstance(result, SSHTunnel)
@@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
     from superset.databases.ssh_tunnel.models import SSHTunnel
     from superset.models.core import Database
 
-    db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
 
     # If we are trying to create a tunnel with a private_key_password
     # then a private_key is mandatory
     properties = {
-        "database": db,
+        "database": database,
         "server_address": "123.132.123.1",
         "server_port": "3005",
         "username": "foo",
         "private_key_password": "bar",
     }
 
-    command = CreateSSHTunnelCommand(db, properties)
+    command = CreateSSHTunnelCommand(database, properties)
 
     with pytest.raises(SSHTunnelInvalidError) as excinfo:
         command.run()
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
index 14838ddc58..78f9c1142c 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
@@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
     engine = session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
     sqla_table = SqlaTable(
         table_name="my_sqla_table",
         columns=[],
         metrics=[],
-        database=db,
+        database=database,
     )
     ssh_tunnel = SSHTunnel(
-        database_id=db.id,
-        database=db,
+        database_id=database.id,
+        database=database,
     )
 
-    session.add(db)
+    session.add(database)
     session.add(sqla_table)
     session.add(ssh_tunnel)
     session.flush()
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
index 5c3907b016..54e54d05da 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
@@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]:
     engine = session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
     sqla_table = SqlaTable(
         table_name="my_sqla_table",
         columns=[],
         metrics=[],
-        database=db,
+        database=database,
+    )
+    ssh_tunnel = SSHTunnel(
+        database_id=database.id, database=database, server_address="Test"
     )
-    ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
 
-    session.add(db)
+    session.add(database)
     session.add(sqla_table)
     session.add(ssh_tunnel)
     session.flush()
diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
index 7a88807597..4646e12c1f 100644
--- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
+++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
@@ -25,11 +25,11 @@ def test_create_ssh_tunnel():
     from superset.databases.ssh_tunnel.models import SSHTunnel
     from superset.models.core import Database
 
-    db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
 
     result = SSHTunnelDAO.create(
         attributes={
-            "database_id": db.id,
+            "database_id": database.id,
             "server_address": "123.132.123.1",
             "server_port": "3005",
             "username": "foo",
diff --git a/tests/unit_tests/datasets/api_tests.py b/tests/unit_tests/datasets/api_tests.py
index de93720fa6..e0786afaa3 100644
--- a/tests/unit_tests/datasets/api_tests.py
+++ b/tests/unit_tests/datasets/api_tests.py
@@ -19,6 +19,8 @@ from typing import Any
 
 from sqlalchemy.orm.session import Session
 
+from superset import db
+
 
 def test_put_invalid_dataset(
     session: Session,
@@ -31,7 +33,7 @@ def test_put_invalid_dataset(
     from superset.connectors.sqla.models import SqlaTable
     from superset.models.core import Database
 
-    SqlaTable.metadata.create_all(session.get_bind())
+    SqlaTable.metadata.create_all(db.session.get_bind())
 
     database = Database(
         database_name="my_db",
@@ -41,8 +43,8 @@ def test_put_invalid_dataset(
         table_name="test_put_invalid_dataset",
         database=database,
     )
-    session.add(dataset)
-    session.flush()
+    db.session.add(dataset)
+    db.session.flush()
 
     response = client.put(
         "/api/v1/dataset/1",
diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py
index 20565da5bc..73f383859b 100644
--- a/tests/unit_tests/datasets/commands/export_test.py
+++ b/tests/unit_tests/datasets/commands/export_test.py
@@ -20,6 +20,8 @@ import json
 
 from sqlalchemy.orm.session import Session
 
+from superset import db
+
 
 def test_export(session: Session) -> None:
     """
@@ -29,12 +31,12 @@ def test_export(session: Session) -> None:
     from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
     from superset.models.core import Database
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     columns = [
         TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
index 5089838e69..a7660d6c0b 100644
--- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
@@ -28,6 +28,7 @@ from flask import current_app
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
+from superset import db
 from superset.commands.dataset.exceptions import (
     DatasetForbiddenDataURI,
     ImportFailedError,
@@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     dataset_uuid = uuid.uuid4()
     config = {
@@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
         "database_id": database.id,
     }
 
-    sqla_table = import_dataset(session, config)
+    sqla_table = import_dataset(config)
     assert sqla_table.table_name == "my_table"
     assert sqla_table.main_dttm_col == "ds"
     assert sqla_table.description == "This is the description"
@@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     dataset_uuid = uuid.uuid4()
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     dataset = SqlaTable(
         uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
     )
     column = TableColumn(column_name="existing_column")
-    session.add(dataset)
-    session.add(column)
-    session.flush()
+    db.session.add(dataset)
+    db.session.add(column)
+    db.session.flush()
 
     config = {
         "table_name": dataset.table_name,
@@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
         "database_id": database.id,
     }
 
-    sqla_table = import_dataset(session, config, overwrite=True)
+    sqla_table = import_dataset(config, overwrite=True)
     assert sqla_table.table_name == dataset.table_name
     assert sqla_table.main_dttm_col == "ds"
     assert sqla_table.description == "This is the description"
@@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     dataset_uuid = uuid.uuid4()
     yaml_config: dict[str, Any] = {
@@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
     schema = ImportV1DatasetSchema()
     dataset_config = schema.load(yaml_config)
     dataset_config["database_id"] = database.id
-    sqla_table = import_dataset(session, dataset_config)
+    sqla_table = import_dataset(dataset_config)
 
     assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
     assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
@@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string(
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     dataset_uuid = uuid.uuid4()
     yaml_config: dict[str, Any] = {
@@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string(
     schema = ImportV1DatasetSchema()
     dataset_config = schema.load(yaml_config)
     dataset_config["database_id"] = database.id
-    sqla_table = import_dataset(session, dataset_config)
+    sqla_table = import_dataset(dataset_config)
 
     assert sqla_table.extra == None
 
@@ -443,12 +444,12 @@ def test_import_column_allowed_data_url(
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     dataset_uuid = uuid.uuid4()
     yaml_config: dict[str, Any] = {
@@ -495,9 +496,8 @@ def test_import_column_allowed_data_url(
     schema = ImportV1DatasetSchema()
     dataset_config = schema.load(yaml_config)
     dataset_config["database_id"] = database.id
-    _ = import_dataset(session, dataset_config, force_data=True)
-    session.connection()
-    assert [("value1",), ("value2",)] == session.execute(
+    _ = import_dataset(dataset_config, force_data=True)
+    assert [("value1",), ("value2",)] == db.session.execute(
         "SELECT * FROM my_table"
     ).fetchall()
 
@@ -517,19 +517,19 @@ def test_import_dataset_managed_externally(
 
     mocker.patch.object(security_manager, "can_access", return_value=True)
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
-    session.add(database)
-    session.flush()
+    db.session.add(database)
+    db.session.flush()
 
     config = copy.deepcopy(dataset_config)
     config["is_managed_externally"] = True
     config["external_url"] = "https://example.org/my_table"
     config["database_id"] = database.id
 
-    sqla_table = import_dataset(session, config)
+    sqla_table = import_dataset(config)
     assert sqla_table.is_managed_externally is True
     assert sqla_table.external_url == "https://example.org/my_table"
 
diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py
index 3302f2dc04..a4632fad3d 100644
--- a/tests/unit_tests/datasets/dao/dao_tests.py
+++ b/tests/unit_tests/datasets/dao/dao_tests.py
@@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]:
     engine = session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
     sqla_table = SqlaTable(
         table_name="my_sqla_table",
         columns=[],
         metrics=[],
-        database=db,
+        database=database,
     )
 
-    session.add(db)
+    session.add(database)
     session.add(sqla_table)
     session.flush()
     yield session
@@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N
 
     result = DatasetDAO.find_by_id(
         1,
-        session=session_with_data,
         skip_base_filter=True,
     )
 
@@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
 
     result = DatasetDAO.find_by_id(
         125326326,
-        session=session_with_data,
         skip_base_filter=True,
     )
     assert result is None
@@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) ->
 
     result = DatasetDAO.find_by_ids(
         [1, 125326326],
-        session=session_with_data,
         skip_base_filter=True,
     )
 
@@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found(
 
     result = DatasetDAO.find_by_ids(
         [125326326, 125326326125326326],
-        session=session_with_data,
         skip_base_filter=True,
     )
 
diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py
index b4ce162c0c..adc674d0fd 100644
--- a/tests/unit_tests/datasource/dao_tests.py
+++ b/tests/unit_tests/datasource/dao_tests.py
@@ -35,7 +35,7 @@ def session_with_data(session: Session) -> Iterator[Session]:
     engine = session.get_bind()
     SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
 
     columns = [
         TableColumn(column_name="a", type="INTEGER"),
@@ -45,12 +45,12 @@ def session_with_data(session: Session) -> Iterator[Session]:
         table_name="my_sqla_table",
         columns=columns,
         metrics=[],
-        database=db,
+        database=database,
     )
 
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
@@ -63,13 +63,13 @@ def session_with_data(session: Session) -> Iterator[Session]:
         results_key="abc",
     )
 
-    saved_query = SavedQuery(database=db, sql="select * from foo")
+    saved_query = SavedQuery(database=database, sql="select * from foo")
 
     table = Table(
         name="my_table",
         schema="my_schema",
         catalog="my_catalog",
-        database=db,
+        database=database,
         columns=[],
     )
 
@@ -93,7 +93,7 @@ FROM my_catalog.my_schema.my_table
     session.add(table)
     session.add(saved_query)
     session.add(query_obj)
-    session.add(db)
+    session.add(database)
     session.add(sqla_table)
     session.flush()
     yield session
@@ -190,7 +190,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None:
 def test_get_all_datasources(session_with_data: Session) -> None:
     from superset.connectors.sqla.models import SqlaTable
 
-    result = SqlaTable.get_all_datasources(session=session_with_data)
+    result = SqlaTable.get_all_datasources()
     assert len(result) == 1
 
 
diff --git a/tests/unit_tests/db_engine_specs/test_druid.py b/tests/unit_tests/db_engine_specs/test_druid.py
index d090dffcde..0ab4688214 100644
--- a/tests/unit_tests/db_engine_specs/test_druid.py
+++ b/tests/unit_tests/db_engine_specs/test_druid.py
@@ -74,10 +74,10 @@ def test_extras_without_ssl() -> None:
     from superset.db_engine_specs.druid import DruidEngineSpec
     from tests.integration_tests.fixtures.database import default_db_extra
 
-    db = mock.Mock()
-    db.extra = default_db_extra
-    db.server_cert = None
-    extras = DruidEngineSpec.get_extra_params(db)
+    database = mock.Mock()
+    database.extra = default_db_extra
+    database.server_cert = None
+    extras = DruidEngineSpec.get_extra_params(database)
     assert "connect_args" not in extras["engine_params"]
 
 
@@ -86,10 +86,10 @@ def test_extras_with_ssl() -> None:
     from tests.integration_tests.fixtures.certificates import ssl_certificate
     from tests.integration_tests.fixtures.database import default_db_extra
 
-    db = mock.Mock()
-    db.extra = default_db_extra
-    db.server_cert = ssl_certificate
-    extras = DruidEngineSpec.get_extra_params(db)
+    database = mock.Mock()
+    database.extra = default_db_extra
+    database.server_cert = ssl_certificate
+    extras = DruidEngineSpec.get_extra_params(database)
     connect_args = extras["engine_params"]["connect_args"]
     assert connect_args["scheme"] == "https"
     assert "ssl_verify_cert" in connect_args
diff --git a/tests/unit_tests/db_engine_specs/test_pinot.py b/tests/unit_tests/db_engine_specs/test_pinot.py
index a1648f5f60..72c8267816 100644
--- a/tests/unit_tests/db_engine_specs/test_pinot.py
+++ b/tests/unit_tests/db_engine_specs/test_pinot.py
@@ -50,8 +50,8 @@ def test_extras_without_ssl() -> None:
     from superset.db_engine_specs.pinot import PinotEngineSpec as spec
     from tests.integration_tests.fixtures.database import default_db_extra
 
-    db = mock.Mock()
-    db.extra = default_db_extra
-    db.server_cert = None
-    extras = spec.get_extra_params(db)
+    database = mock.Mock()
+    database.extra = default_db_extra
+    database.server_cert = None
+    extras = spec.get_extra_params(database)
     assert "connect_args" not in extras["engine_params"]
diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py
index cc738fd6c6..caa141aaf7 100644
--- a/tests/unit_tests/extensions/test_sqlalchemy.py
+++ b/tests/unit_tests/extensions/test_sqlalchemy.py
@@ -26,6 +26,7 @@ from sqlalchemy.engine import create_engine
 from sqlalchemy.exc import ProgrammingError
 from sqlalchemy.orm.session import Session
 
+from superset import db
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import SupersetSecurityException
 from tests.unit_tests.conftest import with_feature_flags
@@ -38,7 +39,7 @@ if TYPE_CHECKING:
 def database1(session: Session) -> Iterator["Database"]:
     from superset.models.core import Database
 
-    engine = session.connection().engine
+    engine = db.session.connection().engine
     Database.metadata.create_all(engine)  # pylint: disable=no-member
 
     database = Database(
@@ -46,13 +47,13 @@ def database1(session: Session) -> Iterator["Database"]:
         sqlalchemy_uri="sqlite:///database1.db",
         allow_dml=True,
     )
-    session.add(database)
-    session.commit()
+    db.session.add(database)
+    db.session.commit()
 
     yield database
 
-    session.delete(database)
-    session.commit()
+    db.session.delete(database)
+    db.session.commit()
     os.unlink("database1.db")
 
 
@@ -62,12 +63,12 @@ def table1(session: Session, database1: "Database") -> Iterator[None]:
         conn = engine.connect()
         conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)")
         conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
-        session.commit()
+        db.session.commit()
 
         yield
 
         conn.execute("DROP TABLE table1")
-        session.commit()
+        db.session.commit()
 
 
 @pytest.fixture
@@ -79,13 +80,13 @@ def database2(session: Session) -> Iterator["Database"]:
         sqlalchemy_uri="sqlite:///database2.db",
         allow_dml=False,
     )
-    session.add(database)
-    session.commit()
+    db.session.add(database)
+    db.session.commit()
 
     yield database
 
-    session.delete(database)
-    session.commit()
+    db.session.delete(database)
+    db.session.commit()
     os.unlink("database2.db")
 
 
@@ -95,12 +96,12 @@ def table2(session: Session, database2: "Database") -> Iterator[None]:
         conn = engine.connect()
         conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)")
         conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')")
-        session.commit()
+        db.session.commit()
 
         yield
 
         conn.execute("DROP TABLE table2")
-        session.commit()
+        db.session.commit()
 
 
 @with_feature_flags(ENABLE_SUPERSET_META_DB=True)
diff --git a/tests/unit_tests/queries/dao_test.py b/tests/unit_tests/queries/dao_test.py
index a0221b8019..dbca78a9d3 100644
--- a/tests/unit_tests/queries/dao_test.py
+++ b/tests/unit_tests/queries/dao_test.py
@@ -22,10 +22,10 @@ def test_column_attributes_on_query():
     from superset.models.core import Database
     from superset.models.sql_lab import Query
 
-    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
     query_obj = Query(
         client_id="foo",
-        database=db,
+        database=database,
         tab_name="test_tab",
         sql_editor_id="test_editor_id",
         sql="select * from bar",
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 8265277372..83e7c373c8 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -125,7 +125,7 @@ def test_sql_lab_insert_rls_as_subquery(
     from superset.sql_lab import execute_sql_statement
     from superset.utils.core import RowLevelSecurityFilterType
 
-    engine = session.connection().engine
+    engine = db.session.connection().engine
     Query.metadata.create_all(engine)  # pylint: disable=no-member
 
     connection = engine.raw_connection()
@@ -143,8 +143,8 @@ def test_sql_lab_insert_rls_as_subquery(
         limit=5,
         select_as_cta_used=False,
     )
-    session.add(query)
-    session.commit()
+    db.session.add(query)
+    db.session.commit()
 
     admin = User(
         first_name="Alice",
@@ -185,8 +185,8 @@ def test_sql_lab_insert_rls_as_subquery(
         group_key=None,
         clause="c > 5",
     )
-    session.add(rls)
-    session.flush()
+    db.session.add(rls)
+    db.session.flush()
     mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
     mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
 
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f650b77734..f05e16ae85 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1759,8 +1759,7 @@ def test_get_rls_for_table(mocker: MockerFixture) -> None:
     Tests for ``get_rls_for_table``.
     """
     candidate = Identifier([Token(Name, "some_table")])
-    db = mocker.patch("superset.db")
-    dataset = db.session.query().filter().one_or_none()
+    dataset = mocker.patch("superset.db").session.query().filter().one_or_none()
     dataset.__str__.return_value = "some_table"
 
     dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py
index 7705dba6aa..926e059261 100644
--- a/tests/unit_tests/tables/test_models.py
+++ b/tests/unit_tests/tables/test_models.py
@@ -14,11 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 # pylint: disable=import-outside-toplevel, unused-argument
-
 from sqlalchemy.orm.session import Session
 
+from superset import db
+
 
 def test_table_model(session: Session) -> None:
     """
@@ -28,7 +28,7 @@ def test_table_model(session: Session) -> None:
     from superset.models.core import Database
     from superset.tables.models import Table
 
-    engine = session.get_bind()
+    engine = db.session.get_bind()
     Table.metadata.create_all(engine)  # pylint: disable=no-member
 
     table = Table(
@@ -44,8 +44,8 @@ def test_table_model(session: Session) -> None:
             )
         ],
     )
-    session.add(table)
-    session.flush()
+    db.session.add(table)
+    db.session.flush()
 
     assert table.id == 1
     assert table.uuid is not None
diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py
index b18144521a..1e1895bb77 100644
--- a/tests/unit_tests/tags/commands/create_test.py
+++ b/tests/unit_tests/tags/commands/create_test.py
@@ -18,6 +18,7 @@ import pytest
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
+from superset import db
 from superset.utils.core import DatasourceType
 
 
@@ -40,13 +41,15 @@ def session_with_data(session: Session):
         slice_name="slice_name",
     )
 
-    db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
+    database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
 
     columns = [
         TableColumn(column_name="a", type="INTEGER"),
     ]
 
-    saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
+    saved_query = SavedQuery(
+        label="test_query", database=database, sql="select * from foo"
+    )
 
     dashboard_obj = Dashboard(
         id=100,
@@ -57,7 +60,7 @@ def session_with_data(session: Session):
     )
 
     session.add(slice_obj)
-    session.add(db)
+    session.add(database)
     session.add(saved_query)
     session.add(dashboard_obj)
     session.commit()
@@ -74,9 +77,9 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
     from superset.tags.models import ObjectType, TaggedObject
 
     # Define a list of objects to tag
-    query = session_with_data.query(SavedQuery).first()
-    chart = session_with_data.query(Slice).first()
-    dashboard = session_with_data.query(Dashboard).first()
+    query = db.session.query(SavedQuery).first()
+    chart = db.session.query(Slice).first()
+    dashboard = db.session.query(Dashboard).first()
 
     mocker.patch(
         "superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -94,10 +97,10 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
         data={"name": "test_tag", "objects_to_tag": objects_to_tag}
     ).run()
 
-    assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+    assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
     for object_type, object_id in objects_to_tag:
         assert (
-            session_with_data.query(TaggedObject)
+            db.session.query(TaggedObject)
             .filter(
                 TaggedObject.object_type == object_type,
                 TaggedObject.object_id == object_id,
@@ -117,9 +120,9 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
     from superset.tags.models import ObjectType, TaggedObject
 
     # Define a list of objects to tag
-    query = session_with_data.query(SavedQuery).first()
-    chart = session_with_data.query(Slice).first()
-    dashboard = session_with_data.query(Dashboard).first()
+    query = db.session.query(SavedQuery).first()
+    chart = db.session.query(Slice).first()
+    dashboard = db.session.query(Dashboard).first()
 
     mocker.patch(
         "superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -136,10 +139,10 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
     CreateCustomTagWithRelationshipsCommand(
         data={"name": "test_tag", "objects_to_tag": objects_to_tag}
     ).run()
-    assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+    assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
 
     CreateCustomTagWithRelationshipsCommand(
         data={"name": "test_tag", "objects_to_tag": []}
     ).run()
 
-    assert len(session_with_data.query(TaggedObject).all()) == 0
+    assert len(db.session.query(TaggedObject).all()) == 0
diff --git a/tests/unit_tests/tags/commands/update_test.py b/tests/unit_tests/tags/commands/update_test.py
index e488321228..75636ab0af 100644
--- a/tests/unit_tests/tags/commands/update_test.py
+++ b/tests/unit_tests/tags/commands/update_test.py
@@ -18,6 +18,7 @@ import pytest
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
+from superset import db
 from superset.utils.core import DatasourceType
 
 
@@ -41,7 +42,7 @@ def session_with_data(session: Session):
         slice_name="slice_name",
     )
 
-    db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
+    database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
 
     columns = [
         TableColumn(column_name="a", type="INTEGER"),
@@ -51,7 +52,7 @@ def session_with_data(session: Session):
         table_name="my_sqla_table",
         columns=columns,
         metrics=[],
-        database=db,
+        database=database,
     )
 
     dashboard_obj = Dashboard(
@@ -62,7 +63,9 @@ def session_with_data(session: Session):
         published=True,
     )
 
-    saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
+    saved_query = SavedQuery(
+        label="test_query", database=database, sql="select * from foo"
+    )
 
     tag = Tag(name="test_name", description="test_description")
 
@@ -79,7 +82,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
     from superset.models.dashboard import Dashboard
     from superset.tags.models import ObjectType, TaggedObject
 
-    dashboard = session_with_data.query(Dashboard).first()
+    dashboard = db.session.query(Dashboard).first()
     mocker.patch(
         "superset.security.SupersetSecurityManager.is_admin", return_value=True
     )
@@ -104,7 +107,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
     updated_tag = TagDAO.find_by_name("new_name")
     assert updated_tag is not None
     assert updated_tag.description == "new_description"
-    assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+    assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
 
 
 def test_update_command_success_duplicates(
@@ -117,8 +120,8 @@ def test_update_command_success_duplicates(
     from superset.models.slice import Slice
     from superset.tags.models import ObjectType, TaggedObject
 
-    dashboard = session_with_data.query(Dashboard).first()
-    chart = session_with_data.query(Slice).first()
+    dashboard = db.session.query(Dashboard).first()
+    chart = db.session.query(Slice).first()
 
     mocker.patch(
         "superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -153,7 +156,7 @@ def test_update_command_success_duplicates(
     updated_tag = TagDAO.find_by_name("new_name")
     assert updated_tag is not None
     assert updated_tag.description == "new_description"
-    assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+    assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
     assert changed_model.objects[0].object_id == chart.id
 
 
@@ -168,8 +171,8 @@ def test_update_command_failed_validation(
     from superset.models.slice import Slice
     from superset.tags.models import ObjectType
 
-    dashboard = session_with_data.query(Dashboard).first()
-    chart = session_with_data.query(Slice).first()
+    dashboard = db.session.query(Dashboard).first()
+    chart = db.session.query(Slice).first()
     objects_to_tag = [
         (ObjectType.chart, chart.id),
     ]