You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2024/02/13 17:20:23 UTC
(superset) branch master updated: refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 847ed3f5b0 refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
847ed3f5b0 is described below
commit 847ed3f5b0016abed968abfacb6b9980e74dc1bf
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Feb 14 06:20:15 2024 +1300
refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) (#26909)
---
.pylintrc | 2 +-
superset/cli/importexport.py | 3 +-
superset/commands/chart/importers/v1/__init__.py | 10 ++-
superset/commands/chart/importers/v1/utils.py | 13 ++--
.../commands/dashboard/importers/v1/__init__.py | 19 +++--
superset/commands/dashboard/importers/v1/utils.py | 11 ++-
.../commands/database/importers/v1/__init__.py | 8 +--
superset/commands/database/importers/v1/utils.py | 13 ++--
superset/commands/dataset/importers/v0.py | 57 +++++++--------
superset/commands/dataset/importers/v1/__init__.py | 8 +--
superset/commands/dataset/importers/v1/utils.py | 20 +++---
superset/commands/importers/v1/__init__.py | 6 +-
superset/commands/importers/v1/assets.py | 21 +++---
superset/commands/importers/v1/examples.py | 11 +--
superset/commands/query/importers/v1/__init__.py | 8 +--
superset/commands/query/importers/v1/utils.py | 13 ++--
superset/connectors/sqla/models.py | 15 ++--
superset/connectors/sqla/utils.py | 6 +-
superset/daos/base.py | 13 ++--
superset/databases/filters.py | 4 +-
superset/db_engine_specs/gsheets.py | 6 +-
superset/extensions/__init__.py | 2 +-
superset/models/dashboard.py | 2 +-
superset/models/helpers.py | 15 ++--
superset/security/manager.py | 15 ++--
superset/sqllab/schemas.py | 2 +-
superset/tables/models.py | 2 +-
superset/tags/models.py | 24 ++++---
superset/utils/dashboard_import_export.py | 7 +-
superset/utils/dict_import_export.py | 7 +-
tests/integration_tests/base_tests.py | 30 ++++----
tests/integration_tests/cache_tests.py | 4 +-
tests/integration_tests/charts/api_tests.py | 10 +--
tests/integration_tests/charts/commands_tests.py | 2 +-
tests/integration_tests/core_tests.py | 13 ++--
tests/integration_tests/dashboards/api_tests.py | 4 +-
.../dashboards/filter_state/api_tests.py | 7 +-
.../dashboards/permalink/api_tests.py | 6 +-
.../dashboards/superset_factory_util.py | 58 ++++++++--------
tests/integration_tests/databases/api_tests.py | 10 ++-
tests/integration_tests/datasets/api_tests.py | 1 -
tests/integration_tests/datasource_tests.py | 36 +++++-----
.../db_engine_specs/databricks_tests.py | 16 ++---
.../db_engine_specs/hive_tests.py | 14 ++--
.../db_engine_specs/postgres_tests.py | 24 +++----
.../db_engine_specs/presto_tests.py | 14 ++--
.../integration_tests/dict_import_export_tests.py | 26 ++++---
tests/integration_tests/explore/api_tests.py | 10 ++-
.../explore/form_data/api_tests.py | 10 ++-
.../explore/form_data/commands_tests.py | 32 ++++-----
.../explore/permalink/api_tests.py | 3 +-
.../explore/permalink/commands_tests.py | 32 ++++-----
tests/integration_tests/fixtures/datasource.py | 13 ++--
tests/integration_tests/import_export_tests.py | 15 ++--
.../key_value/commands/fixtures.py | 3 +-
.../security/guest_token_security_tests.py | 15 ++--
.../security/migrate_roles_tests.py | 5 +-
.../security/row_level_security_tests.py | 29 ++++----
tests/integration_tests/security_tests.py | 4 +-
tests/integration_tests/sqllab_tests.py | 6 +-
tests/integration_tests/test_jinja_context.py | 38 +++++-----
tests/integration_tests/utils/get_dashboards.py | 5 +-
tests/integration_tests/utils_tests.py | 4 +-
.../charts/commands/importers/v1/import_test.py | 12 ++--
tests/unit_tests/charts/dao/dao_tests.py | 16 ++---
tests/unit_tests/charts/test_post_processing.py | 7 +-
tests/unit_tests/columns/test_models.py | 7 +-
.../commands/importers/v1/assets_test.py | 34 ++++-----
tests/unit_tests/config_test.py | 4 +-
tests/unit_tests/conftest.py | 4 +-
tests/unit_tests/dao/dataset_test.py | 5 +-
tests/unit_tests/dao/queries_test.py | 80 ++++++++++++----------
tests/unit_tests/dao/tag_test.py | 2 +-
.../commands/importers/v1/import_test.py | 14 ++--
tests/unit_tests/dashboards/dao_tests.py | 12 ++--
tests/unit_tests/databases/api_test.py | 46 +++++++------
.../databases/commands/importers/v1/import_test.py | 27 ++++----
tests/unit_tests/databases/dao/dao_tests.py | 10 +--
.../databases/ssh_tunnel/commands/create_test.py | 12 ++--
.../databases/ssh_tunnel/commands/delete_test.py | 10 +--
.../databases/ssh_tunnel/commands/update_test.py | 10 +--
tests/unit_tests/databases/ssh_tunnel/dao_tests.py | 4 +-
tests/unit_tests/datasets/api_tests.py | 8 ++-
tests/unit_tests/datasets/commands/export_test.py | 8 ++-
.../datasets/commands/importers/v1/import_test.py | 58 ++++++++--------
tests/unit_tests/datasets/dao/dao_tests.py | 10 +--
tests/unit_tests/datasource/dao_tests.py | 14 ++--
tests/unit_tests/db_engine_specs/test_druid.py | 16 ++---
tests/unit_tests/db_engine_specs/test_pinot.py | 8 +--
tests/unit_tests/extensions/test_sqlalchemy.py | 27 ++++----
tests/unit_tests/queries/dao_test.py | 4 +-
tests/unit_tests/sql_lab_test.py | 10 +--
tests/unit_tests/sql_parse_tests.py | 3 +-
tests/unit_tests/tables/test_models.py | 10 +--
tests/unit_tests/tags/commands/create_test.py | 29 ++++----
tests/unit_tests/tags/commands/update_test.py | 23 ++++---
96 files changed, 656 insertions(+), 730 deletions(-)
diff --git a/.pylintrc b/.pylintrc
index 6083060624..1cab7a587a 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -108,7 +108,7 @@ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / stateme
good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x,y
# Bad variable names which should always be refused, separated by a comma
-bad-names=fd,foo,bar,baz,toto,tutu,tata
+bad-names=bar,baz,db,fd,foo,sesh,session,tata,toto,tutu
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py
index fc6a9ad3c4..ebf94b444a 100755
--- a/superset/cli/importexport.py
+++ b/superset/cli/importexport.py
@@ -214,7 +214,7 @@ def legacy_export_dashboards(
# pylint: disable=import-outside-toplevel
from superset.utils import dashboard_import_export
- data = dashboard_import_export.export_dashboards(db.session)
+ data = dashboard_import_export.export_dashboards()
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
@@ -263,7 +263,6 @@ def legacy_export_datasources(
from superset.utils import dict_import_export
data = dict_import_export.export_to_dict(
- session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults,
diff --git a/superset/commands/chart/importers/v1/__init__.py b/superset/commands/chart/importers/v1/__init__.py
index f99fbb9008..7f2537383f 100644
--- a/superset/commands/chart/importers/v1/__init__.py
+++ b/superset/commands/chart/importers/v1/__init__.py
@@ -47,9 +47,7 @@ class ImportChartsCommand(ImportModelsCommand):
import_error = ChartImportError
@staticmethod
- def _import(
- session: Session, configs: dict[str, Any], overwrite: bool = False
- ) -> None:
+ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover datasets associated with charts
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
@@ -66,7 +64,7 @@ class ImportChartsCommand(ImportModelsCommand):
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
- database = import_database(session, config, overwrite=False)
+ database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
@@ -77,7 +75,7 @@ class ImportChartsCommand(ImportModelsCommand):
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
- dataset = import_dataset(session, config, overwrite=False)
+ dataset = import_dataset(config, overwrite=False)
datasets[str(dataset.uuid)] = dataset
# import charts with the correct parent ref
@@ -101,4 +99,4 @@ class ImportChartsCommand(ImportModelsCommand):
if "query_context" in config:
config["query_context"] = None
- import_chart(session, config, overwrite=overwrite)
+ import_chart(config, overwrite=overwrite)
diff --git a/superset/commands/chart/importers/v1/utils.py b/superset/commands/chart/importers/v1/utils.py
index 2aac3ea9c4..f1b38e7ddc 100644
--- a/superset/commands/chart/importers/v1/utils.py
+++ b/superset/commands/chart/importers/v1/utils.py
@@ -20,9 +20,7 @@ import json
from inspect import isclass
from typing import Any
-from sqlalchemy.orm import Session
-
-from superset import security_manager
+from superset import db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.migrations.shared.migrate_viz import processors
from superset.migrations.shared.migrate_viz.base import MigrateViz
@@ -46,13 +44,12 @@ def filter_chart_annotations(chart_config: dict[str, Any]) -> None:
def import_chart(
- session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:
can_write = ignore_permissions or security_manager.can_access("can_write", "Chart")
- existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
+ existing = db.session.query(Slice).filter_by(uuid=config["uuid"]).first()
if existing:
if overwrite and can_write and get_user():
if not security_manager.can_access_chart(existing):
@@ -76,11 +73,9 @@ def import_chart(
# migrate old viz types to new ones
config = migrate_chart(config)
- chart = Slice.import_from_dict(
- session, config, recursive=False, allow_reparenting=True
- )
+ chart = Slice.import_from_dict(config, recursive=False, allow_reparenting=True)
if chart.id is None:
- session.flush()
+ db.session.flush()
if user := get_user():
chart.owners.append(user)
diff --git a/superset/commands/dashboard/importers/v1/__init__.py b/superset/commands/dashboard/importers/v1/__init__.py
index 62f5f393e9..77d28696cf 100644
--- a/superset/commands/dashboard/importers/v1/__init__.py
+++ b/superset/commands/dashboard/importers/v1/__init__.py
@@ -21,6 +21,7 @@ from marshmallow import Schema
from sqlalchemy.orm import Session
from sqlalchemy.sql import select
+from superset import db
from superset.charts.schemas import ImportV1ChartSchema
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.dashboard.exceptions import DashboardImportError
@@ -59,9 +60,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
# TODO (betodealmeida): refactor to use code from other commands
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
- def _import(
- session: Session, configs: dict[str, Any], overwrite: bool = False
- ) -> None:
+ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover charts and datasets associated with dashboards
chart_uuids: set[str] = set()
dataset_uuids: set[str] = set()
@@ -87,7 +86,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
- database = import_database(session, config, overwrite=False)
+ database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
@@ -98,7 +97,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
- dataset = import_dataset(session, config, overwrite=False)
+ dataset = import_dataset(config, overwrite=False)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": dataset.datasource_type,
@@ -122,12 +121,12 @@ class ImportDashboardsCommand(ImportModelsCommand):
if "query_context" in config:
config["query_context"] = None
- chart = import_chart(session, config, overwrite=False)
+ chart = import_chart(config, overwrite=False)
charts.append(chart)
chart_ids[str(chart.uuid)] = chart.id
# store the existing relationship between dashboards and charts
- existing_relationships = session.execute(
+ existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
).fetchall()
@@ -137,7 +136,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
- dashboard = import_dashboard(session, config, overwrite=overwrite)
+ dashboard = import_dashboard(config, overwrite=overwrite)
dashboards.append(dashboard)
for uuid in find_chart_uuids(config["position"]):
if uuid not in chart_ids:
@@ -151,7 +150,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
- session.execute(dashboard_slices.insert(), values)
+ db.session.execute(dashboard_slices.insert(), values)
# Migrate any filter-box charts to native dashboard filters.
for dashboard in dashboards:
@@ -160,4 +159,4 @@ class ImportDashboardsCommand(ImportModelsCommand):
# Remove all obsolete filter-box charts.
for chart in charts:
if chart.viz_type == "filter_box":
- session.delete(chart)
+ db.session.delete(chart)
diff --git a/superset/commands/dashboard/importers/v1/utils.py b/superset/commands/dashboard/importers/v1/utils.py
index b8ac3144db..09be75a6ea 100644
--- a/superset/commands/dashboard/importers/v1/utils.py
+++ b/superset/commands/dashboard/importers/v1/utils.py
@@ -19,9 +19,7 @@ import json
import logging
from typing import Any
-from sqlalchemy.orm import Session
-
-from superset import security_manager
+from superset import db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.models.dashboard import Dashboard
from superset.utils.core import get_user
@@ -146,7 +144,6 @@ def update_id_refs( # pylint: disable=too-many-locals
def import_dashboard(
- session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
@@ -155,7 +152,7 @@ def import_dashboard(
"can_write",
"Dashboard",
)
- existing = session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
+ existing = db.session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
if existing:
if overwrite and can_write and get_user():
if not security_manager.can_access_dashboard(existing):
@@ -187,9 +184,9 @@ def import_dashboard(
except TypeError:
logger.info("Unable to encode `%s` field: %s", key, value)
- dashboard = Dashboard.import_from_dict(session, config, recursive=False)
+ dashboard = Dashboard.import_from_dict(config, recursive=False)
if dashboard.id is None:
- session.flush()
+ db.session.flush()
if user := get_user():
dashboard.owners.append(user)
diff --git a/superset/commands/database/importers/v1/__init__.py b/superset/commands/database/importers/v1/__init__.py
index 73b1bca531..203f0e3089 100644
--- a/superset/commands/database/importers/v1/__init__.py
+++ b/superset/commands/database/importers/v1/__init__.py
@@ -43,14 +43,12 @@ class ImportDatabasesCommand(ImportModelsCommand):
import_error = DatabaseImportError
@staticmethod
- def _import(
- session: Session, configs: dict[str, Any], overwrite: bool = False
- ) -> None:
+ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# first import databases
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
- database = import_database(session, config, overwrite=overwrite)
+ database = import_database(config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id
# import related datasets
@@ -61,4 +59,4 @@ class ImportDatabasesCommand(ImportModelsCommand):
):
config["database_id"] = database_ids[config["database_uuid"]]
# overwrite=False prevents deleting any non-imported columns/metrics
- import_dataset(session, config, overwrite=False)
+ import_dataset(config, overwrite=False)
diff --git a/superset/commands/database/importers/v1/utils.py b/superset/commands/database/importers/v1/utils.py
index c8c2847b9f..17b8488b44 100644
--- a/superset/commands/database/importers/v1/utils.py
+++ b/superset/commands/database/importers/v1/utils.py
@@ -18,9 +18,7 @@
import json
from typing import Any
-from sqlalchemy.orm import Session
-
-from superset import app, security_manager
+from superset import app, db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
@@ -30,7 +28,6 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
def import_database(
- session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
@@ -39,7 +36,7 @@ def import_database(
"can_write",
"Database",
)
- existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
+ existing = db.session.query(Database).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite or not can_write:
return existing
@@ -67,12 +64,12 @@ def import_database(
# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)
- database = Database.import_from_dict(session, config, recursive=False)
+ database = Database.import_from_dict(config, recursive=False)
if database.id is None:
- session.flush()
+ db.session.flush()
if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
- SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
+ SSHTunnel.import_from_dict(ssh_tunnel, recursive=False)
return database
diff --git a/superset/commands/dataset/importers/v0.py b/superset/commands/dataset/importers/v0.py
index d389a17651..6c1d79779e 100644
--- a/superset/commands/dataset/importers/v0.py
+++ b/superset/commands/dataset/importers/v0.py
@@ -20,7 +20,6 @@ from typing import Any, Callable, Optional
import yaml
from flask_appbuilder import Model
-from sqlalchemy.orm import Session
from sqlalchemy.orm.session import make_transient
from superset import db
@@ -86,7 +85,6 @@ def import_dataset(
raise DatasetInvalidError
return import_datasource(
- db.session,
i_datasource,
lookup_database,
lookup_datasource,
@@ -95,9 +93,9 @@ def import_dataset(
)
-def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric:
+def lookup_sqla_metric(metric: SqlMetric) -> SqlMetric:
return (
- session.query(SqlMetric)
+ db.session.query(SqlMetric)
.filter(
SqlMetric.table_id == metric.table_id,
SqlMetric.metric_name == metric.metric_name,
@@ -106,13 +104,13 @@ def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric:
)
-def import_metric(session: Session, metric: SqlMetric) -> SqlMetric:
- return import_simple_obj(session, metric, lookup_sqla_metric)
+def import_metric(metric: SqlMetric) -> SqlMetric:
+ return import_simple_obj(metric, lookup_sqla_metric)
-def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn:
+def lookup_sqla_column(column: TableColumn) -> TableColumn:
return (
- session.query(TableColumn)
+ db.session.query(TableColumn)
.filter(
TableColumn.table_id == column.table_id,
TableColumn.column_name == column.column_name,
@@ -121,12 +119,11 @@ def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn:
)
-def import_column(session: Session, column: TableColumn) -> TableColumn:
- return import_simple_obj(session, column, lookup_sqla_column)
+def import_column(column: TableColumn) -> TableColumn:
+ return import_simple_obj(column, lookup_sqla_column)
-def import_datasource( # pylint: disable=too-many-arguments
- session: Session,
+def import_datasource(
i_datasource: Model,
lookup_database: Callable[[Model], Optional[Model]],
lookup_datasource: Callable[[Model], Optional[Model]],
@@ -155,11 +152,11 @@ def import_datasource( # pylint: disable=too-many-arguments
if datasource:
datasource.override(i_datasource)
- session.flush()
+ db.session.flush()
else:
datasource = i_datasource.copy()
- session.add(datasource)
- session.flush()
+ db.session.add(datasource)
+ db.session.flush()
for metric in i_datasource.metrics:
new_m = metric.copy()
@@ -169,7 +166,7 @@ def import_datasource( # pylint: disable=too-many-arguments
new_m.to_json(),
i_datasource.full_name,
)
- imported_m = import_metric(session, new_m)
+ imported_m = import_metric(new_m)
if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]:
datasource.metrics.append(imported_m)
@@ -181,44 +178,40 @@ def import_datasource( # pylint: disable=too-many-arguments
new_c.to_json(),
i_datasource.full_name,
)
- imported_c = import_column(session, new_c)
+ imported_c = import_column(new_c)
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
datasource.columns.append(imported_c)
- session.flush()
+ db.session.flush()
return datasource.id
-def import_simple_obj(
- session: Session, i_obj: Model, lookup_obj: Callable[[Session, Model], Model]
-) -> Model:
+def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
# find if the column was already imported
- existing_column = lookup_obj(session, i_obj)
+ existing_column = lookup_obj(i_obj)
i_obj.table = None
if existing_column:
existing_column.override(i_obj)
- session.flush()
+ db.session.flush()
return existing_column
- session.add(i_obj)
- session.flush()
+ db.session.add(i_obj)
+ db.session.flush()
return i_obj
-def import_from_dict(
- session: Session, data: dict[str, Any], sync: Optional[list[str]] = None
-) -> None:
+def import_from_dict(data: dict[str, Any], sync: Optional[list[str]] = None) -> None:
"""Imports databases from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
- Database.import_from_dict(session, database, sync=sync)
- session.commit()
+ Database.import_from_dict(database, sync=sync)
+ db.session.commit()
else:
logger.info("Supplied object is not a dictionary.")
@@ -254,7 +247,7 @@ class ImportDatasetsCommand(BaseCommand):
for file_name, config in self._configs.items():
logger.info("Importing dataset from file %s", file_name)
if isinstance(config, dict):
- import_from_dict(db.session, config, sync=self.sync)
+ import_from_dict(config, sync=self.sync)
else: # list
for dataset in config:
# UI exports don't have the database metadata, so we assume
@@ -266,7 +259,7 @@ class ImportDatasetsCommand(BaseCommand):
.one()
)
dataset["database_id"] = database.id
- SqlaTable.import_from_dict(db.session, dataset, sync=self.sync)
+ SqlaTable.import_from_dict(dataset, sync=self.sync)
def validate(self) -> None:
# ensure all files are YAML
diff --git a/superset/commands/dataset/importers/v1/__init__.py b/superset/commands/dataset/importers/v1/__init__.py
index 600a39bf48..29f850258c 100644
--- a/superset/commands/dataset/importers/v1/__init__.py
+++ b/superset/commands/dataset/importers/v1/__init__.py
@@ -43,9 +43,7 @@ class ImportDatasetsCommand(ImportModelsCommand):
import_error = DatasetImportError
@staticmethod
- def _import(
- session: Session, configs: dict[str, Any], overwrite: bool = False
- ) -> None:
+ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover databases associated with datasets
database_uuids: set[str] = set()
for file_name, config in configs.items():
@@ -56,7 +54,7 @@ class ImportDatasetsCommand(ImportModelsCommand):
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
- database = import_database(session, config, overwrite=False)
+ database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
@@ -66,4 +64,4 @@ class ImportDatasetsCommand(ImportModelsCommand):
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
- import_dataset(session, config, overwrite=overwrite)
+ import_dataset(config, overwrite=overwrite)
diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py
index 014a864da4..04fc81e241 100644
--- a/superset/commands/dataset/importers/v1/utils.py
+++ b/superset/commands/dataset/importers/v1/utils.py
@@ -25,10 +25,9 @@ import pandas as pd
from flask import current_app
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.exc import MultipleResultsFound
-from sqlalchemy.orm import Session
from sqlalchemy.sql.visitors import VisitableType
-from superset import security_manager
+from superset import db, security_manager
from superset.commands.dataset.exceptions import DatasetForbiddenDataURI
from superset.commands.exceptions import ImportFailedError
from superset.connectors.sqla.models import SqlaTable
@@ -103,7 +102,6 @@ def validate_data_uri(data_uri: str) -> None:
def import_dataset(
- session: Session,
config: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
@@ -113,7 +111,7 @@ def import_dataset(
"can_write",
"Dataset",
)
- existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
+ existing = db.session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite or not can_write:
return existing
@@ -150,7 +148,7 @@ def import_dataset(
# import recursively to include columns and metrics
try:
- dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
+ dataset = SqlaTable.import_from_dict(config, recursive=True, sync=sync)
except MultipleResultsFound:
# Finding multiple results when importing a dataset only happens because initially
# datasets were imported without schemas (eg, `examples.NULL.users`), and later
@@ -160,10 +158,10 @@ def import_dataset(
# `examples.public.users`, resulting in a conflict.
#
# When that happens, we return the original dataset, unmodified.
- dataset = session.query(SqlaTable).filter_by(uuid=config["uuid"]).one()
+ dataset = db.session.query(SqlaTable).filter_by(uuid=config["uuid"]).one()
if dataset.id is None:
- session.flush()
+ db.session.flush()
try:
table_exists = dataset.database.has_table_by_name(dataset.table_name)
@@ -175,7 +173,7 @@ def import_dataset(
table_exists = True
if data_uri and (not table_exists or force_data):
- load_data(data_uri, dataset, dataset.database, session)
+ load_data(data_uri, dataset, dataset.database)
if user := get_user():
dataset.owners.append(user)
@@ -183,9 +181,7 @@ def import_dataset(
return dataset
-def load_data(
- data_uri: str, dataset: SqlaTable, database: Database, session: Session
-) -> None:
+def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None:
"""
Load data from a data URI into a dataset.
@@ -208,7 +204,7 @@ def load_data(
# reuse session when loading data if possible, to make import atomic
if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"):
logger.info("Loading data inside the import transaction")
- connection = session.connection()
+ connection = db.session.connection()
df.to_sql(
dataset.table_name,
con=connection,
diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py
index 38d6568af4..8d90875fd3 100644
--- a/superset/commands/importers/v1/__init__.py
+++ b/superset/commands/importers/v1/__init__.py
@@ -60,9 +60,7 @@ class ImportModelsCommand(BaseCommand):
self._configs: dict[str, Any] = {}
@staticmethod
- def _import(
- session: Session, configs: dict[str, Any], overwrite: bool = False
- ) -> None:
+ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
raise NotImplementedError("Subclasses MUST implement _import")
@classmethod
@@ -74,7 +72,7 @@ class ImportModelsCommand(BaseCommand):
# rollback to prevent partial imports
try:
- self._import(db.session, self._configs, self.overwrite)
+ self._import(self._configs, self.overwrite)
db.session.commit()
except CommandException as ex:
db.session.rollback()
diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py
index fe9539ac80..876ce509ae 100644
--- a/superset/commands/importers/v1/assets.py
+++ b/superset/commands/importers/v1/assets.py
@@ -18,7 +18,6 @@ from typing import Any, Optional
from marshmallow import Schema
from marshmallow.exceptions import ValidationError
-from sqlalchemy.orm import Session
from sqlalchemy.sql import delete, insert
from superset import db
@@ -80,26 +79,26 @@ class ImportAssetsCommand(BaseCommand):
# pylint: disable=too-many-locals
@staticmethod
- def _import(session: Session, configs: dict[str, Any]) -> None:
+ def _import(configs: dict[str, Any]) -> None:
# import databases first
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
- database = import_database(session, config, overwrite=True)
+ database = import_database(config, overwrite=True)
database_ids[str(database.uuid)] = database.id
# import saved queries
for file_name, config in configs.items():
if file_name.startswith("queries/"):
config["db_id"] = database_ids[config["database_uuid"]]
- import_saved_query(session, config, overwrite=True)
+ import_saved_query(config, overwrite=True)
# import datasets
dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = database_ids[config["database_uuid"]]
- dataset = import_dataset(session, config, overwrite=True)
+ dataset = import_dataset(config, overwrite=True)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": dataset.datasource_type,
@@ -118,7 +117,7 @@ class ImportAssetsCommand(BaseCommand):
config["params"].update({"datasource": dataset_uid})
if "query_context" in config:
config["query_context"] = None
- chart = import_chart(session, config, overwrite=True)
+ chart = import_chart(config, overwrite=True)
charts.append(chart)
chart_ids[str(chart.uuid)] = chart.id
@@ -126,7 +125,7 @@ class ImportAssetsCommand(BaseCommand):
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
- dashboard = import_dashboard(session, config, overwrite=True)
+ dashboard = import_dashboard(config, overwrite=True)
# set ref in the dashboard_slices table
dashboard_chart_ids: list[dict[str, int]] = []
@@ -140,12 +139,12 @@ class ImportAssetsCommand(BaseCommand):
}
dashboard_chart_ids.append(dashboard_chart_id)
- session.execute(
+ db.session.execute(
delete(dashboard_slices).where(
dashboard_slices.c.dashboard_id == dashboard.id
)
)
- session.execute(insert(dashboard_slices).values(dashboard_chart_ids))
+ db.session.execute(insert(dashboard_slices).values(dashboard_chart_ids))
# Migrate any filter-box charts to native dashboard filters.
migrate_dashboard(dashboard)
@@ -153,14 +152,14 @@ class ImportAssetsCommand(BaseCommand):
# Remove all obsolete filter-box charts.
for chart in charts:
if chart.viz_type == "filter_box":
- session.delete(chart)
+ db.session.delete(chart)
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
- self._import(db.session, self._configs)
+ self._import(self._configs)
db.session.commit()
except Exception as ex:
db.session.rollback()
diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py
index 87280033eb..ff69aadc45 100644
--- a/superset/commands/importers/v1/examples.py
+++ b/superset/commands/importers/v1/examples.py
@@ -18,7 +18,6 @@ from typing import Any
from marshmallow import Schema
from sqlalchemy.exc import MultipleResultsFound
-from sqlalchemy.orm import Session
from sqlalchemy.sql import select
from superset import db
@@ -70,7 +69,6 @@ class ImportExamplesCommand(ImportModelsCommand):
# rollback to prevent partial imports
try:
self._import(
- db.session,
self._configs,
self.overwrite,
self.force_data,
@@ -92,7 +90,6 @@ class ImportExamplesCommand(ImportModelsCommand):
@staticmethod
def _import( # pylint: disable=too-many-locals, too-many-branches
- session: Session,
configs: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
@@ -102,7 +99,6 @@ class ImportExamplesCommand(ImportModelsCommand):
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(
- session,
config,
overwrite=overwrite,
ignore_permissions=True,
@@ -133,7 +129,6 @@ class ImportExamplesCommand(ImportModelsCommand):
try:
dataset = import_dataset(
- session,
config,
overwrite=overwrite,
force_data=force_data,
@@ -164,7 +159,6 @@ class ImportExamplesCommand(ImportModelsCommand):
# update datasource id, type, and name
config.update(dataset_info[config["dataset_uuid"]])
chart = import_chart(
- session,
config,
overwrite=overwrite,
ignore_permissions=True,
@@ -172,7 +166,7 @@ class ImportExamplesCommand(ImportModelsCommand):
chart_ids[str(chart.uuid)] = chart.id
# store the existing relationship between dashboards and charts
- existing_relationships = session.execute(
+ existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
).fetchall()
@@ -186,7 +180,6 @@ class ImportExamplesCommand(ImportModelsCommand):
continue
dashboard = import_dashboard(
- session,
config,
overwrite=overwrite,
ignore_permissions=True,
@@ -203,4 +196,4 @@ class ImportExamplesCommand(ImportModelsCommand):
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
- session.execute(dashboard_slices.insert(), values)
+ db.session.execute(dashboard_slices.insert(), values)
diff --git a/superset/commands/query/importers/v1/__init__.py b/superset/commands/query/importers/v1/__init__.py
index fa1f21b6fc..f251759c38 100644
--- a/superset/commands/query/importers/v1/__init__.py
+++ b/superset/commands/query/importers/v1/__init__.py
@@ -43,9 +43,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
import_error = SavedQueryImportError
@staticmethod
- def _import(
- session: Session, configs: dict[str, Any], overwrite: bool = False
- ) -> None:
+ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover databases associated with saved queries
database_uuids: set[str] = set()
for file_name, config in configs.items():
@@ -56,7 +54,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
- database = import_database(session, config, overwrite=False)
+ database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import saved queries with the correct parent ref
@@ -66,4 +64,4 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
and config["database_uuid"] in database_ids
):
config["db_id"] = database_ids[config["database_uuid"]]
- import_saved_query(session, config, overwrite=overwrite)
+ import_saved_query(config, overwrite=overwrite)
diff --git a/superset/commands/query/importers/v1/utils.py b/superset/commands/query/importers/v1/utils.py
index 813f3c2295..d611aa5e3a 100644
--- a/superset/commands/query/importers/v1/utils.py
+++ b/superset/commands/query/importers/v1/utils.py
@@ -17,22 +17,19 @@
from typing import Any
-from sqlalchemy.orm import Session
-
+from superset import db
from superset.models.sql_lab import SavedQuery
-def import_saved_query(
- session: Session, config: dict[str, Any], overwrite: bool = False
-) -> SavedQuery:
- existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
+def import_saved_query(config: dict[str, Any], overwrite: bool = False) -> SavedQuery:
+ existing = db.session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id
- saved_query = SavedQuery.import_from_dict(session, config, recursive=False)
+ saved_query = SavedQuery.import_from_dict(config, recursive=False)
if saved_query.id is None:
- session.flush()
+ db.session.flush()
return saved_query
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 08dc923c21..2552740695 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -65,7 +65,6 @@ from sqlalchemy.orm import (
reconstructor,
relationship,
RelationshipProperty,
- Session,
)
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
@@ -1902,13 +1901,12 @@ class SqlaTable(
@classmethod
def query_datasources_by_name(
cls,
- session: Session,
database: Database,
datasource_name: str,
schema: str | None = None,
) -> list[SqlaTable]:
query = (
- session.query(cls)
+ db.session.query(cls)
.filter_by(database_id=database.id)
.filter_by(table_name=datasource_name)
)
@@ -1919,14 +1917,13 @@ class SqlaTable(
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
- session: Session,
database: Database,
permissions: set[str],
schema_perms: set[str],
) -> list[SqlaTable]:
# TODO(hughhhh): add unit test
return (
- session.query(cls)
+ db.session.query(cls)
.filter_by(database_id=database.id)
.filter(
or_(
@@ -1951,8 +1948,8 @@ class SqlaTable(
)
@classmethod
- def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
- qry = session.query(cls)
+ def get_all_datasources(cls) -> list[SqlaTable]:
+ qry = db.session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@@ -2034,7 +2031,7 @@ class SqlaTable(
:param connection: Unused.
:param target: The metric or column that was updated.
"""
- session = inspect(target).session
+ session = inspect(target).session # pylint: disable=disallowed-name
# Forces an update to the table's changed_on value when a metric or column on the
# table is updated. This busts the cache key for all charts that use the table.
@@ -2068,7 +2065,7 @@ class SqlaTable(
if self.database_id and (
not self.database or self.database.id != self.database_id
):
- session = inspect(self).session
+ session = inspect(self).session # pylint: disable=disallowed-name
self.database = session.query(Database).filter_by(id=self.database_id).one()
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 688be53515..58a90e6eca 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -26,10 +26,10 @@ from flask_babel import lazy_gettext as _
from sqlalchemy.engine.url import URL as SqlaURL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.ext.declarative import DeclarativeMeta
-from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from sqlalchemy.sql.type_api import TypeEngine
+from superset import db
from superset.constants import LRU_CACHE_MAX_SIZE
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
@@ -168,14 +168,12 @@ logger = logging.getLogger(__name__)
def find_cached_objects_in_session(
- session: Session,
cls: type[DeclarativeModel],
ids: Iterable[int] | None = None,
uuids: Iterable[UUID] | None = None,
) -> Iterator[DeclarativeModel]:
"""Find known ORM instances in cached SQLA session states.
- :param session: a SQLA session
:param cls: a SQLA DeclarativeModel
:param ids: ids of the desired model instances (optional)
:param uuids: uuids of the desired instances, will be ignored if `ids` are provides
@@ -184,7 +182,7 @@ def find_cached_objects_in_session(
return iter([])
uuids = uuids or []
try:
- items = list(session)
+ items = list(db.session)
except ObjectDeletedError:
logger.warning("ObjectDeletedError", exc_info=True)
return iter(())
diff --git a/superset/daos/base.py b/superset/daos/base.py
index 1133a76a1e..ed6471ac81 100644
--- a/superset/daos/base.py
+++ b/superset/daos/base.py
@@ -22,7 +22,6 @@ from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError, StatementError
-from sqlalchemy.orm import Session
from superset.daos.exceptions import (
DAOCreateFailedError,
@@ -59,16 +58,14 @@ class BaseDAO(Generic[T]):
def find_by_id(
cls,
model_id: str | int,
- session: Session = None,
skip_base_filter: bool = False,
) -> T | None:
"""
Find a model by id, if defined applies `base_filter`
"""
- session = session or db.session
- query = session.query(cls.model_cls)
+ query = db.session.query(cls.model_cls)
if cls.base_filter and not skip_base_filter:
- data_model = SQLAInterface(cls.model_cls, session)
+ data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
@@ -83,7 +80,6 @@ class BaseDAO(Generic[T]):
def find_by_ids(
cls,
model_ids: list[str] | list[int],
- session: Session = None,
skip_base_filter: bool = False,
) -> list[T]:
"""
@@ -92,10 +88,9 @@ class BaseDAO(Generic[T]):
id_col = getattr(cls.model_cls, cls.id_column_name, None)
if id_col is None:
return []
- session = session or db.session
- query = session.query(cls.model_cls).filter(id_col.in_(model_ids))
+ query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
if cls.base_filter and not skip_base_filter:
- data_model = SQLAInterface(cls.model_cls, session)
+ data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 384a62c9d3..33748da4b6 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -86,8 +86,8 @@ class DatabaseUploadEnabledFilter(BaseFilter): # pylint: disable=too-few-public
if hasattr(g, "user"):
allowed_schemas = [
- app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](db, g.user)
- for db in datasource_access_databases
+ app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](database, g.user)
+ for database in datasource_access_databases
]
if len(allowed_schemas):
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index b78a24ec12..18349f4314 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -310,7 +310,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
@staticmethod
def _do_post(
- session: Session,
+ session: Session, # pylint: disable=disallowed-name
url: str,
body: dict[str, Any],
**kwargs: Any,
@@ -385,7 +385,9 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
conn,
spreadsheet_url or EXAMPLE_GSHEETS_URL,
)
- session = adapter._get_session() # pylint: disable=protected-access
+ session = ( # pylint: disable=disallowed-name
+ adapter._get_session() # pylint: disable=protected-access
+ )
# clear existing sheet, or create a new one
if spreadsheet_url:
diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py
index c68332738b..65ba7eebc8 100644
--- a/superset/extensions/__init__.py
+++ b/superset/extensions/__init__.py
@@ -122,7 +122,7 @@ async_query_manager: AsyncQueryManager = LocalProxy(
cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
-db = SQLA()
+db = SQLA() # pylint: disable=disallowed-name
_event_logger: dict[str, Any] = {}
encrypted_field_factory = EncryptedFieldFactory()
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index ef346dbd62..01b1bf9624 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -64,7 +64,7 @@ def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard)
if dashboard_id is None:
return
- session = sqla.inspect(target).session
+ session = sqla.inspect(target).session # pylint: disable=disallowed-name
new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 9322e8c46d..fb2f959f31 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -46,13 +46,13 @@ from jinja2.exceptions import TemplateError
from sqlalchemy import and_, Column, or_, UniqueConstraint
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr
-from sqlalchemy.orm import Mapper, Session, validates
+from sqlalchemy.orm import Mapper, validates
from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy_utils import UUIDType
-from superset import app, is_feature_enabled, security_manager
+from superset import app, db, is_feature_enabled, security_manager
from superset.advanced_data_type.types import AdvancedDataTypeResponse
from superset.common.db_query_status import QueryStatus
from superset.common.utils.time_range_utils import get_since_until_from_time_range
@@ -245,7 +245,6 @@ class ImportExportMixin:
def import_from_dict(
# pylint: disable=too-many-arguments,too-many-branches,too-many-locals
cls,
- session: Session,
dict_rep: dict[Any, Any],
parent: Optional[Any] = None,
recursive: bool = True,
@@ -303,7 +302,7 @@ class ImportExportMixin:
# Check if object already exists in DB, break if more than one is found
try:
- obj_query = session.query(cls).filter(and_(*filters))
+ obj_query = db.session.query(cls).filter(and_(*filters))
obj = obj_query.one_or_none()
except MultipleResultsFound as ex:
logger.error(
@@ -322,7 +321,7 @@ class ImportExportMixin:
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
- session.add(obj)
+ db.session.add(obj)
else:
is_new_obj = False
logger.info("Updating %s %s", obj.__tablename__, str(obj))
@@ -341,7 +340,7 @@ class ImportExportMixin:
for c_obj in new_children.get(child, []):
added.append(
child_class.import_from_dict(
- session=session, dict_rep=c_obj, parent=obj, sync=sync
+ dict_rep=c_obj, parent=obj, sync=sync
)
)
# If children should get synced, delete the ones that did not
@@ -353,11 +352,11 @@ class ImportExportMixin:
for k in back_refs.keys()
]
to_delete = set(
- session.query(child_class).filter(and_(*delete_filters))
+ db.session.query(child_class).filter(and_(*delete_filters))
).difference(set(added))
for o in to_delete:
logger.info("Deleting %s %s", child, str(obj))
- session.delete(o)
+ db.session.delete(o)
return obj
diff --git a/superset/security/manager.py b/superset/security/manager.py
index ffc4da250f..356ea06852 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -453,7 +453,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
level=ErrorLevel.ERROR,
)
- def get_chart_access_error_object( # pylint: disable=invalid-name
+ def get_chart_access_error_object(
self,
dashboard: "Dashboard", # pylint: disable=unused-argument
) -> SupersetError:
@@ -576,7 +576,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
# group all datasources by database
- all_datasources = SqlaTable.get_all_datasources(self.get_session)
+ all_datasources = SqlaTable.get_all_datasources()
datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)
@@ -714,7 +714,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = SqlaTable.query_datasources_by_permissions(
- self.get_session, database, user_perms, schema_perms
+ database, user_perms, schema_perms
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
@@ -781,7 +781,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self.add_permission_view_menu(view_menu, perm)
logger.info("Creating missing datasource permissions.")
- datasources = SqlaTable.get_all_datasources(self.get_session)
+ datasources = SqlaTable.get_all_datasources()
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())
@@ -797,8 +797,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
logger.info("Cleaning faulty perms")
- sesh = self.get_session
- pvms = sesh.query(PermissionView).filter(
+ pvms = self.get_session.query(PermissionView).filter(
or_(
PermissionView.permission # pylint: disable=singleton-comparison
== None,
@@ -806,7 +805,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
== None,
)
)
- sesh.commit()
+ self.get_session.commit()
if deleted_count := pvms.delete():
logger.info("Deleted %i faulty permissions", deleted_count)
@@ -1925,7 +1924,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if not (schema_perm and self.can_access("schema_access", schema_perm)):
datasources = SqlaTable.query_datasources_by_name(
- self.get_session, database, table_.table, schema=table_.schema
+ database, table_.table, schema=table_.schema
)
# Access to any datasource is suffice.
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index 66f90a6e92..dba54cd3b5 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -66,7 +66,7 @@ class ExecutePayloadSchema(Schema):
class QueryResultSchema(Schema):
changed_on = fields.DateTime()
dbId = fields.Integer()
- db = fields.String() # pylint: disable=invalid-name
+ db = fields.String() # pylint: disable=disallowed-name
endDttm = fields.Float()
errorMessage = fields.String(allow_none=True)
executedSql = fields.String()
diff --git a/superset/tables/models.py b/superset/tables/models.py
index 11f1021197..2616aaf90f 100644
--- a/superset/tables/models.py
+++ b/superset/tables/models.py
@@ -169,7 +169,7 @@ class Table(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
)
default_props = default_props or {}
- session: Session = inspect(database).session
+ session: Session = inspect(database).session # pylint: disable=disallowed-name
# load existing tables
predicate = or_(
*[
diff --git a/superset/tags/models.py b/superset/tags/models.py
index 1e8ca7de1a..7361441940 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -131,7 +131,9 @@ class TaggedObject(Model, AuditMixinNullable):
return f"<TaggedObject: {self.object_type}:{self.object_id} TAG:{self.tag_id}>"
-def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag:
+def get_tag(
+ name: str, session: orm.Session, type_: TagType # pylint: disable=disallowed-name
+) -> Tag:
tag_name = name.strip()
tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none()
if tag is None:
@@ -168,7 +170,7 @@ class ObjectUpdater:
@classmethod
def get_owner_tag_ids(
cls,
- session: orm.Session,
+ session: orm.Session, # pylint: disable=disallowed-name
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> set[int]:
tag_ids = set()
@@ -181,7 +183,7 @@ class ObjectUpdater:
@classmethod
def _add_owners(
cls,
- session: orm.Session,
+ session: orm.Session, # pylint: disable=disallowed-name
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
for owner_id in cls.get_owners_ids(target):
@@ -193,7 +195,11 @@ class ObjectUpdater:
@classmethod
def add_tag_object_if_not_tagged(
- cls, session: orm.Session, tag_id: int, object_id: int, object_type: str
+ cls,
+ session: orm.Session, # pylint: disable=disallowed-name
+ tag_id: int,
+ object_id: int,
+ object_type: str,
) -> None:
# Check if the object is already tagged
exists_query = exists().where(
@@ -217,7 +223,7 @@ class ObjectUpdater:
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
- with Session(bind=connection) as session:
+ with Session(bind=connection) as session: # pylint: disable=disallowed-name
# add `owner:` tags
cls._add_owners(session, target)
@@ -235,7 +241,7 @@ class ObjectUpdater:
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
- with Session(bind=connection) as session:
+ with Session(bind=connection) as session: # pylint: disable=disallowed-name
# Fetch current owner tags
existing_tags = (
session.query(TaggedObject)
@@ -274,7 +280,7 @@ class ObjectUpdater:
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
- with Session(bind=connection) as session:
+ with Session(bind=connection) as session: # pylint: disable=disallowed-name
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
@@ -321,7 +327,7 @@ class FavStarUpdater:
def after_insert(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
- with Session(bind=connection) as session:
+ with Session(bind=connection) as session: # pylint: disable=disallowed-name
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagType.favorited_by)
tagged_object = TaggedObject(
@@ -336,7 +342,7 @@ class FavStarUpdater:
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
- with Session(bind=connection) as session:
+ with Session(bind=connection) as session: # pylint: disable=disallowed-name
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py
index eef8cbe6df..c21761dadb 100644
--- a/superset/utils/dashboard_import_export.py
+++ b/superset/utils/dashboard_import_export.py
@@ -16,17 +16,16 @@
# under the License.
import logging
-from sqlalchemy.orm import Session
-
+from superset import db
from superset.models.dashboard import Dashboard
logger = logging.getLogger(__name__)
-def export_dashboards(session: Session) -> str:
+def export_dashboards() -> str:
"""Returns all dashboards metadata as a json dump"""
logger.info("Starting export")
- dashboards = session.query(Dashboard)
+ dashboards = db.session.query(Dashboard)
dashboard_ids = set()
for dashboard in dashboards:
dashboard_ids.add(dashboard.id)
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index fbd9db7d81..7b3d995249 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -17,8 +17,7 @@
import logging
from typing import Any
-from sqlalchemy.orm import Session
-
+from superset import db
from superset.models.core import Database
EXPORT_VERSION = "1.0.0"
@@ -38,11 +37,11 @@ def export_schema_to_dict(back_references: bool) -> dict[str, Any]:
def export_to_dict(
- session: Session, recursive: bool, back_references: bool, include_defaults: bool
+ recursive: bool, back_references: bool, include_defaults: bool
) -> dict[str, Any]:
"""Exports databases to a dictionary"""
logger.info("Starting export")
- dbs = session.query(Database)
+ dbs = db.session.query(Database)
databases = [
database.export_to_dict(
recursive=recursive,
diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py
index 0040ec60f6..81ad8ccc9f 100644
--- a/tests/integration_tests/base_tests.py
+++ b/tests/integration_tests/base_tests.py
@@ -187,24 +187,22 @@ class SupersetTestCase(TestCase):
except ImportError:
return False
- def get_or_create(self, cls, criteria, session, **kwargs):
- obj = session.query(cls).filter_by(**criteria).first()
+ def get_or_create(self, cls, criteria, **kwargs):
+ obj = db.session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
obj.__dict__.update(**kwargs)
- session.add(obj)
- session.commit()
+ db.session.add(obj)
+ db.session.commit()
return obj
def login(self, username="admin", password="general"):
return login(self.client, username, password)
- def get_slice(
- self, slice_name: str, session: Session, expunge_from_session: bool = True
- ) -> Slice:
- slc = session.query(Slice).filter_by(slice_name=slice_name).one()
+ def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice:
+ slc = db.session.query(Slice).filter_by(slice_name=slice_name).one()
if expunge_from_session:
- session.expunge_all()
+ db.session.expunge_all()
return slc
@staticmethod
@@ -353,7 +351,6 @@ class SupersetTestCase(TestCase):
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
- session=db.session,
sqlalchemy_uri="sqlite:///:memory:",
id=db_id,
extra=extra,
@@ -375,7 +372,6 @@ class SupersetTestCase(TestCase):
database = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
- session=db.session,
sqlalchemy_uri="db_for_macros_testing://user@host:8080/hive",
id=db_id,
)
@@ -398,8 +394,7 @@ class SupersetTestCase(TestCase):
db.session.commit()
def get_dash_by_slug(self, dash_slug):
- sesh = db.session()
- return sesh.query(Dashboard).filter_by(slug=dash_slug).first()
+ return db.session.query(Dashboard).filter_by(slug=dash_slug).first()
def get_assert_metric(self, uri: str, func_name: str) -> Response:
"""
@@ -522,11 +517,10 @@ class SupersetTestCase(TestCase):
@contextmanager
def db_insert_temp_object(obj: DeclarativeMeta):
"""Insert a temporary object in database; delete when done."""
- session = db.session
try:
- session.add(obj)
- session.commit()
+ db.session.add(obj)
+ db.session.commit()
yield obj
finally:
- session.delete(obj)
- session.commit()
+ db.session.delete(obj)
+ db.session.commit()
diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py
index b2a8704dfb..89093db864 100644
--- a/tests/integration_tests/cache_tests.py
+++ b/tests/integration_tests/cache_tests.py
@@ -46,7 +46,7 @@ class TestCache(SupersetTestCase):
app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "NullCache"}
cache_manager.init_app(app)
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
json_endpoint = "/superset/explore_json/{}/{}/".format(
slc.datasource_type, slc.datasource_id
)
@@ -73,7 +73,7 @@ class TestCache(SupersetTestCase):
}
cache_manager.init_app(app)
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
json_endpoint = "/superset/explore_json/{}/{}/".format(
slc.datasource_type, slc.datasource_id
)
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index a58ce1779e..d0985124e2 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -453,7 +453,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
"""
Chart API: Test create chart
"""
- dashboards_ids = get_dashboards_ids(db, ["world_health", "births"])
+ dashboards_ids = get_dashboards_ids(["world_health", "births"])
admin_id = self.get_user("admin").id
chart_data = {
"slice_name": "name1",
@@ -1736,7 +1736,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache(self, slice_name):
self.login()
- slc = self.get_slice(slice_name, db.session)
+ slc = self.get_slice(slice_name)
rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id})
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
@@ -1815,7 +1815,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_error(self) -> None:
self.login()
- slc = self.get_slice("Pivot Table v2", db.session)
+ slc = self.get_slice("Pivot Table v2")
with mock.patch.object(ChartDataCommand, "run") as mock_run:
mock_run.side_effect = ChartDataQueryFailedError(
@@ -1843,7 +1843,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_no_query_context(self) -> None:
self.login()
- slc = self.get_slice("Pivot Table v2", db.session)
+ slc = self.get_slice("Pivot Table v2")
with mock.patch.object(Slice, "get_query_context") as mock_get_query_context:
mock_get_query_context.return_value = None
@@ -1866,7 +1866,7 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_no_datasource(self) -> None:
self.login()
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
with mock.patch.object(
Slice,
diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py
index a72a716d17..6ee3e45b5f 100644
--- a/tests/integration_tests/charts/commands_tests.py
+++ b/tests/integration_tests/charts/commands_tests.py
@@ -413,7 +413,7 @@ class TestChartWarmUpCacheCommand(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache(self):
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
result = ChartWarmUpCacheCommand(slc.id, None, None).run()
self.assertEqual(
result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"}
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index 9e1a9ad11c..1b1e128b07 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -135,7 +135,7 @@ class TestCore(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_viz_cache_key(self):
self.login(username="admin")
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
viz = slc.viz
qobj = viz.query_obj()
@@ -175,7 +175,7 @@ class TestCore(SupersetTestCase):
def test_save_slice(self):
self.login(username="admin")
slice_name = f"Energy Sankey"
- slice_id = self.get_slice(slice_name, db.session).id
+ slice_id = self.get_slice(slice_name).id
copy_name_prefix = "Test Sankey"
copy_name = f"{copy_name_prefix}[save]{random.random()}"
tbl_id = self.table_ids.get("energy_usage")
@@ -242,7 +242,6 @@ class TestCore(SupersetTestCase):
self.login(username="admin")
slc = self.get_slice(
slice_name="Top 10 Girl Name Share",
- session=db.session,
expunge_from_session=False,
)
slc_data_attributes = slc.data.keys()
@@ -356,7 +355,7 @@ class TestCore(SupersetTestCase):
)
def test_warm_up_cache(self):
self.login()
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
self.assertEqual(
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
@@ -381,7 +380,7 @@ class TestCore(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_warm_up_cache_error(self) -> None:
self.login()
- slc = self.get_slice("Pivot Table v2", db.session)
+ slc = self.get_slice("Pivot Table v2")
with mock.patch.object(
ChartDataCommand,
@@ -406,7 +405,7 @@ class TestCore(SupersetTestCase):
self.login("admin")
store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
assert ck.datasource_uid == f"{slc.table.id}__table"
@@ -1172,7 +1171,7 @@ class TestCore(SupersetTestCase):
random_key = "random_key"
mock_command.return_value = random_key
slice_name = f"Energy Sankey"
- slice_id = self.get_slice(slice_name, db.session).id
+ slice_id = self.get_slice(slice_name).id
form_data = {"slice_id": slice_id, "viz_type": "line", "datasource": "1__table"}
rv = self.client.get(
f"/superset/explore/?form_data={quote(json.dumps(form_data))}"
diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py
index d809880bf7..623572c713 100644
--- a/tests/integration_tests/dashboards/api_tests.py
+++ b/tests/integration_tests/dashboards/api_tests.py
@@ -1661,7 +1661,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
Dashboard API: Test dashboard export
"""
self.login(username="admin")
- dashboards_ids = get_dashboards_ids(db, ["world_health", "births"])
+ dashboards_ids = get_dashboards_ids(["world_health", "births"])
uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}"
rv = self.get_assert_metric(uri, "export")
@@ -1699,7 +1699,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
"""
Dashboard API: Test dashboard export
"""
- dashboards_ids = get_dashboards_ids(db, ["world_health", "births"])
+ dashboards_ids = get_dashboards_ids(["world_health", "births"])
uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}"
self.login(username="admin")
diff --git a/tests/integration_tests/dashboards/filter_state/api_tests.py b/tests/integration_tests/dashboards/filter_state/api_tests.py
index 3538e14012..4dd02bfb65 100644
--- a/tests/integration_tests/dashboards/filter_state/api_tests.py
+++ b/tests/integration_tests/dashboards/filter_state/api_tests.py
@@ -22,6 +22,7 @@ from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
+from superset import db
from superset.commands.dashboard.exceptions import DashboardAccessDeniedError
from superset.commands.temporary_cache.entry import Entry
from superset.extensions import cache_manager
@@ -40,15 +41,13 @@ UPDATED_VALUE = json.dumps({"test": "updated value"})
@pytest.fixture
def dashboard_id(app_context: AppContext, load_world_bank_dashboard_with_slices) -> int:
- session: Session = app_context.app.appbuilder.get_session
- dashboard = session.query(Dashboard).filter_by(slug="world_health").one()
+ dashboard = db.session.query(Dashboard).filter_by(slug="world_health").one()
return dashboard.id
@pytest.fixture
def admin_id(app_context: AppContext) -> int:
- session: Session = app_context.app.appbuilder.get_session
- admin = session.query(User).filter_by(username="admin").one_or_none()
+ admin = db.session.query(User).filter_by(username="admin").one_or_none()
return admin.id
diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py
index a49f1e6f4c..bfa20fd8a3 100644
--- a/tests/integration_tests/dashboards/permalink/api_tests.py
+++ b/tests/integration_tests/dashboards/permalink/api_tests.py
@@ -42,10 +42,8 @@ STATE = {
@pytest.fixture
def dashboard_id(load_world_bank_dashboard_with_slices) -> int:
- with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
- dashboard = session.query(Dashboard).filter_by(slug="world_health").one()
- return dashboard.id
+ dashboard = db.session.query(Dashboard).filter_by(slug="world_health").one()
+ return dashboard.id
@pytest.fixture
diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py
index 88495b03b4..aeae6171df 100644
--- a/tests/integration_tests/dashboards/superset_factory_util.py
+++ b/tests/integration_tests/dashboards/superset_factory_util.py
@@ -38,8 +38,6 @@ from tests.integration_tests.dashboards.dashboard_test_utils import (
logger = logging.getLogger(__name__)
-session = db.session
-
inserted_dashboards_ids = []
inserted_databases_ids = []
inserted_sqltables_ids = []
@@ -99,9 +97,9 @@ def create_dashboard(
def insert_model(dashboard: Model) -> None:
- session.add(dashboard)
- session.commit()
- session.refresh(dashboard)
+ db.session.add(dashboard)
+ db.session.commit()
+ db.session.refresh(dashboard)
def create_slice_to_db(
@@ -193,7 +191,7 @@ def delete_all_inserted_objects() -> None:
def delete_all_inserted_dashboards():
try:
dashboards_to_delete: list[Dashboard] = (
- session.query(Dashboard)
+ db.session.query(Dashboard)
.filter(Dashboard.id.in_(inserted_dashboards_ids))
.all()
)
@@ -204,7 +202,7 @@ def delete_all_inserted_dashboards():
logger.error(f"failed to delete {dashboard.id}", exc_info=True)
raise ex
if len(inserted_dashboards_ids) > 0:
- session.commit()
+ db.session.commit()
inserted_dashboards_ids.clear()
except Exception as ex2:
logger.error("delete_all_inserted_dashboards failed", exc_info=True)
@@ -216,25 +214,25 @@ def delete_dashboard(dashboard: Dashboard, do_commit: bool = False) -> None:
delete_dashboard_roles_associations(dashboard)
delete_dashboard_users_associations(dashboard)
delete_dashboard_slices_associations(dashboard)
- session.delete(dashboard)
+ db.session.delete(dashboard)
if do_commit:
- session.commit()
+ db.session.commit()
def delete_dashboard_users_associations(dashboard: Dashboard) -> None:
- session.execute(
+ db.session.execute(
dashboard_user.delete().where(dashboard_user.c.dashboard_id == dashboard.id)
)
def delete_dashboard_roles_associations(dashboard: Dashboard) -> None:
- session.execute(
+ db.session.execute(
DashboardRoles.delete().where(DashboardRoles.c.dashboard_id == dashboard.id)
)
def delete_dashboard_slices_associations(dashboard: Dashboard) -> None:
- session.execute(
+ db.session.execute(
dashboard_slices.delete().where(dashboard_slices.c.dashboard_id == dashboard.id)
)
@@ -242,7 +240,7 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None:
def delete_all_inserted_slices():
try:
slices_to_delete: list[Slice] = (
- session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all()
+ db.session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all()
)
for slice in slices_to_delete:
try:
@@ -251,7 +249,7 @@ def delete_all_inserted_slices():
logger.error(f"failed to delete {slice.id}", exc_info=True)
raise ex
if len(inserted_slices_ids) > 0:
- session.commit()
+ db.session.commit()
inserted_slices_ids.clear()
except Exception as ex2:
logger.error("delete_all_inserted_slices failed", exc_info=True)
@@ -261,19 +259,19 @@ def delete_all_inserted_slices():
def delete_slice(slice_: Slice, do_commit: bool = False) -> None:
logger.info(f"deleting slice{slice_.id}")
delete_slice_users_associations(slice_)
- session.delete(slice_)
+ db.session.delete(slice_)
if do_commit:
- session.commit()
+ db.session.commit()
def delete_slice_users_associations(slice_: Slice) -> None:
- session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id))
+ db.session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id))
def delete_all_inserted_tables():
try:
tables_to_delete: list[SqlaTable] = (
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter(SqlaTable.id.in_(inserted_sqltables_ids))
.all()
)
@@ -284,7 +282,7 @@ def delete_all_inserted_tables():
logger.error(f"failed to delete {table.id}", exc_info=True)
raise ex
if len(inserted_sqltables_ids) > 0:
- session.commit()
+ db.session.commit()
inserted_sqltables_ids.clear()
except Exception as ex2:
logger.error("delete_all_inserted_tables failed", exc_info=True)
@@ -294,32 +292,32 @@ def delete_all_inserted_tables():
def delete_sqltable(table: SqlaTable, do_commit: bool = False) -> None:
logger.info(f"deleting table{table.id}")
delete_table_users_associations(table)
- session.delete(table)
+ db.session.delete(table)
if do_commit:
- session.commit()
+ db.session.commit()
def delete_table_users_associations(table: SqlaTable) -> None:
- session.execute(
+ db.session.execute(
sqlatable_user.delete().where(sqlatable_user.c.table_id == table.id)
)
def delete_all_inserted_dbs():
try:
- dbs_to_delete: list[Database] = (
- session.query(Database)
+ databases_to_delete: list[Database] = (
+ db.session.query(Database)
.filter(Database.id.in_(inserted_databases_ids))
.all()
)
- for db in dbs_to_delete:
+ for database in databases_to_delete:
try:
- delete_database(db, False)
+ delete_database(database, False)
except Exception as ex:
- logger.error(f"failed to delete {db.id}", exc_info=True)
+ logger.error(f"failed to delete {database.id}", exc_info=True)
raise ex
if len(inserted_databases_ids) > 0:
- session.commit()
+ db.session.commit()
inserted_databases_ids.clear()
except Exception as ex2:
logger.error("delete_all_inserted_databases failed", exc_info=True)
@@ -328,6 +326,6 @@ def delete_all_inserted_dbs():
def delete_database(database: Database, do_commit: bool = False) -> None:
logger.info(f"deleting database{database.id}")
- session.delete(database)
+ db.session.delete(database)
if do_commit:
- session.commit()
+ db.session.commit()
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index f7b8cc0ec8..ebabc16e87 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -1365,12 +1365,11 @@ class TestDatabaseApi(SupersetTestCase):
"""
Database API: Test get select star with datasource access
"""
- session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
- session.add(table)
- session.commit()
+ db.session.add(table)
+ db.session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
@@ -1732,15 +1731,14 @@ class TestDatabaseApi(SupersetTestCase):
with self.create_app().app_context():
main_db = get_main_database()
main_db.allow_file_upload = True
- session = db.session
table = SqlaTable(
schema="public",
table_name="ab_permission",
database=get_main_database(),
)
- session.add(table)
- session.commit()
+ db.session.add(table)
+ db.session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 3530bdec1a..1ebe5bd1f7 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -1748,7 +1748,6 @@ class TestDatasetApi(SupersetTestCase):
assert rv.status_code == 200
cli_export = export_to_dict(
- session=db.session,
recursive=True,
back_references=False,
include_defaults=False,
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index 4e05b63002..91e843fc3f 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -79,7 +79,6 @@ class TestDatasource(SupersetTestCase):
def test_always_filter_main_dttm(self):
self.login(username="admin")
- session = db.session
database = get_example_database()
sql = f"SELECT DATE() as default_dttm, DATE() as additional_dttm, 1 as metric;"
@@ -115,8 +114,8 @@ class TestDatasource(SupersetTestCase):
sql=sql,
)
- session.add(table)
- session.commit()
+ db.session.add(table)
+ db.session.commit()
table.always_filter_main_dttm = False
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
@@ -126,27 +125,26 @@ class TestDatasource(SupersetTestCase):
result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause)
assert "default_dttm" in result and "additional_dttm" in result
- session.delete(table)
- session.commit()
+ db.session.delete(table)
+ db.session.commit()
def test_external_metadata_for_virtual_table(self):
self.login(username="admin")
- session = db.session
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
- session.add(table)
- session.commit()
+ db.session.add(table)
+ db.session.commit()
table = self.get_table(name="dummy_sql_table")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
- session.delete(table)
- session.commit()
+ db.session.delete(table)
+ db.session.commit()
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_external_metadata_by_name_for_physical_table(self):
@@ -171,15 +169,14 @@ class TestDatasource(SupersetTestCase):
def test_external_metadata_by_name_for_virtual_table(self):
self.login(username="admin")
- session = db.session
table = SqlaTable(
table_name="dummy_sql_table",
database=get_example_database(),
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
- session.add(table)
- session.commit()
+ db.session.add(table)
+ db.session.commit()
tbl = self.get_table(name="dummy_sql_table")
params = prison.dumps(
@@ -195,8 +192,8 @@ class TestDatasource(SupersetTestCase):
url = f"/datasource/external_metadata_by_name/?q={params}"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol", "strcol"}
- session.delete(tbl)
- session.commit()
+ db.session.delete(tbl)
+ db.session.commit()
def test_external_metadata_by_name_from_sqla_inspector(self):
self.login(username="admin")
@@ -265,7 +262,6 @@ class TestDatasource(SupersetTestCase):
def test_external_metadata_for_virtual_table_template_params(self):
self.login(username="admin")
- session = db.session
table = SqlaTable(
table_name="dummy_sql_table_with_template_params",
database=get_example_database(),
@@ -273,15 +269,15 @@ class TestDatasource(SupersetTestCase):
sql="select {{ foo }} as intcol",
template_params=json.dumps({"foo": "123"}),
)
- session.add(table)
- session.commit()
+ db.session.add(table)
+ db.session.commit()
table = self.get_table(name="dummy_sql_table_with_template_params")
url = f"/datasource/external_metadata/table/{table.id}/"
resp = self.get_json_resp(url)
assert {o.get("column_name") for o in resp} == {"intcol"}
- session.delete(table)
- session.commit()
+ db.session.delete(table)
+ db.session.commit()
def test_external_metadata_for_malicious_virtual_table(self):
self.login(username="admin")
diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py
index 5ff20b7347..bf4d7e8b9f 100644
--- a/tests/integration_tests/db_engine_specs/databricks_tests.py
+++ b/tests/integration_tests/db_engine_specs/databricks_tests.py
@@ -33,10 +33,10 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
assert get_engine_spec("databricks", "pyhive").engine == "databricks"
def test_extras_without_ssl(self):
- db = mock.Mock()
- db.extra = default_db_extra
- db.server_cert = None
- extras = DatabricksNativeEngineSpec.get_extra_params(db)
+ database = mock.Mock()
+ database.extra = default_db_extra
+ database.server_cert = None
+ extras = DatabricksNativeEngineSpec.get_extra_params(database)
assert extras == {
"engine_params": {
"connect_args": {
@@ -50,12 +50,12 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
}
def test_extras_with_ssl_custom(self):
- db = mock.Mock()
- db.extra = default_db_extra.replace(
+ database = mock.Mock()
+ database.extra = default_db_extra.replace(
'"engine_params": {}',
'"engine_params": {"connect_args": {"ssl": "1"}}',
)
- db.server_cert = ssl_certificate
- extras = DatabricksNativeEngineSpec.get_extra_params(db)
+ database.server_cert = ssl_certificate
+ extras = DatabricksNativeEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["ssl"] == "1"
diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py
index 341b494927..374d99c02e 100644
--- a/tests/integration_tests/db_engine_specs/hive_tests.py
+++ b/tests/integration_tests/db_engine_specs/hive_tests.py
@@ -337,14 +337,14 @@ def test_fetch_data_success(fetch_data_mock):
@mock.patch("superset.db_engine_specs.hive.HiveEngineSpec._latest_partition_from_df")
def test_where_latest_partition(mock_method):
mock_method.return_value = ("01-01-19", 1)
- db = mock.Mock()
- db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
- db.get_extra = mock.Mock(return_value={})
- db.get_df = mock.Mock()
+ database = mock.Mock()
+ database.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
+ database.get_extra = mock.Mock(return_value={})
+ database.get_df = mock.Mock()
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select(), columns
+ "test_table", "test_schema", database, select(), columns
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
@@ -353,11 +353,11 @@ def test_where_latest_partition(mock_method):
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
def test_where_latest_partition_super_method_exception(mock_method):
mock_method.side_effect = Exception()
- db = mock.Mock()
+ database = mock.Mock()
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
- "test_table", "test_schema", db, select(), columns
+ "test_table", "test_schema", database, select(), columns
)
assert result is None
mock_method.assert_called()
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index 2b543e8e25..0f4841fb35 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -119,29 +119,29 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
assert "postgres" in backends
def test_extras_without_ssl(self):
- db = mock.Mock()
- db.extra = default_db_extra
- db.server_cert = None
- extras = PostgresEngineSpec.get_extra_params(db)
+ database = mock.Mock()
+ database.extra = default_db_extra
+ database.server_cert = None
+ extras = PostgresEngineSpec.get_extra_params(database)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl_default(self):
- db = mock.Mock()
- db.extra = default_db_extra
- db.server_cert = ssl_certificate
- extras = PostgresEngineSpec.get_extra_params(db)
+ database = mock.Mock()
+ database.extra = default_db_extra
+ database.server_cert = ssl_certificate
+ extras = PostgresEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-full"
assert "sslrootcert" in connect_args
def test_extras_with_ssl_custom(self):
- db = mock.Mock()
- db.extra = default_db_extra.replace(
+ database = mock.Mock()
+ database.extra = default_db_extra.replace(
'"engine_params": {}',
'"engine_params": {"connect_args": {"sslmode": "verify-ca"}}',
)
- db.server_cert = ssl_certificate
- extras = PostgresEngineSpec.get_extra_params(db)
+ database.server_cert = ssl_certificate
+ extras = PostgresEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["sslmode"] == "verify-ca"
assert "sslrootcert" in connect_args
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 7e151648a6..c28e78afe6 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -550,13 +550,17 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_presto_extra_table_metadata(self):
- db = mock.Mock()
- db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}])
- db.get_extra = mock.Mock(return_value={})
+ database = mock.Mock()
+ database.get_indexes = mock.Mock(
+ return_value=[{"column_names": ["ds", "hour"]}]
+ )
+ database.get_extra = mock.Mock(return_value={})
df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
- db.get_df = mock.Mock(return_value=df)
+ database.get_df = mock.Mock(return_value=df)
PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
- result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
+ result = PrestoEngineSpec.extra_table_metadata(
+ database, "test_table", "test_schema"
+ )
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py
index 6018e59a92..b4dddff09f 100644
--- a/tests/integration_tests/dict_import_export_tests.py
+++ b/tests/integration_tests/dict_import_export_tests.py
@@ -43,11 +43,10 @@ class TestDictImportExport(SupersetTestCase):
def delete_imports(cls):
with app.app_context():
# Imported data clean up
- session = db.session
- for table in session.query(SqlaTable):
+ for table in db.session.query(SqlaTable):
if DBREF in table.params_dict:
- session.delete(table)
- session.commit()
+ db.session.delete(table)
+ db.session.commit()
@classmethod
def setUpClass(cls):
@@ -124,7 +123,7 @@ class TestDictImportExport(SupersetTestCase):
def test_import_table_no_metadata(self):
table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
- new_table = SqlaTable.import_from_dict(db.session, dict_table)
+ new_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
imported_id = new_table.id
imported = self.get_table_by_id(imported_id)
@@ -139,7 +138,7 @@ class TestDictImportExport(SupersetTestCase):
cols_uuids=[uuid4()],
metric_names=["metric1"],
)
- imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
@@ -156,7 +155,7 @@ class TestDictImportExport(SupersetTestCase):
cols_uuids=[uuid4(), uuid4()],
metric_names=["m1", "m2"],
)
- imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
@@ -166,7 +165,7 @@ class TestDictImportExport(SupersetTestCase):
table, dict_table = self.create_table(
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
- imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
table_over, dict_table_over = self.create_table(
"table_override",
@@ -174,7 +173,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
+ imported_over_table = SqlaTable.import_from_dict(dict_table_over)
db.session.commit()
imported_over = self.get_table_by_id(imported_over_table.id)
@@ -195,7 +194,7 @@ class TestDictImportExport(SupersetTestCase):
table, dict_table = self.create_table(
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
- imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
table_over, dict_table_over = self.create_table(
"table_override",
@@ -204,7 +203,7 @@ class TestDictImportExport(SupersetTestCase):
metric_names=["new_metric1"],
)
imported_over_table = SqlaTable.import_from_dict(
- session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
+ dict_rep=dict_table_over, sync=["metrics", "columns"]
)
db.session.commit()
@@ -229,7 +228,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_table = SqlaTable.import_from_dict(db.session, dict_table)
+ imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
copy_table, dict_copy_table = self.create_table(
"copy_cat",
@@ -237,7 +236,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
+ imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
db.session.commit()
self.assertEqual(imported_table.id, imported_copy_table.id)
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
@@ -250,7 +249,6 @@ class TestDictImportExport(SupersetTestCase):
self.delete_fake_db()
cli_export = export_to_dict(
- session=db.session,
recursive=True,
back_references=False,
include_defaults=False,
diff --git a/tests/integration_tests/explore/api_tests.py b/tests/integration_tests/explore/api_tests.py
index e37200e310..c0b7f5fcd4 100644
--- a/tests/integration_tests/explore/api_tests.py
+++ b/tests/integration_tests/explore/api_tests.py
@@ -21,6 +21,7 @@ import pytest
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
+from superset import db
from superset.commands.explore.form_data.state import TemporaryExploreState
from superset.connectors.sqla.models import SqlaTable
from superset.explore.exceptions import DatasetAccessDeniedError
@@ -39,25 +40,22 @@ FORM_DATA = {"test": "test value"}
@pytest.fixture
def chart_id(load_world_bank_dashboard_with_slices) -> int:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
- chart = session.query(Slice).filter_by(slice_name="World's Population").one()
+ chart = db.session.query(Slice).filter_by(slice_name="World's Population").one()
return chart.id
@pytest.fixture
def admin_id() -> int:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
- admin = session.query(User).filter_by(username="admin").one()
+ admin = db.session.query(User).filter_by(username="admin").one()
return admin.id
@pytest.fixture
def dataset() -> int:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
dataset = (
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py
index 5dbd67d4f5..9187e46213 100644
--- a/tests/integration_tests/explore/form_data/api_tests.py
+++ b/tests/integration_tests/explore/form_data/api_tests.py
@@ -21,6 +21,7 @@ import pytest
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
+from superset import db
from superset.commands.dataset.exceptions import DatasetAccessDeniedError
from superset.commands.explore.form_data.state import TemporaryExploreState
from superset.connectors.sqla.models import SqlaTable
@@ -41,25 +42,22 @@ UPDATED_FORM_DATA = json.dumps({"test": "updated value"})
@pytest.fixture
def chart_id(load_world_bank_dashboard_with_slices) -> int:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
- chart = session.query(Slice).filter_by(slice_name="World's Population").one()
+ chart = db.session.query(Slice).filter_by(slice_name="World's Population").one()
return chart.id
@pytest.fixture
def admin_id() -> int:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
- admin = session.query(User).filter_by(username="admin").one()
+ admin = db.session.query(User).filter_by(username="admin").one()
return admin.id
@pytest.fixture
def datasource() -> int:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
dataset = (
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py
index 781c4fdbb2..293a2c556f 100644
--- a/tests/integration_tests/explore/form_data/commands_tests.py
+++ b/tests/integration_tests/explore/form_data/commands_tests.py
@@ -45,22 +45,22 @@ class TestCreateFormDataCommand(SupersetTestCase):
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
- session = db.session
- session.add(dataset)
- session.commit()
+ db.session.add(dataset)
+ db.session.commit()
yield dataset
# rollback
- session.delete(dataset)
- session.commit()
+ db.session.delete(dataset)
+ db.session.commit()
@pytest.fixture()
def create_slice(self):
with self.create_app().app_context():
- session = db.session
dataset = (
- session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
+ db.session.query(SqlaTable)
+ .filter_by(table_name="dummy_sql_table")
+ .first()
)
slice = Slice(
datasource_id=dataset.id,
@@ -69,34 +69,32 @@ class TestCreateFormDataCommand(SupersetTestCase):
slice_name="slice_name",
)
- session.add(slice)
- session.commit()
+ db.session.add(slice)
+ db.session.commit()
yield slice
# rollback
- session.delete(slice)
- session.commit()
+ db.session.delete(slice)
+ db.session.commit()
@pytest.fixture()
def create_query(self):
with self.create_app().app_context():
- session = db.session
-
query = Query(
sql="select 1 as foo;",
client_id="sldkfjlk",
database=get_example_database(),
)
- session.add(query)
- session.commit()
+ db.session.add(query)
+ db.session.commit()
yield query
# rollback
- session.delete(query)
- session.commit()
+ db.session.delete(query)
+ db.session.commit()
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice")
diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py
index 81be2f0de8..a171504cc6 100644
--- a/tests/integration_tests/explore/permalink/api_tests.py
+++ b/tests/integration_tests/explore/permalink/api_tests.py
@@ -38,8 +38,7 @@ from tests.integration_tests.test_app import app
@pytest.fixture
def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice:
- session: Session = app_context.app.appbuilder.get_session
- chart = session.query(Slice).filter_by(slice_name="World's Population").one()
+ chart = db.session.query(Slice).filter_by(slice_name="World's Population").one()
return chart
diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py
index 5402a419bc..f499591aa5 100644
--- a/tests/integration_tests/explore/permalink/commands_tests.py
+++ b/tests/integration_tests/explore/permalink/commands_tests.py
@@ -43,22 +43,22 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
- session = db.session
- session.add(dataset)
- session.commit()
+ db.session.add(dataset)
+ db.session.commit()
yield dataset
# rollback
- session.delete(dataset)
- session.commit()
+ db.session.delete(dataset)
+ db.session.commit()
@pytest.fixture()
def create_slice(self):
with self.create_app().app_context():
- session = db.session
dataset = (
- session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
+ db.session.query(SqlaTable)
+ .filter_by(table_name="dummy_sql_table")
+ .first()
)
slice = Slice(
datasource_id=dataset.id,
@@ -67,34 +67,32 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
slice_name="slice_name",
)
- session.add(slice)
- session.commit()
+ db.session.add(slice)
+ db.session.commit()
yield slice
# rollback
- session.delete(slice)
- session.commit()
+ db.session.delete(slice)
+ db.session.commit()
@pytest.fixture()
def create_query(self):
with self.create_app().app_context():
- session = db.session
-
query = Query(
sql="select 1 as foo;",
client_id="sldkfjlk",
database=get_example_database(),
)
- session.add(query)
- session.commit()
+ db.session.add(query)
+ db.session.commit()
yield query
# rollback
- session.delete(query)
- session.commit()
+ db.session.delete(query)
+ db.session.commit()
@patch("superset.security.manager.g")
@pytest.mark.usefixtures("create_dataset", "create_slice")
diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py
index 279b67eda0..fd7c69deca 100644
--- a/tests/integration_tests/fixtures/datasource.py
+++ b/tests/integration_tests/fixtures/datasource.py
@@ -177,7 +177,6 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
with app.app_context():
engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True)
meta = MetaData()
- session = db.session
students = Table(
"students",
@@ -196,8 +195,8 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
)
column = TableColumn(table_id=dataset.id, column_name="name")
dataset.columns = [column]
- session.add(dataset)
- session.commit()
+ db.session.add(dataset)
+ db.session.commit()
yield dataset
# cleanup
@@ -205,8 +204,8 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]:
if students_table is not None:
base = declarative_base()
# needed for sqlite
- session.commit()
+ db.session.commit()
base.metadata.drop_all(engine, [students_table], checkfirst=True)
- session.delete(dataset)
- session.delete(column)
- session.commit()
+ db.session.delete(dataset)
+ db.session.delete(column)
+ db.session.commit()
diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py
index adc398e785..4a1558ffd8 100644
--- a/tests/integration_tests/import_export_tests.py
+++ b/tests/integration_tests/import_export_tests.py
@@ -53,17 +53,16 @@ from .base_tests import SupersetTestCase
def delete_imports():
with app.app_context():
# Imported data clean up
- session = db.session
- for slc in session.query(Slice):
+ for slc in db.session.query(Slice):
if "remote_id" in slc.params_dict:
- session.delete(slc)
- for dash in session.query(Dashboard):
+ db.session.delete(slc)
+ for dash in db.session.query(Dashboard):
if "remote_id" in dash.params_dict:
- session.delete(dash)
- for table in session.query(SqlaTable):
+ db.session.delete(dash)
+ for table in db.session.query(SqlaTable):
if "remote_id" in table.params_dict:
- session.delete(table)
- session.commit()
+ db.session.delete(table)
+ db.session.commit()
@pytest.fixture(autouse=True, scope="module")
diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py
index ac33d003e0..6ba09c8a18 100644
--- a/tests/integration_tests/key_value/commands/fixtures.py
+++ b/tests/integration_tests/key_value/commands/fixtures.py
@@ -66,6 +66,5 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]:
@pytest.fixture
def admin() -> User:
with app.app_context() as ctx:
- session: Session = ctx.app.appbuilder.get_session
- admin = session.query(User).filter_by(username="admin").one()
+ admin = db.session.query(User).filter_by(username="admin").one()
return admin
diff --git a/tests/integration_tests/security/guest_token_security_tests.py b/tests/integration_tests/security/guest_token_security_tests.py
index b812929433..44a4cdd3ce 100644
--- a/tests/integration_tests/security/guest_token_security_tests.py
+++ b/tests/integration_tests/security/guest_token_security_tests.py
@@ -230,15 +230,14 @@ class TestGuestUserDatasourceAccess(SupersetTestCase):
schema=get_example_default_schema(),
sql="select 123 as intcol, 'abc' as strcol",
)
- session = db.session
- session.add(dataset)
- session.commit()
+ db.session.add(dataset)
+ db.session.commit()
yield dataset
# rollback
- session.delete(dataset)
- session.commit()
+ db.session.delete(dataset)
+ db.session.commit()
def setUp(self) -> None:
self.dash = self.get_dash_by_slug("births")
@@ -258,11 +257,9 @@ class TestGuestUserDatasourceAccess(SupersetTestCase):
],
}
)
- self.chart = self.get_slice("Girls", db.session, expunge_from_session=False)
+ self.chart = self.get_slice("Girls", expunge_from_session=False)
self.datasource = self.chart.datasource
- self.other_chart = self.get_slice(
- "Treemap", db.session, expunge_from_session=False
- )
+ self.other_chart = self.get_slice("Treemap", expunge_from_session=False)
self.other_datasource = self.other_chart.datasource
self.native_filter_datasource = (
db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first()
diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py
index 39d66a82aa..4ab73a713e 100644
--- a/tests/integration_tests/security/migrate_roles_tests.py
+++ b/tests/integration_tests/security/migrate_roles_tests.py
@@ -245,11 +245,10 @@ def test_migrate_role(
logger.info(description)
with create_old_role(pvm_map, external_pvms) as old_role:
role_name = old_role.name
- session = db.session
# Run migrations
- add_pvms(session, new_pvms)
- migrate_roles(session, pvm_map)
+ add_pvms(db.session, new_pvms)
+ migrate_roles(db.session, pvm_map)
role = db.session.query(Role).filter(Role.name == role_name).one_or_none()
for old_pvm, new_pvms in pvm_map.items():
diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py
index 41ca0d5e79..7518621ddd 100644
--- a/tests/integration_tests/security/row_level_security_tests.py
+++ b/tests/integration_tests/security/row_level_security_tests.py
@@ -74,8 +74,6 @@ class TestRowLevelSecurity(SupersetTestCase):
BASE_FILTER_REGEX = re.compile(r"gender = 'boy'")
def setUp(self):
- session = db.session
-
# Create roles
self.role_ab = security_manager.add_role(self.NAME_AB_ROLE)
self.role_q = security_manager.add_role(self.NAME_Q_ROLE)
@@ -83,13 +81,13 @@ class TestRowLevelSecurity(SupersetTestCase):
gamma_user.roles.append(self.role_ab)
gamma_user.roles.append(self.role_q)
self.create_user_with_roles("NoRlsRoleUser", ["Gamma"])
- session.commit()
+ db.session.commit()
# Create regular RowLevelSecurityFilter (energy_usage, unicode_test)
self.rls_entry1 = RowLevelSecurityFilter()
self.rls_entry1.name = "rls_entry1"
self.rls_entry1.tables.extend(
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
.all()
)
@@ -104,7 +102,7 @@ class TestRowLevelSecurity(SupersetTestCase):
self.rls_entry2 = RowLevelSecurityFilter()
self.rls_entry2.name = "rls_entry2"
self.rls_entry2.tables.extend(
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
.all()
)
@@ -118,7 +116,7 @@ class TestRowLevelSecurity(SupersetTestCase):
self.rls_entry3 = RowLevelSecurityFilter()
self.rls_entry3.name = "rls_entry3"
self.rls_entry3.tables.extend(
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
.all()
)
@@ -132,7 +130,7 @@ class TestRowLevelSecurity(SupersetTestCase):
self.rls_entry4 = RowLevelSecurityFilter()
self.rls_entry4.name = "rls_entry4"
self.rls_entry4.tables.extend(
- session.query(SqlaTable)
+ db.session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["birth_names"]))
.all()
)
@@ -145,15 +143,14 @@ class TestRowLevelSecurity(SupersetTestCase):
db.session.commit()
def tearDown(self):
- session = db.session
- session.delete(self.rls_entry1)
- session.delete(self.rls_entry2)
- session.delete(self.rls_entry3)
- session.delete(self.rls_entry4)
- session.delete(security_manager.find_role("NameAB"))
- session.delete(security_manager.find_role("NameQ"))
- session.delete(self.get_user("NoRlsRoleUser"))
- session.commit()
+ db.session.delete(self.rls_entry1)
+ db.session.delete(self.rls_entry2)
+ db.session.delete(self.rls_entry3)
+ db.session.delete(self.rls_entry4)
+ db.session.delete(security_manager.find_role("NameAB"))
+ db.session.delete(security_manager.find_role("NameQ"))
+ db.session.delete(self.get_user("NoRlsRoleUser"))
+ db.session.commit()
@pytest.fixture()
def create_dataset(self):
diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py
index ece9afcccb..b1f66b0d6c 100644
--- a/tests/integration_tests/security_tests.py
+++ b/tests/integration_tests/security_tests.py
@@ -1704,11 +1704,11 @@ class TestSecurityManager(SupersetTestCase):
mock_is_owner,
):
births = self.get_dash_by_slug("births")
- girls = self.get_slice("Girls", db.session, expunge_from_session=False)
+ girls = self.get_slice("Girls", expunge_from_session=False)
birth_names = girls.datasource
world_health = self.get_dash_by_slug("world_health")
- treemap = self.get_slice("Treemap", db.session, expunge_from_session=False)
+ treemap = self.get_slice("Treemap", expunge_from_session=False)
births.json_metadata = json.dumps(
{
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index 4410f19782..0dc4e26aca 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -434,8 +434,6 @@ class TestSqlLab(SupersetTestCase):
Test query api with can_access_all_queries perm added to
gamma and make sure all queries show up.
"""
- session = db.session
-
# Add all_query_access perm to Gamma user
all_queries_view = security_manager.find_permission_view_menu(
"all_query_access", "all_query_access"
@@ -444,7 +442,7 @@ class TestSqlLab(SupersetTestCase):
security_manager.add_permission_role(
security_manager.find_role("gamma_sqllab"), all_queries_view
)
- session.commit()
+ db.session.commit()
# Test search_queries for Admin user
self.run_some_queries()
@@ -461,7 +459,7 @@ class TestSqlLab(SupersetTestCase):
security_manager.find_role("gamma_sqllab"), all_queries_view
)
- session.commit()
+ db.session.commit()
def test_query_admin_can_access_all_queries(self) -> None:
"""
diff --git a/tests/integration_tests/test_jinja_context.py b/tests/integration_tests/test_jinja_context.py
index 8c2db6920d..6f776017fb 100644
--- a/tests/integration_tests/test_jinja_context.py
+++ b/tests/integration_tests/test_jinja_context.py
@@ -114,10 +114,10 @@ def test_template_hive(app_context: AppContext, mocker: MockFixture) -> None:
"superset.jinja_context.HiveTemplateProcessor.latest_partition"
)
lp_mock.return_value = "the_latest"
- db = mock.Mock()
- db.backend = "hive"
+ database = mock.Mock()
+ database.backend = "hive"
template = "{{ hive.latest_partition('my_table') }}"
- tp = get_template_processor(database=db)
+ tp = get_template_processor(database=database)
assert tp.process_template(template) == "the_latest"
@@ -126,15 +126,15 @@ def test_template_trino(app_context: AppContext, mocker: MockFixture) -> None:
"superset.jinja_context.TrinoTemplateProcessor.latest_partition"
)
lp_mock.return_value = "the_latest"
- db = mock.Mock()
- db.backend = "trino"
+ database = mock.Mock()
+ database.backend = "trino"
template = "{{ trino.latest_partition('my_table') }}"
- tp = get_template_processor(database=db)
+ tp = get_template_processor(database=database)
assert tp.process_template(template) == "the_latest"
# Backwards compatibility if migrating from Presto.
template = "{{ presto.latest_partition('my_table') }}"
- tp = get_template_processor(database=db)
+ tp = get_template_processor(database=database)
assert tp.process_template(template) == "the_latest"
@@ -154,9 +154,9 @@ def test_custom_process_template(app_context: AppContext, mocker: MockFixture) -
"tests.integration_tests.superset_test_custom_template_processors.datetime"
)
mock_dt.utcnow = mock.Mock(return_value=datetime(1970, 1, 1))
- db = mock.Mock()
- db.backend = "db_for_macros_testing"
- tp = get_template_processor(database=db)
+ database = mock.Mock()
+ database.backend = "db_for_macros_testing"
+ tp = get_template_processor(database=database)
template = "SELECT '$DATE()'"
assert tp.process_template(template) == f"SELECT '1970-01-01'"
@@ -168,28 +168,28 @@ def test_custom_process_template(app_context: AppContext, mocker: MockFixture) -
def test_custom_get_template_kwarg(app_context: AppContext) -> None:
"""Test macro passed as kwargs when getting template processor
works in custom template processor."""
- db = mock.Mock()
- db.backend = "db_for_macros_testing"
+ database = mock.Mock()
+ database.backend = "db_for_macros_testing"
template = "$foo()"
- tp = get_template_processor(database=db, foo=lambda: "bar")
+ tp = get_template_processor(database=database, foo=lambda: "bar")
assert tp.process_template(template) == "bar"
def test_custom_template_kwarg(app_context: AppContext) -> None:
"""Test macro passed as kwargs when processing template
works in custom template processor."""
- db = mock.Mock()
- db.backend = "db_for_macros_testing"
+ database = mock.Mock()
+ database.backend = "db_for_macros_testing"
template = "$foo()"
- tp = get_template_processor(database=db)
+ tp = get_template_processor(database=database)
assert tp.process_template(template, foo=lambda: "bar") == "bar"
def test_custom_template_processors_overwrite(app_context: AppContext) -> None:
"""Test template processor for presto gets overwritten by custom one."""
- db = mock.Mock()
- db.backend = "db_for_macros_testing"
- tp = get_template_processor(database=db)
+ database = mock.Mock()
+ database.backend = "db_for_macros_testing"
+ tp = get_template_processor(database=database)
template = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
assert tp.process_template(template) == template
diff --git a/tests/integration_tests/utils/get_dashboards.py b/tests/integration_tests/utils/get_dashboards.py
index 7012bf08a0..b23b372310 100644
--- a/tests/integration_tests/utils/get_dashboards.py
+++ b/tests/integration_tests/utils/get_dashboards.py
@@ -15,12 +15,11 @@
# specific language governing permissions and limitations
# under the License.
-from flask_appbuilder import SQLA
-
+from superset import db
from superset.models.dashboard import Dashboard
-def get_dashboards_ids(db: SQLA, dashboard_slugs: list[str]) -> list[int]:
+def get_dashboards_ids(dashboard_slugs: list[str]) -> list[int]:
result = (
db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all()
)
diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py
index bdbb912eec..b4ab08dc55 100644
--- a/tests/integration_tests/utils_tests.py
+++ b/tests/integration_tests/utils_tests.py
@@ -898,7 +898,7 @@ class TestUtils(SupersetTestCase):
def test_log_this(self) -> None:
# TODO: Add additional scenarios.
self.login(username="admin")
- slc = self.get_slice("Top 10 Girl Name Share", db.session)
+ slc = self.get_slice("Top 10 Girl Name Share")
dashboard_id = 1
assert slc.viz is not None
@@ -956,7 +956,7 @@ class TestUtils(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_extract_dataframe_dtypes(self):
- slc = self.get_slice("Girls", db.session)
+ slc = self.get_slice("Girls")
cols: tuple[tuple[str, GenericDataType, list[Any]], ...] = (
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
(
diff --git a/tests/unit_tests/charts/commands/importers/v1/import_test.py b/tests/unit_tests/charts/commands/importers/v1/import_test.py
index bcff3ee411..e6f6d00206 100644
--- a/tests/unit_tests/charts/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/charts/commands/importers/v1/import_test.py
@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
-from superset import security_manager
+from superset import db, security_manager
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.exceptions import ImportFailedError
from superset.connectors.sqla.models import Database, SqlaTable
@@ -82,7 +82,7 @@ def test_import_chart(mocker: MockFixture, session_with_schema: Session) -> None
config["datasource_id"] = 1
config["datasource_type"] = "table"
- chart = import_chart(session_with_schema, config)
+ chart = import_chart(config)
assert chart.slice_name == "Deck Path"
assert chart.viz_type == "deck_path"
assert chart.is_managed_externally is False
@@ -106,7 +106,7 @@ def test_import_chart_managed_externally(
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_chart"
- chart = import_chart(session_with_schema, config)
+ chart = import_chart(config)
assert chart.is_managed_externally is True
assert chart.external_url == "https://example.org/my_chart"
@@ -128,7 +128,7 @@ def test_import_chart_without_permission(
config["datasource_type"] = "table"
with pytest.raises(ImportFailedError) as excinfo:
- import_chart(session_with_schema, config)
+ import_chart(config)
assert (
str(excinfo.value)
== "Chart doesn't exist and user doesn't have permission to create charts"
@@ -173,7 +173,7 @@ def test_import_existing_chart_without_permission(
with override_user("admin"):
with pytest.raises(ImportFailedError) as excinfo:
- import_chart(session_with_data, chart_config, overwrite=True)
+ import_chart(chart_config, overwrite=True)
assert (
str(excinfo.value)
== "A chart already exists and user doesn't have permissions to overwrite it"
@@ -213,7 +213,7 @@ def test_import_existing_chart_with_permission(
)
with override_user(admin):
- import_chart(session_with_data, config, overwrite=True)
+ import_chart(config, overwrite=True)
# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
security_manager.can_access_chart.assert_called_once_with(slice)
diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py
index e8c58b5600..e811223a98 100644
--- a/tests/unit_tests/charts/dao/dao_tests.py
+++ b/tests/unit_tests/charts/dao/dao_tests.py
@@ -48,7 +48,7 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
from superset.daos.chart import ChartDAO
from superset.models.slice import Slice
- result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
+ result = ChartDAO.find_by_id(1, skip_base_filter=True)
assert result
assert 1 == result.id
@@ -57,20 +57,18 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None:
def test_datasource_find_by_id_skip_base_filter_not_found(
- session_with_data: Session,
+ session: Session,
) -> None:
from superset.daos.chart import ChartDAO
- result = ChartDAO.find_by_id(
- 125326326, session=session_with_data, skip_base_filter=True
- )
+ result = ChartDAO.find_by_id(125326326, skip_base_filter=True)
assert result is None
-def test_add_favorite(session_with_data: Session) -> None:
+def test_add_favorite(session: Session) -> None:
from superset.daos.chart import ChartDAO
- chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
+ chart = ChartDAO.find_by_id(1, skip_base_filter=True)
if not chart:
return
assert len(ChartDAO.favorited_ids([chart])) == 0
@@ -82,10 +80,10 @@ def test_add_favorite(session_with_data: Session) -> None:
assert len(ChartDAO.favorited_ids([chart])) == 1
-def test_remove_favorite(session_with_data: Session) -> None:
+def test_remove_favorite(session: Session) -> None:
from superset.daos.chart import ChartDAO
- chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True)
+ chart = ChartDAO.find_by_id(1, skip_base_filter=True)
if not chart:
return
assert len(ChartDAO.favorited_ids([chart])) == 0
diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py
index 945b337fad..9f8962f85c 100644
--- a/tests/unit_tests/charts/test_post_processing.py
+++ b/tests/unit_tests/charts/test_post_processing.py
@@ -1965,12 +1965,13 @@ def test_apply_post_process_json_format_data_is_none():
def test_apply_post_process_verbose_map(session: Session):
+ from superset import db
from superset.connectors.sqla.models import SqlaTable, SqlMetric
from superset.models.core import Database
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
@@ -1982,7 +1983,7 @@ def test_apply_post_process_verbose_map(session: Session):
expression="COUNT(*)",
)
],
- database=db,
+ database=database,
)
result = {
diff --git a/tests/unit_tests/columns/test_models.py b/tests/unit_tests/columns/test_models.py
index 068557e7a6..0ea230da17 100644
--- a/tests/unit_tests/columns/test_models.py
+++ b/tests/unit_tests/columns/test_models.py
@@ -24,9 +24,10 @@ def test_column_model(session: Session) -> None:
"""
Test basic attributes of a ``Column``.
"""
+ from superset import db
from superset.columns.models import Column
- engine = session.get_bind()
+ engine = db.session.get_bind()
Column.metadata.create_all(engine) # pylint: disable=no-member
column = Column(
@@ -35,8 +36,8 @@ def test_column_model(session: Session) -> None:
expression="ds",
)
- session.add(column)
- session.flush()
+ db.session.add(column)
+ db.session.flush()
assert column.id == 1
assert column.uuid is not None
diff --git a/tests/unit_tests/commands/importers/v1/assets_test.py b/tests/unit_tests/commands/importers/v1/assets_test.py
index d48eed1be7..9609b0b45c 100644
--- a/tests/unit_tests/commands/importers/v1/assets_test.py
+++ b/tests/unit_tests/commands/importers/v1/assets_test.py
@@ -35,14 +35,14 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
"""
Test that all new assets are imported correctly.
"""
- from superset import security_manager
+ from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
configs = {
**copy.deepcopy(databases_config),
@@ -53,11 +53,11 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None:
expected_number_of_dashboards = len(dashboards_config_1)
expected_number_of_charts = len(charts_config_1)
- ImportAssetsCommand._import(session, configs)
- dashboard_ids = session.scalars(
+ ImportAssetsCommand._import(configs)
+ dashboard_ids = db.session.scalars(
select(dashboard_slices.c.dashboard_id).distinct()
).all()
- chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
+ chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
assert len(chart_ids) == expected_number_of_charts
assert len(dashboard_ids) == expected_number_of_dashboards
@@ -67,14 +67,14 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
"""
Test that existing dashboards are updated with new charts.
"""
- from superset import security_manager
+ from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
base_configs = {
**copy.deepcopy(databases_config),
@@ -91,12 +91,12 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) ->
expected_number_of_dashboards = len(dashboards_config_1)
expected_number_of_charts = len(charts_config_1)
- ImportAssetsCommand._import(session, base_configs)
- ImportAssetsCommand._import(session, new_configs)
- dashboard_ids = session.scalars(
+ ImportAssetsCommand._import(base_configs)
+ ImportAssetsCommand._import(new_configs)
+ dashboard_ids = db.session.scalars(
select(dashboard_slices.c.dashboard_id).distinct()
).all()
- chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
+ chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
assert len(chart_ids) == expected_number_of_charts
assert len(dashboard_ids) == expected_number_of_dashboards
@@ -106,14 +106,14 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
"""
Test that existing dashboards are updated without old charts.
"""
- from superset import security_manager
+ from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
base_configs = {
**copy.deepcopy(databases_config),
@@ -130,12 +130,12 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session)
expected_number_of_dashboards = len(dashboards_config_2)
expected_number_of_charts = len(charts_config_2)
- ImportAssetsCommand._import(session, base_configs)
- ImportAssetsCommand._import(session, new_configs)
- dashboard_ids = session.scalars(
+ ImportAssetsCommand._import(base_configs)
+ ImportAssetsCommand._import(new_configs)
+ dashboard_ids = db.session.scalars(
select(dashboard_slices.c.dashboard_id).distinct()
).all()
- chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all()
+ chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all()
assert len(chart_ids) == expected_number_of_charts
assert len(dashboard_ids) == expected_number_of_dashboards
diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py
index a69d9eaede..837c53ec07 100644
--- a/tests/unit_tests/config_test.py
+++ b/tests/unit_tests/config_test.py
@@ -23,6 +23,8 @@ import pytest
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
+from superset import db
+
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@@ -81,7 +83,7 @@ def test_table(session: Session) -> "SqlaTable":
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
columns = [
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index beb4e99472..3824d7b7b7 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -41,7 +41,7 @@ from superset.initialization import SupersetAppInitializer
@pytest.fixture
def get_session(mocker: MockFixture) -> Callable[[], Session]:
"""
- Create an in-memory SQLite session to test models.
+ Create an in-memory SQLite db.session.to test models.
"""
engine = create_engine("sqlite://")
@@ -49,7 +49,7 @@ def get_session(mocker: MockFixture) -> Callable[[], Session]:
Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name
in_memory_session = Session_()
- # flask calls session.remove()
+ # flask calls db.session.remove()
in_memory_session.remove = lambda: None
# patch session
diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py
index 288f68cae0..1e3d1ec975 100644
--- a/tests/unit_tests/dao/dataset_test.py
+++ b/tests/unit_tests/dao/dataset_test.py
@@ -27,6 +27,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
In particular, allow datasets with the same name in the same database as long as they
are in different schemas
"""
+ from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
@@ -46,8 +47,8 @@ def test_validate_update_uniqueness(session: Session) -> None:
schema="dev",
database=database,
)
- session.add_all([database, dataset1, dataset2])
- session.flush()
+ db.session.add_all([database, dataset1, dataset2])
+ db.session.flush()
# same table name, different schema
assert (
diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py
index 65e9bbfbfb..eb84b288fd 100644
--- a/tests/unit_tests/dao/queries_test.py
+++ b/tests/unit_tests/dao/queries_test.py
@@ -25,17 +25,18 @@ from superset.exceptions import QueryNotFoundException, SupersetCancelQueryExcep
def test_query_dao_save_metadata(session: Session) -> None:
+ from superset import db
from superset.models.core import Database
from superset.models.sql_lab import Query
- engine = session.get_bind()
+ engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -48,30 +49,31 @@ def test_query_dao_save_metadata(session: Session) -> None:
results_key="abc",
)
- session.add(db)
- session.add(query_obj)
+ db.session.add(database)
+ db.session.add(query_obj)
from superset.daos.query import QueryDAO
- query = session.query(Query).one()
+ query = db.session.query(Query).one()
QueryDAO.save_metadata(query=query, payload={"columns": []})
assert query.extra.get("columns", None) == []
def test_query_dao_get_queries_changed_after(session: Session) -> None:
+ from superset import db
from superset.models.core import Database
from superset.models.sql_lab import Query
- engine = session.get_bind()
+ engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
now = datetime.utcnow()
old_query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -87,7 +89,7 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
updated_query_obj = Query(
client_id="updated_foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from foo",
@@ -101,9 +103,9 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
changed_on=now - timedelta(days=1),
)
- session.add(db)
- session.add(old_query_obj)
- session.add(updated_query_obj)
+ db.session.add(database)
+ db.session.add(old_query_obj)
+ db.session.add(updated_query_obj)
from superset.daos.query import QueryDAO
@@ -116,18 +118,19 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None:
def test_query_dao_stop_query_not_found(
mocker: MockFixture, app: Any, session: Session
) -> None:
+ from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
- engine = session.get_bind()
+ engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -141,8 +144,8 @@ def test_query_dao_stop_query_not_found(
status=QueryStatus.RUNNING,
)
- session.add(db)
- session.add(query_obj)
+ db.session.add(database)
+ db.session.add(query_obj)
mocker.patch("superset.sql_lab.cancel_query", return_value=False)
@@ -151,25 +154,26 @@ def test_query_dao_stop_query_not_found(
with pytest.raises(QueryNotFoundException):
QueryDAO.stop_query("foo2")
- query = session.query(Query).one()
+ query = db.session.query(Query).one()
assert query.status == QueryStatus.RUNNING
def test_query_dao_stop_query_not_running(
mocker: MockFixture, app: Any, session: Session
) -> None:
+ from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
- engine = session.get_bind()
+ engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -183,31 +187,32 @@ def test_query_dao_stop_query_not_running(
status=QueryStatus.FAILED,
)
- session.add(db)
- session.add(query_obj)
+ db.session.add(database)
+ db.session.add(query_obj)
from superset.daos.query import QueryDAO
QueryDAO.stop_query(query_obj.client_id)
- query = session.query(Query).one()
+ query = db.session.query(Query).one()
assert query.status == QueryStatus.FAILED
def test_query_dao_stop_query_failed(
mocker: MockFixture, app: Any, session: Session
) -> None:
+ from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
- engine = session.get_bind()
+ engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -221,8 +226,8 @@ def test_query_dao_stop_query_failed(
status=QueryStatus.RUNNING,
)
- session.add(db)
- session.add(query_obj)
+ db.session.add(database)
+ db.session.add(query_obj)
mocker.patch("superset.sql_lab.cancel_query", return_value=False)
@@ -231,23 +236,24 @@ def test_query_dao_stop_query_failed(
with pytest.raises(SupersetCancelQueryException):
QueryDAO.stop_query(query_obj.client_id)
- query = session.query(Query).one()
+ query = db.session.query(Query).one()
assert query.status == QueryStatus.RUNNING
def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -> None:
+ from superset import db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
- engine = session.get_bind()
+ engine = db.session.get_bind()
Query.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -261,13 +267,13 @@ def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -
status=QueryStatus.RUNNING,
)
- session.add(db)
- session.add(query_obj)
+ db.session.add(database)
+ db.session.add(query_obj)
mocker.patch("superset.sql_lab.cancel_query", return_value=True)
from superset.daos.query import QueryDAO
QueryDAO.stop_query(query_obj.client_id)
- query = session.query(Query).one()
+ query = db.session.query(Query).one()
assert query.status == QueryStatus.STOPPED
diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py
index 5f29d0f28c..652d3729b7 100644
--- a/tests/unit_tests/dao/tag_test.py
+++ b/tests/unit_tests/dao/tag_test.py
@@ -70,7 +70,7 @@ def test_remove_user_favorite_tag(mocker):
# Check that users_favorited no longer contains the user
assert mock_user not in mock_tag.users_favorited
- # Check that the session was committed
+ # Check that the db.session.was committed
mock_session.commit.assert_called_once()
diff --git a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py
index afbce49cd9..ac3d2a919b 100644
--- a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py
@@ -24,7 +24,7 @@ from flask_appbuilder.security.sqla.models import Role, User
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
-from superset import security_manager
+from superset import db, security_manager
from superset.commands.dashboard.importers.v1.utils import import_dashboard
from superset.commands.exceptions import ImportFailedError
from superset.models.dashboard import Dashboard
@@ -67,7 +67,7 @@ def test_import_dashboard(mocker: MockFixture, session_with_schema: Session) ->
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
- dashboard = import_dashboard(session_with_schema, dashboard_config)
+ dashboard = import_dashboard(dashboard_config)
assert dashboard.dashboard_title == "Test dash"
assert dashboard.description is None
assert dashboard.is_managed_externally is False
@@ -88,8 +88,7 @@ def test_import_dashboard_managed_externally(
config = copy.deepcopy(dashboard_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_dashboard"
-
- dashboard = import_dashboard(session_with_schema, config)
+ dashboard = import_dashboard(config)
assert dashboard.is_managed_externally is True
assert dashboard.external_url == "https://example.org/my_dashboard"
@@ -107,7 +106,7 @@ def test_import_dashboard_without_permission(
mocker.patch.object(security_manager, "can_access", return_value=False)
with pytest.raises(ImportFailedError) as excinfo:
- import_dashboard(session_with_schema, dashboard_config)
+ import_dashboard(dashboard_config)
assert (
str(excinfo.value)
== "Dashboard doesn't exist and user doesn't have permission to create dashboards"
@@ -135,7 +134,7 @@ def test_import_existing_dashboard_without_permission(
with override_user("admin"):
with pytest.raises(ImportFailedError) as excinfo:
- import_dashboard(session_with_data, dashboard_config, overwrite=True)
+ import_dashboard(dashboard_config, overwrite=True)
assert (
str(excinfo.value)
== "A dashboard already exists and user doesn't have permissions to overwrite it"
@@ -171,7 +170,8 @@ def test_import_existing_dashboard_with_permission(
)
with override_user(admin):
- import_dashboard(session_with_data, dashboard_config, overwrite=True)
+ import_dashboard(dashboard_config, overwrite=True)
+
# Assert that the can write to dashboard was checked
security_manager.can_access.assert_called_once_with("can_write", "Dashboard")
security_manager.can_access_dashboard.assert_called_once_with(dashboard)
diff --git a/tests/unit_tests/dashboards/dao_tests.py b/tests/unit_tests/dashboards/dao_tests.py
index 3bf4038f16..09edfacd44 100644
--- a/tests/unit_tests/dashboards/dao_tests.py
+++ b/tests/unit_tests/dashboards/dao_tests.py
@@ -42,12 +42,10 @@ def session_with_data(session: Session) -> Iterator[Session]:
session.rollback()
-def test_add_favorite(session_with_data: Session) -> None:
+def test_add_favorite(session: Session) -> None:
from superset.daos.dashboard import DashboardDAO
- dashboard = DashboardDAO.find_by_id(
- 100, session=session_with_data, skip_base_filter=True
- )
+ dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
if not dashboard:
return
assert len(DashboardDAO.favorited_ids([dashboard])) == 0
@@ -59,12 +57,10 @@ def test_add_favorite(session_with_data: Session) -> None:
assert len(DashboardDAO.favorited_ids([dashboard])) == 1
-def test_remove_favorite(session_with_data: Session) -> None:
+def test_remove_favorite(session: Session) -> None:
from superset.daos.dashboard import DashboardDAO
- dashboard = DashboardDAO.find_by_id(
- 100, session=session_with_data, skip_base_filter=True
- )
+ dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True)
if not dashboard:
return
assert len(DashboardDAO.favorited_ids([dashboard])) == 0
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index cf3e64c306..f867f82a98 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -28,6 +28,8 @@ from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
+from superset import db
+
def test_filter_by_uuid(
session: Session,
@@ -49,14 +51,14 @@ def test_filter_by_uuid(
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
- session.add(
+ db.session.add(
Database(
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
)
)
- session.commit()
+ db.session.commit()
response = client.get(
"/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:"
@@ -96,7 +98,7 @@ def test_post_with_uuid(
payload = response.json
assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"
- database = session.query(Database).one()
+ database = db.session.query(Database).one()
assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb")
@@ -139,8 +141,8 @@ def test_password_mask(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -195,8 +197,8 @@ def test_database_connection(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -331,8 +333,8 @@ def test_update_with_password_mask(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
client.put(
"/api/v1/database/1",
@@ -347,7 +349,7 @@ def test_update_with_password_mask(
),
},
)
- database = session.query(Database).one()
+ database = db.session.query(Database).one()
assert (
database.encrypted_extra
== '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}'
@@ -429,8 +431,8 @@ def test_delete_ssh_tunnel(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -446,8 +448,8 @@ def test_delete_ssh_tunnel(
database=database,
)
- session.add(tunnel)
- session.commit()
+ db.session.add(tunnel)
+ db.session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
@@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
@@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found(
database=database,
)
- session.add(tunnel)
- session.commit()
+ db.session.add(tunnel)
+ db.session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
@@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
# Create our Second Database
database = Database(
@@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter(
}
),
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py
index 5fb4d12ce5..ad18f0157c 100644
--- a/tests/unit_tests/databases/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py
@@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
+from superset import db
from superset.commands.exceptions import ImportFailedError
@@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
- database = import_database(session, config)
+ database = import_database(config)
assert database.database_name == "imported_database"
assert database.sqlalchemy_uri == "someengine://user:pass@host1"
assert database.cache_timeout is None
@@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None:
# missing
config = copy.deepcopy(database_config)
del config["allow_dml"]
- session.delete(database)
- session.flush()
- database = import_database(session, config)
+ db.session.delete(database)
+ db.session.flush()
+ database = import_database(config)
assert database.allow_dml is False
@@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) -
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config_sqlite)
with pytest.raises(ImportFailedError) as excinfo:
- _ = import_database(session, config)
+ _ = import_database(config)
assert (
str(excinfo.value)
== "SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
@@ -106,14 +107,14 @@ def test_import_database_managed_externally(
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_database"
- database = import_database(session, config)
+ database = import_database(config)
assert database.is_managed_externally is True
assert database.external_url == "https://example.org/my_database"
@@ -132,13 +133,13 @@ def test_import_database_without_permission(
mocker.patch.object(security_manager, "can_access", return_value=False)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
with pytest.raises(ImportFailedError) as excinfo:
- import_database(session, config)
+ import_database(config)
assert (
str(excinfo.value)
== "Database doesn't exist and user doesn't have permission to create databases"
@@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) ->
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
config = copy.deepcopy(database_config)
config["extra"]["version"] = "1.1.1"
- database = import_database(session, config)
+ database = import_database(config)
assert json.loads(database.extra)["version"] == "1.1.1"
diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py
index b792a65336..a826d01be8 100644
--- a/tests/unit_tests/databases/dao/dao_tests.py
+++ b/tests/unit_tests/databases/dao/dao_tests.py
@@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
- database=db,
+ database=database,
)
ssh_tunnel = SSHTunnel(
- database_id=db.id,
- database=db,
+ database_id=database.id,
+ database=database,
)
- session.add(db)
+ session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
index 1777bdc2e1..4b05cce637 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
- db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
- "database_id": db.id,
+ "database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
- result = CreateSSHTunnelCommand(db, properties).run()
+ result = CreateSSHTunnelCommand(database, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
@@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
- db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
- "database": db,
+ "database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
- command = CreateSSHTunnelCommand(db, properties)
+ command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
index 14838ddc58..78f9c1142c 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
@@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
- database=db,
+ database=database,
)
ssh_tunnel = SSHTunnel(
- database_id=db.id,
- database=db,
+ database_id=database.id,
+ database=database,
)
- session.add(db)
+ session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
index 5c3907b016..54e54d05da 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
@@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
- database=db,
+ database=database,
+ )
+ ssh_tunnel = SSHTunnel(
+ database_id=database.id, database=database, server_address="Test"
)
- ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
- session.add(db)
+ session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
index 7a88807597..4646e12c1f 100644
--- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
+++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
@@ -25,11 +25,11 @@ def test_create_ssh_tunnel():
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
- db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
result = SSHTunnelDAO.create(
attributes={
- "database_id": db.id,
+ "database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
diff --git a/tests/unit_tests/datasets/api_tests.py b/tests/unit_tests/datasets/api_tests.py
index de93720fa6..e0786afaa3 100644
--- a/tests/unit_tests/datasets/api_tests.py
+++ b/tests/unit_tests/datasets/api_tests.py
@@ -19,6 +19,8 @@ from typing import Any
from sqlalchemy.orm.session import Session
+from superset import db
+
def test_put_invalid_dataset(
session: Session,
@@ -31,7 +33,7 @@ def test_put_invalid_dataset(
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
- SqlaTable.metadata.create_all(session.get_bind())
+ SqlaTable.metadata.create_all(db.session.get_bind())
database = Database(
database_name="my_db",
@@ -41,8 +43,8 @@ def test_put_invalid_dataset(
table_name="test_put_invalid_dataset",
database=database,
)
- session.add(dataset)
- session.flush()
+ db.session.add(dataset)
+ db.session.flush()
response = client.put(
"/api/v1/dataset/1",
diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py
index 20565da5bc..73f383859b 100644
--- a/tests/unit_tests/datasets/commands/export_test.py
+++ b/tests/unit_tests/datasets/commands/export_test.py
@@ -20,6 +20,8 @@ import json
from sqlalchemy.orm.session import Session
+from superset import db
+
def test_export(session: Session) -> None:
"""
@@ -29,12 +31,12 @@ def test_export(session: Session) -> None:
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.core import Database
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
columns = [
TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"),
diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
index 5089838e69..a7660d6c0b 100644
--- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
@@ -28,6 +28,7 @@ from flask import current_app
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
+from superset import db
from superset.commands.dataset.exceptions import (
DatasetForbiddenDataURI,
ImportFailedError,
@@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
dataset_uuid = uuid.uuid4()
config = {
@@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None:
"database_id": database.id,
}
- sqla_table = import_dataset(session, config)
+ sqla_table = import_dataset(config)
assert sqla_table.table_name == "my_table"
assert sqla_table.main_dttm_col == "ds"
assert sqla_table.description == "This is the description"
@@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
dataset_uuid = uuid.uuid4()
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
dataset = SqlaTable(
uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id
)
column = TableColumn(column_name="existing_column")
- session.add(dataset)
- session.add(column)
- session.flush()
+ db.session.add(dataset)
+ db.session.add(column)
+ db.session.flush()
config = {
"table_name": dataset.table_name,
@@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session)
"database_id": database.id,
}
- sqla_table = import_dataset(session, config, overwrite=True)
+ sqla_table = import_dataset(config, overwrite=True)
assert sqla_table.table_name == dataset.table_name
assert sqla_table.main_dttm_col == "ds"
assert sqla_table.description == "This is the description"
@@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
- sqla_table = import_dataset(session, dataset_config)
+ sqla_table = import_dataset(dataset_config)
assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
@@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string(
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string(
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
- sqla_table = import_dataset(session, dataset_config)
+ sqla_table = import_dataset(dataset_config)
assert sqla_table.extra == None
@@ -443,12 +444,12 @@ def test_import_column_allowed_data_url(
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
dataset_uuid = uuid.uuid4()
yaml_config: dict[str, Any] = {
@@ -495,9 +496,8 @@ def test_import_column_allowed_data_url(
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
- _ = import_dataset(session, dataset_config, force_data=True)
- session.connection()
- assert [("value1",), ("value2",)] == session.execute(
+ _ = import_dataset(dataset_config, force_data=True)
+ assert [("value1",), ("value2",)] == db.session.execute(
"SELECT * FROM my_table"
).fetchall()
@@ -517,19 +517,19 @@ def test_import_dataset_managed_externally(
mocker.patch.object(security_manager, "can_access", return_value=True)
- engine = session.get_bind()
+ engine = db.session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
- session.add(database)
- session.flush()
+ db.session.add(database)
+ db.session.flush()
config = copy.deepcopy(dataset_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_table"
config["database_id"] = database.id
- sqla_table = import_dataset(session, config)
+ sqla_table = import_dataset(config)
assert sqla_table.is_managed_externally is True
assert sqla_table.external_url == "https://example.org/my_table"
diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py
index 3302f2dc04..a4632fad3d 100644
--- a/tests/unit_tests/datasets/dao/dao_tests.py
+++ b/tests/unit_tests/datasets/dao/dao_tests.py
@@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
- database=db,
+ database=database,
)
- session.add(db)
+ session.add(database)
session.add(sqla_table)
session.flush()
yield session
@@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N
result = DatasetDAO.find_by_id(
1,
- session=session_with_data,
skip_base_filter=True,
)
@@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found(
result = DatasetDAO.find_by_id(
125326326,
- session=session_with_data,
skip_base_filter=True,
)
assert result is None
@@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) ->
result = DatasetDAO.find_by_ids(
[1, 125326326],
- session=session_with_data,
skip_base_filter=True,
)
@@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found(
result = DatasetDAO.find_by_ids(
[125326326, 125326326125326326],
- session=session_with_data,
skip_base_filter=True,
)
diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py
index b4ce162c0c..adc674d0fd 100644
--- a/tests/unit_tests/datasource/dao_tests.py
+++ b/tests/unit_tests/datasource/dao_tests.py
@@ -35,7 +35,7 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
columns = [
TableColumn(column_name="a", type="INTEGER"),
@@ -45,12 +45,12 @@ def session_with_data(session: Session) -> Iterator[Session]:
table_name="my_sqla_table",
columns=columns,
metrics=[],
- database=db,
+ database=database,
)
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
@@ -63,13 +63,13 @@ def session_with_data(session: Session) -> Iterator[Session]:
results_key="abc",
)
- saved_query = SavedQuery(database=db, sql="select * from foo")
+ saved_query = SavedQuery(database=database, sql="select * from foo")
table = Table(
name="my_table",
schema="my_schema",
catalog="my_catalog",
- database=db,
+ database=database,
columns=[],
)
@@ -93,7 +93,7 @@ FROM my_catalog.my_schema.my_table
session.add(table)
session.add(saved_query)
session.add(query_obj)
- session.add(db)
+ session.add(database)
session.add(sqla_table)
session.flush()
yield session
@@ -190,7 +190,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None:
def test_get_all_datasources(session_with_data: Session) -> None:
from superset.connectors.sqla.models import SqlaTable
- result = SqlaTable.get_all_datasources(session=session_with_data)
+ result = SqlaTable.get_all_datasources()
assert len(result) == 1
diff --git a/tests/unit_tests/db_engine_specs/test_druid.py b/tests/unit_tests/db_engine_specs/test_druid.py
index d090dffcde..0ab4688214 100644
--- a/tests/unit_tests/db_engine_specs/test_druid.py
+++ b/tests/unit_tests/db_engine_specs/test_druid.py
@@ -74,10 +74,10 @@ def test_extras_without_ssl() -> None:
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.fixtures.database import default_db_extra
- db = mock.Mock()
- db.extra = default_db_extra
- db.server_cert = None
- extras = DruidEngineSpec.get_extra_params(db)
+ database = mock.Mock()
+ database.extra = default_db_extra
+ database.server_cert = None
+ extras = DruidEngineSpec.get_extra_params(database)
assert "connect_args" not in extras["engine_params"]
@@ -86,10 +86,10 @@ def test_extras_with_ssl() -> None:
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
- db = mock.Mock()
- db.extra = default_db_extra
- db.server_cert = ssl_certificate
- extras = DruidEngineSpec.get_extra_params(db)
+ database = mock.Mock()
+ database.extra = default_db_extra
+ database.server_cert = ssl_certificate
+ extras = DruidEngineSpec.get_extra_params(database)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args
diff --git a/tests/unit_tests/db_engine_specs/test_pinot.py b/tests/unit_tests/db_engine_specs/test_pinot.py
index a1648f5f60..72c8267816 100644
--- a/tests/unit_tests/db_engine_specs/test_pinot.py
+++ b/tests/unit_tests/db_engine_specs/test_pinot.py
@@ -50,8 +50,8 @@ def test_extras_without_ssl() -> None:
from superset.db_engine_specs.pinot import PinotEngineSpec as spec
from tests.integration_tests.fixtures.database import default_db_extra
- db = mock.Mock()
- db.extra = default_db_extra
- db.server_cert = None
- extras = spec.get_extra_params(db)
+ database = mock.Mock()
+ database.extra = default_db_extra
+ database.server_cert = None
+ extras = spec.get_extra_params(database)
assert "connect_args" not in extras["engine_params"]
diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py
index cc738fd6c6..caa141aaf7 100644
--- a/tests/unit_tests/extensions/test_sqlalchemy.py
+++ b/tests/unit_tests/extensions/test_sqlalchemy.py
@@ -26,6 +26,7 @@ from sqlalchemy.engine import create_engine
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm.session import Session
+from superset import db
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from tests.unit_tests.conftest import with_feature_flags
@@ -38,7 +39,7 @@ if TYPE_CHECKING:
def database1(session: Session) -> Iterator["Database"]:
from superset.models.core import Database
- engine = session.connection().engine
+ engine = db.session.connection().engine
Database.metadata.create_all(engine) # pylint: disable=no-member
database = Database(
@@ -46,13 +47,13 @@ def database1(session: Session) -> Iterator["Database"]:
sqlalchemy_uri="sqlite:///database1.db",
allow_dml=True,
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
yield database
- session.delete(database)
- session.commit()
+ db.session.delete(database)
+ db.session.commit()
os.unlink("database1.db")
@@ -62,12 +63,12 @@ def table1(session: Session, database1: "Database") -> Iterator[None]:
conn = engine.connect()
conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)")
conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)")
- session.commit()
+ db.session.commit()
yield
conn.execute("DROP TABLE table1")
- session.commit()
+ db.session.commit()
@pytest.fixture
@@ -79,13 +80,13 @@ def database2(session: Session) -> Iterator["Database"]:
sqlalchemy_uri="sqlite:///database2.db",
allow_dml=False,
)
- session.add(database)
- session.commit()
+ db.session.add(database)
+ db.session.commit()
yield database
- session.delete(database)
- session.commit()
+ db.session.delete(database)
+ db.session.commit()
os.unlink("database2.db")
@@ -95,12 +96,12 @@ def table2(session: Session, database2: "Database") -> Iterator[None]:
conn = engine.connect()
conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)")
conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')")
- session.commit()
+ db.session.commit()
yield
conn.execute("DROP TABLE table2")
- session.commit()
+ db.session.commit()
@with_feature_flags(ENABLE_SUPERSET_META_DB=True)
diff --git a/tests/unit_tests/queries/dao_test.py b/tests/unit_tests/queries/dao_test.py
index a0221b8019..dbca78a9d3 100644
--- a/tests/unit_tests/queries/dao_test.py
+++ b/tests/unit_tests/queries/dao_test.py
@@ -22,10 +22,10 @@ def test_column_attributes_on_query():
from superset.models.core import Database
from superset.models.sql_lab import Query
- db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
query_obj = Query(
client_id="foo",
- database=db,
+ database=database,
tab_name="test_tab",
sql_editor_id="test_editor_id",
sql="select * from bar",
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 8265277372..83e7c373c8 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -125,7 +125,7 @@ def test_sql_lab_insert_rls_as_subquery(
from superset.sql_lab import execute_sql_statement
from superset.utils.core import RowLevelSecurityFilterType
- engine = session.connection().engine
+ engine = db.session.connection().engine
Query.metadata.create_all(engine) # pylint: disable=no-member
connection = engine.raw_connection()
@@ -143,8 +143,8 @@ def test_sql_lab_insert_rls_as_subquery(
limit=5,
select_as_cta_used=False,
)
- session.add(query)
- session.commit()
+ db.session.add(query)
+ db.session.commit()
admin = User(
first_name="Alice",
@@ -185,8 +185,8 @@ def test_sql_lab_insert_rls_as_subquery(
group_key=None,
clause="c > 5",
)
- session.add(rls)
- session.flush()
+ db.session.add(rls)
+ db.session.flush()
mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin)
mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True)
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f650b77734..f05e16ae85 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1759,8 +1759,7 @@ def test_get_rls_for_table(mocker: MockerFixture) -> None:
Tests for ``get_rls_for_table``.
"""
candidate = Identifier([Token(Name, "some_table")])
- db = mocker.patch("superset.db")
- dataset = db.session.query().filter().one_or_none()
+ dataset = mocker.patch("superset.db").session.query().filter().one_or_none()
dataset.__str__.return_value = "some_table"
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py
index 7705dba6aa..926e059261 100644
--- a/tests/unit_tests/tables/test_models.py
+++ b/tests/unit_tests/tables/test_models.py
@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
# pylint: disable=import-outside-toplevel, unused-argument
-
from sqlalchemy.orm.session import Session
+from superset import db
+
def test_table_model(session: Session) -> None:
"""
@@ -28,7 +28,7 @@ def test_table_model(session: Session) -> None:
from superset.models.core import Database
from superset.tables.models import Table
- engine = session.get_bind()
+ engine = db.session.get_bind()
Table.metadata.create_all(engine) # pylint: disable=no-member
table = Table(
@@ -44,8 +44,8 @@ def test_table_model(session: Session) -> None:
)
],
)
- session.add(table)
- session.flush()
+ db.session.add(table)
+ db.session.flush()
assert table.id == 1
assert table.uuid is not None
diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py
index b18144521a..1e1895bb77 100644
--- a/tests/unit_tests/tags/commands/create_test.py
+++ b/tests/unit_tests/tags/commands/create_test.py
@@ -18,6 +18,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
+from superset import db
from superset.utils.core import DatasourceType
@@ -40,13 +41,15 @@ def session_with_data(session: Session):
slice_name="slice_name",
)
- db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
+ database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
columns = [
TableColumn(column_name="a", type="INTEGER"),
]
- saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
+ saved_query = SavedQuery(
+ label="test_query", database=database, sql="select * from foo"
+ )
dashboard_obj = Dashboard(
id=100,
@@ -57,7 +60,7 @@ def session_with_data(session: Session):
)
session.add(slice_obj)
- session.add(db)
+ session.add(database)
session.add(saved_query)
session.add(dashboard_obj)
session.commit()
@@ -74,9 +77,9 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
from superset.tags.models import ObjectType, TaggedObject
# Define a list of objects to tag
- query = session_with_data.query(SavedQuery).first()
- chart = session_with_data.query(Slice).first()
- dashboard = session_with_data.query(Dashboard).first()
+ query = db.session.query(SavedQuery).first()
+ chart = db.session.query(Slice).first()
+ dashboard = db.session.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -94,10 +97,10 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
- assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+ assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
for object_type, object_id in objects_to_tag:
assert (
- session_with_data.query(TaggedObject)
+ db.session.query(TaggedObject)
.filter(
TaggedObject.object_type == object_type,
TaggedObject.object_id == object_id,
@@ -117,9 +120,9 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
from superset.tags.models import ObjectType, TaggedObject
# Define a list of objects to tag
- query = session_with_data.query(SavedQuery).first()
- chart = session_with_data.query(Slice).first()
- dashboard = session_with_data.query(Dashboard).first()
+ query = db.session.query(SavedQuery).first()
+ chart = db.session.query(Slice).first()
+ dashboard = db.session.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -136,10 +139,10 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": objects_to_tag}
).run()
- assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+ assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
CreateCustomTagWithRelationshipsCommand(
data={"name": "test_tag", "objects_to_tag": []}
).run()
- assert len(session_with_data.query(TaggedObject).all()) == 0
+ assert len(db.session.query(TaggedObject).all()) == 0
diff --git a/tests/unit_tests/tags/commands/update_test.py b/tests/unit_tests/tags/commands/update_test.py
index e488321228..75636ab0af 100644
--- a/tests/unit_tests/tags/commands/update_test.py
+++ b/tests/unit_tests/tags/commands/update_test.py
@@ -18,6 +18,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session
+from superset import db
from superset.utils.core import DatasourceType
@@ -41,7 +42,7 @@ def session_with_data(session: Session):
slice_name="slice_name",
)
- db = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
+ database = Database(database_name="my_database", sqlalchemy_uri="postgresql://")
columns = [
TableColumn(column_name="a", type="INTEGER"),
@@ -51,7 +52,7 @@ def session_with_data(session: Session):
table_name="my_sqla_table",
columns=columns,
metrics=[],
- database=db,
+ database=database,
)
dashboard_obj = Dashboard(
@@ -62,7 +63,9 @@ def session_with_data(session: Session):
published=True,
)
- saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo")
+ saved_query = SavedQuery(
+ label="test_query", database=database, sql="select * from foo"
+ )
tag = Tag(name="test_name", description="test_description")
@@ -79,7 +82,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
from superset.models.dashboard import Dashboard
from superset.tags.models import ObjectType, TaggedObject
- dashboard = session_with_data.query(Dashboard).first()
+ dashboard = db.session.query(Dashboard).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
@@ -104,7 +107,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture)
updated_tag = TagDAO.find_by_name("new_name")
assert updated_tag is not None
assert updated_tag.description == "new_description"
- assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+ assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
def test_update_command_success_duplicates(
@@ -117,8 +120,8 @@ def test_update_command_success_duplicates(
from superset.models.slice import Slice
from superset.tags.models import ObjectType, TaggedObject
- dashboard = session_with_data.query(Dashboard).first()
- chart = session_with_data.query(Slice).first()
+ dashboard = db.session.query(Dashboard).first()
+ chart = db.session.query(Slice).first()
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
@@ -153,7 +156,7 @@ def test_update_command_success_duplicates(
updated_tag = TagDAO.find_by_name("new_name")
assert updated_tag is not None
assert updated_tag.description == "new_description"
- assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+ assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag)
assert changed_model.objects[0].object_id == chart.id
@@ -168,8 +171,8 @@ def test_update_command_failed_validation(
from superset.models.slice import Slice
from superset.tags.models import ObjectType
- dashboard = session_with_data.query(Dashboard).first()
- chart = session_with_data.query(Slice).first()
+ dashboard = db.session.query(Dashboard).first()
+ chart = db.session.query(Slice).first()
objects_to_tag = [
(ObjectType.chart, chart.id),
]