You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by er...@apache.org on 2020/08/06 22:34:18 UTC
[incubator-superset] branch master updated: Revert "chore: Cleanup
database sessions (#10427)" (#10537)
This is an automated email from the ASF dual-hosted git repository.
erikrit pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new fd2d1c5 Revert "chore: Cleanup database sessions (#10427)" (#10537)
fd2d1c5 is described below
commit fd2d1c58c566d9312d6cfc5641a06ac2b03e753a
Author: Erik Ritter <er...@airbnb.com>
AuthorDate: Thu Aug 6 15:33:48 2020 -0700
Revert "chore: Cleanup database sessions (#10427)" (#10537)
This reverts commit 7645fc85c3d6676a13ae76ca5133f83d8fb54dbe.
---
superset/cli.py | 12 +-
superset/commands/utils.py | 6 +-
superset/common/query_context.py | 4 +-
superset/connectors/connector_registry.py | 35 +++---
superset/connectors/druid/models.py | 52 ++++----
superset/connectors/druid/views.py | 5 +-
superset/connectors/sqla/models.py | 29 +++--
superset/dashboards/dao.py | 8 +-
superset/models/dashboard.py | 36 +++---
superset/models/helpers.py | 13 +-
superset/models/slice.py | 8 +-
superset/models/tags.py | 81 ++++++++-----
superset/security/manager.py | 6 +-
superset/sql_lab.py | 3 +-
superset/tasks/cache.py | 23 ++--
superset/tasks/schedules.py | 13 +-
superset/utils/dashboard_import_export.py | 11 +-
superset/utils/dict_import_export.py | 19 +--
superset/utils/import_datasource.py | 22 ++--
superset/views/base.py | 3 +-
superset/views/chart/views.py | 4 +-
superset/views/core.py | 129 ++++++++++++--------
superset/views/datasource.py | 6 +-
superset/views/utils.py | 7 +-
tests/access_tests.py | 92 +++++++-------
tests/alerts_tests.py | 192 +++++++++++++++---------------
tests/base_tests.py | 34 +++---
tests/celery_tests.py | 10 +-
tests/charts/api_tests.py | 4 +-
tests/core_tests.py | 24 ++--
tests/database_api_tests.py | 5 +-
tests/datasets/api_tests.py | 7 +-
tests/dict_import_export_tests.py | 60 ++++++----
tests/druid_tests.py | 10 +-
tests/import_export_tests.py | 23 ++--
tests/query_context_tests.py | 1 +
tests/security_tests.py | 132 ++++++++++----------
tests/sqllab_tests.py | 6 +-
tests/strategy_tests.py | 6 +-
39 files changed, 645 insertions(+), 496 deletions(-)
diff --git a/superset/cli.py b/superset/cli.py
index ef9ee84..ef17682 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -197,9 +197,10 @@ def set_database_uri(database_name: str, uri: str) -> None:
)
def refresh_druid(datasource: str, merge: bool) -> None:
"""Refresh druid datasources"""
+ session = db.session()
from superset.connectors.druid.models import DruidCluster
- for cluster in db.session.query(DruidCluster).all():
+ for cluster in session.query(DruidCluster).all():
try:
cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
except Exception as ex: # pylint: disable=broad-except
@@ -207,7 +208,7 @@ def refresh_druid(datasource: str, merge: bool) -> None:
logger.exception(ex)
cluster.metadata_last_refreshed = datetime.now()
print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
- db.session.commit()
+ session.commit()
@superset.command()
@@ -249,7 +250,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None:
logger.info("Importing dashboard from file %s", file_)
try:
with file_.open() as data_stream:
- dashboard_import_export.import_dashboards(data_stream)
+ dashboard_import_export.import_dashboards(db.session, data_stream)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing dashboard from file %s", file_)
logger.error(ex)
@@ -267,7 +268,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
"""Export dashboards to JSON"""
from superset.utils import dashboard_import_export
- data = dashboard_import_export.export_dashboards()
+ data = dashboard_import_export.export_dashboards(db.session)
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
@@ -320,7 +321,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
try:
with file_.open() as data_stream:
dict_import_export.import_from_dict(
- yaml.safe_load(data_stream), sync=sync_array
+ db.session, yaml.safe_load(data_stream), sync=sync_array
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing datasources from file %s", file_)
@@ -359,6 +360,7 @@ def export_datasources(
from superset.utils import dict_import_export
data = dict_import_export.export_to_dict(
+ session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults,
diff --git a/superset/commands/utils.py b/superset/commands/utils.py
index 66fd543..c0bd8b7 100644
--- a/superset/commands/utils.py
+++ b/superset/commands/utils.py
@@ -25,7 +25,7 @@ from superset.commands.exceptions import (
)
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
-from superset.extensions import security_manager
+from superset.extensions import db, security_manager
def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]:
@@ -50,6 +50,8 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
try:
- return ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+ return ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, db.session
+ )
except (NoResultFound, KeyError):
raise DatasourceNotFoundValidationError()
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 401f262..e602fbf 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -23,7 +23,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
import numpy as np
import pandas as pd
-from superset import app, cache, security_manager
+from superset import app, cache, db, security_manager
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
@@ -64,7 +64,7 @@ class QueryContext:
result_format: Optional[utils.ChartDataResultFormat] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
- str(datasource["type"]), int(datasource["id"])
+ str(datasource["type"]), int(datasource["id"]), db.session
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py
index 7c47bc7..fff2f8e 100644
--- a/superset/connectors/connector_registry.py
+++ b/superset/connectors/connector_registry.py
@@ -17,9 +17,7 @@
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from sqlalchemy import or_
-from sqlalchemy.orm import subqueryload
-
-from superset.extensions import db
+from sqlalchemy.orm import Session, subqueryload
if TYPE_CHECKING:
# pylint: disable=unused-import
@@ -45,20 +43,20 @@ class ConnectorRegistry:
@classmethod
def get_datasource(
- cls, datasource_type: str, datasource_id: int
+ cls, datasource_type: str, datasource_id: int, session: Session
) -> "BaseDatasource":
return (
- db.session.query(cls.sources[datasource_type])
+ session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)
@classmethod
- def get_all_datasources(cls) -> List["BaseDatasource"]:
+ def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
datasources: List["BaseDatasource"] = []
for source_type in ConnectorRegistry.sources:
source_class = ConnectorRegistry.sources[source_type]
- qry = db.session.query(source_class)
+ qry = session.query(source_class)
qry = source_class.default_query(qry)
datasources.extend(qry.all())
return datasources
@@ -66,6 +64,7 @@ class ConnectorRegistry:
@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
+ session: Session,
datasource_type: str,
datasource_name: str,
schema: str,
@@ -73,17 +72,21 @@ class ConnectorRegistry:
) -> Optional["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[datasource_type]
return datasource_class.get_datasource_by_name(
- datasource_name, schema, database_name
+ session, datasource_name, schema, database_name
)
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
- cls, database: "Database", permissions: Set[str], schema_perms: Set[str],
+ cls,
+ session: Session,
+ database: "Database",
+ permissions: Set[str],
+ schema_perms: Set[str],
) -> List["BaseDatasource"]:
# TODO(bogdan): add unit test
datasource_class = ConnectorRegistry.sources[database.type]
return (
- db.session.query(datasource_class)
+ session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
@@ -96,12 +99,12 @@ class ConnectorRegistry:
@classmethod
def get_eager_datasource(
- cls, datasource_type: str, datasource_id: int
+ cls, session: Session, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
"""Returns datasource with columns and metrics."""
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
- db.session.query(datasource_class)
+ session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
@@ -112,9 +115,13 @@ class ConnectorRegistry:
@classmethod
def query_datasources_by_name(
- cls, database: "Database", datasource_name: str, schema: Optional[str] = None,
+ cls,
+ session: Session,
+ database: "Database",
+ datasource_name: str,
+ schema: Optional[str] = None,
) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
- database, datasource_name, schema=schema
+ session, database, datasource_name, schema=schema
)
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index 0068f11..162163f 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -45,7 +45,7 @@ from sqlalchemy import (
UniqueConstraint,
)
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import backref, relationship
+from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.sql import expression
from sqlalchemy_utils import EncryptedType
@@ -223,8 +223,9 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
Fetches metadata for the specified datasources and
merges to the Superset database
"""
+ session = db.session
ds_list = (
- db.session.query(DruidDatasource)
+ session.query(DruidDatasource)
.filter(DruidDatasource.cluster_id == self.id)
.filter(DruidDatasource.datasource_name.in_(datasource_names))
)
@@ -233,8 +234,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
datasource = ds_map.get(ds_name, None)
if not datasource:
datasource = DruidDatasource(datasource_name=ds_name)
- with db.session.no_autoflush:
- db.session.add(datasource)
+ with session.no_autoflush:
+ session.add(datasource)
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
ds_map[ds_name] = datasource
elif refresh_all:
@@ -244,7 +245,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
continue
datasource.cluster = self
datasource.merge_flag = merge_flag
- db.session.flush()
+ session.flush()
# Prepare multithreaded executation
pool = ThreadPool()
@@ -258,7 +259,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
cols = metadata[i]
if cols:
col_objs_list = (
- db.session.query(DruidColumn)
+ session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(DruidColumn.column_name.in_(cols.keys()))
)
@@ -271,15 +272,15 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
col_obj = DruidColumn(
datasource_id=datasource.id, column_name=col
)
- with db.session.no_autoflush:
- db.session.add(col_obj)
+ with session.no_autoflush:
+ session.add(col_obj)
col_obj.type = cols[col]["type"]
col_obj.datasource = datasource
if col_obj.type == "STRING":
col_obj.groupby = True
col_obj.filterable = True
datasource.refresh_metrics()
- db.session.commit()
+ session.commit()
@hybrid_property
def perm(self) -> str:
@@ -389,7 +390,7 @@ class DruidColumn(Model, BaseColumn):
.first()
)
- return import_datasource.import_simple_obj(i_column, lookup_obj)
+ return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
class DruidMetric(Model, BaseMetric):
@@ -458,7 +459,7 @@ class DruidMetric(Model, BaseMetric):
.first()
)
- return import_datasource.import_simple_obj(i_metric, lookup_obj)
+ return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
druiddatasource_user = Table(
@@ -634,7 +635,7 @@ class DruidDatasource(Model, BaseDatasource):
return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
return import_datasource.import_datasource(
- i_datasource, lookup_cluster, lookup_datasource, import_time
+ db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
)
def latest_metadata(self) -> Optional[Dict[str, Any]]:
@@ -704,10 +705,9 @@ class DruidDatasource(Model, BaseDatasource):
refresh: bool = True,
) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
+ session = db.session
datasource = (
- db.session.query(cls)
- .filter_by(datasource_name=druid_config["name"])
- .first()
+ session.query(cls).filter_by(datasource_name=druid_config["name"]).first()
)
# Create a new datasource.
if not datasource:
@@ -718,13 +718,13 @@ class DruidDatasource(Model, BaseDatasource):
changed_by_fk=user.id,
created_by_fk=user.id,
)
- db.session.add(datasource)
+ session.add(datasource)
elif not refresh:
return
dimensions = druid_config["dimensions"]
col_objs = (
- db.session.query(DruidColumn)
+ session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(DruidColumn.column_name.in_(dimensions))
)
@@ -741,10 +741,10 @@ class DruidDatasource(Model, BaseDatasource):
type="STRING",
datasource=datasource,
)
- db.session.add(col_obj)
+ session.add(col_obj)
# Import Druid metrics
metric_objs = (
- db.session.query(DruidMetric)
+ session.query(DruidMetric)
.filter(DruidMetric.datasource_id == datasource.id)
.filter(
DruidMetric.metric_name.in_(
@@ -777,8 +777,8 @@ class DruidDatasource(Model, BaseDatasource):
% druid_config["name"]
),
)
- db.session.add(metric_obj)
- db.session.commit()
+ session.add(metric_obj)
+ session.commit()
@staticmethod
def time_offset(granularity: Granularity) -> int:
@@ -788,10 +788,10 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def get_datasource_by_name(
- cls, datasource_name: str, schema: str, database_name: str
+ cls, session: Session, datasource_name: str, schema: str, database_name: str
) -> Optional["DruidDatasource"]:
query = (
- db.session.query(cls)
+ session.query(cls)
.join(DruidCluster)
.filter(cls.datasource_name == datasource_name)
.filter(DruidCluster.cluster_name == database_name)
@@ -1724,7 +1724,11 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def query_datasources_by_name(
- cls, database: Database, datasource_name: str, schema: Optional[str] = None,
+ cls,
+ session: Session,
+ database: Database,
+ datasource_name: str,
+ schema: Optional[str] = None,
) -> List["DruidDatasource"]:
return []
diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py
index 4a22bc2..4c2fbf9 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -365,10 +365,11 @@ class Druid(BaseSupersetView):
self, refresh_all: bool = True
) -> FlaskResponse:
"""endpoint that refreshes druid datasources metadata"""
+ session = db.session()
DruidCluster = ConnectorRegistry.sources[ # pylint: disable=invalid-name
"druid"
].cluster_class
- for cluster in db.session.query(DruidCluster).all():
+ for cluster in session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
valid_cluster = True
try:
@@ -390,7 +391,7 @@ class Druid(BaseSupersetView):
),
"info",
)
- db.session.commit()
+ session.commit()
return redirect("/druiddatasourcemodelview/list/")
@has_access
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 82f3017..530a2e1 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -41,7 +41,7 @@ from sqlalchemy import (
Text,
)
from sqlalchemy.exc import CompileError
-from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty
+from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
@@ -255,7 +255,7 @@ class TableColumn(Model, BaseColumn):
.first()
)
- return import_datasource.import_simple_obj(i_column, lookup_obj)
+ return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
def dttm_sql_literal(
self,
@@ -375,7 +375,7 @@ class SqlMetric(Model, BaseMetric):
.first()
)
- return import_datasource.import_simple_obj(i_metric, lookup_obj)
+ return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
sqlatable_user = Table(
@@ -503,11 +503,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
@classmethod
def get_datasource_by_name(
- cls, datasource_name: str, schema: Optional[str], database_name: str,
+ cls,
+ session: Session,
+ datasource_name: str,
+ schema: Optional[str],
+ database_name: str,
) -> Optional["SqlaTable"]:
schema = schema or None
query = (
- db.session.query(cls)
+ session.query(cls)
.join(Database)
.filter(cls.table_name == datasource_name)
.filter(Database.database_name == database_name)
@@ -1292,15 +1296,24 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
)
return import_datasource.import_datasource(
- i_datasource, lookup_database, lookup_sqlatable, import_time, database_id,
+ db.session,
+ i_datasource,
+ lookup_database,
+ lookup_sqlatable,
+ import_time,
+ database_id,
)
@classmethod
def query_datasources_by_name(
- cls, database: Database, datasource_name: str, schema: Optional[str] = None,
+ cls,
+ session: Session,
+ database: Database,
+ datasource_name: str,
+ schema: Optional[str] = None,
) -> List["SqlaTable"]:
query = (
- db.session.query(cls)
+ session.query(cls)
.filter_by(database_id=database.id)
.filter_by(table_name=datasource_name)
)
diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py
index 6345bb7..774e1c8 100644
--- a/superset/dashboards/dao.py
+++ b/superset/dashboards/dao.py
@@ -99,7 +99,9 @@ class DashboardDAO(BaseDAO):
except KeyError:
pass
- current_slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
+ session = db.session()
+ current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
+
dashboard.slices = current_slices
# update slice names. this assumes user has permissions to update the slice
@@ -109,8 +111,8 @@ class DashboardDAO(BaseDAO):
new_name = slice_id_to_name[slc.id]
if slc.slice_name != new_name:
slc.slice_name = new_name
- db.session.merge(slc)
- db.session.flush()
+ session.merge(slc)
+ session.flush()
except KeyError:
pass
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index 04e9251..1844543 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -37,7 +37,7 @@ from sqlalchemy import (
UniqueConstraint,
)
from sqlalchemy.engine.base import Connection
-from sqlalchemy.orm import relationship, subqueryload
+from sqlalchemy.orm import relationship, sessionmaker, subqueryload
from sqlalchemy.orm.mapper import Mapper
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
@@ -62,17 +62,18 @@ config = app.config
logger = logging.getLogger(__name__)
-def copy_dashboard( # pylint: disable=unused-argument
- mapper: Mapper, connection: Connection, target: "Dashboard"
-) -> None:
+def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
+ # pylint: disable=unused-argument
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
return
- new_user = db.session.query(User).filter_by(id=target.id).first()
+ session_class = sessionmaker(autoflush=False)
+ session = session_class(bind=connection)
+ new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
- template = db.session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
+ template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
@@ -82,15 +83,15 @@ def copy_dashboard( # pylint: disable=unused-argument
slices=template.slices,
owners=[new_user],
)
- db.session.add(dashboard)
- db.session.commit()
+ session.add(dashboard)
+ session.commit()
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
- db.session.add(extra_attributes)
- db.session.commit()
+ session.add(extra_attributes)
+ session.commit()
sqla.event.listen(User, "after_insert", copy_dashboard)
@@ -306,6 +307,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
logger.info(
"Started import of the dashboard: %s", dashboard_to_import.to_json()
)
+ session = db.session
logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
@@ -322,7 +324,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
- for slc in db.session.query(Slice).all()
+ for slc in session.query(Slice).all()
if "remote_id" in slc.params_dict
}
for slc in slices:
@@ -373,7 +375,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
# override the dashboard
existing_dashboard = None
- for dash in db.session.query(Dashboard).all():
+ for dash in session.query(Dashboard).all():
if (
"remote_id" in dash.params_dict
and dash.params_dict["remote_id"] == dashboard_to_import.id
@@ -400,7 +402,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
)
new_slices = (
- db.session.query(Slice)
+ session.query(Slice)
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
.all()
)
@@ -408,12 +410,12 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
- db.session.flush()
+ session.flush()
return existing_dashboard.id
dashboard_to_import.slices = new_slices
- db.session.add(dashboard_to_import)
- db.session.flush()
+ session.add(dashboard_to_import)
+ session.flush()
return dashboard_to_import.id # type: ignore
@classmethod
@@ -455,7 +457,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
eager_datasources = []
for datasource_id, datasource_type in datasource_ids:
eager_datasource = ConnectorRegistry.get_eager_datasource(
- datasource_type, datasource_id
+ db.session, datasource_type, datasource_id
)
copied_datasource = eager_datasource.copy()
copied_datasource.alter_params(
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index c67c6b6..d903d27 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -34,9 +34,9 @@ from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, or_, UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
-from superset.extensions import db
from superset.utils.core import QueryStatus
logger = logging.getLogger(__name__)
@@ -127,6 +127,7 @@ class ImportMixin:
@classmethod
def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
cls,
+ session: Session,
dict_rep: Dict[Any, Any],
parent: Optional[Any] = None,
recursive: bool = True,
@@ -177,7 +178,7 @@ class ImportMixin:
# Check if object already exists in DB, break if more than one is found
try:
- obj_query = db.session.query(cls).filter(and_(*filters))
+ obj_query = session.query(cls).filter(and_(*filters))
obj = obj_query.one_or_none()
except MultipleResultsFound as ex:
logger.error(
@@ -195,7 +196,7 @@ class ImportMixin:
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
- db.session.add(obj)
+ session.add(obj)
else:
is_new_obj = False
logger.info("Updating %s %s", obj.__tablename__, str(obj))
@@ -213,7 +214,7 @@ class ImportMixin:
for c_obj in new_children.get(child, []):
added.append(
child_class.import_from_dict(
- dict_rep=c_obj, parent=obj, sync=sync
+ session=session, dict_rep=c_obj, parent=obj, sync=sync
)
)
# If children should get synced, delete the ones that did not
@@ -227,11 +228,11 @@ class ImportMixin:
for k in back_refs.keys()
]
to_delete = set(
- db.session.query(child_class).filter(and_(*delete_filters))
+ session.query(child_class).filter(and_(*delete_filters))
).difference(set(added))
for o in to_delete:
logger.info("Deleting %s %s", child, str(obj))
- db.session.delete(o)
+ session.delete(o)
return obj
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 0a2e7d5..b7f9e05 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -300,6 +300,7 @@ class Slice(
:returns: The resulting id for the imported slice
:rtype: int
"""
+ session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
@@ -308,6 +309,7 @@ class Slice(
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = ConnectorRegistry.get_datasource_by_name(
+ session,
slc_to_import.datasource_type,
params["datasource_name"],
params["schema"],
@@ -316,11 +318,11 @@ class Slice(
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:
slc_to_override.override(slc_to_import)
- db.session.flush()
+ session.flush()
return slc_to_override.id
- db.session.add(slc_to_import)
+ session.add(slc_to_import)
logger.info("Final slice: %s", str(slc_to_import.to_json()))
- db.session.flush()
+ session.flush()
return slc_to_import.id
@property
diff --git a/superset/models/tags.py b/superset/models/tags.py
index 1302ff5..c09bb16 100644
--- a/superset/models/tags.py
+++ b/superset/models/tags.py
@@ -22,11 +22,10 @@ from typing import List, Optional, TYPE_CHECKING, Union
from flask_appbuilder import Model
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
from sqlalchemy.engine.base import Connection
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, Session, sessionmaker
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.mapper import Mapper
-from superset.extensions import db
from superset.models.helpers import AuditMixinNullable
if TYPE_CHECKING:
@@ -35,6 +34,8 @@ if TYPE_CHECKING:
from superset.models.slice import Slice # pylint: disable=unused-import
from superset.models.sql_lab import Query # pylint: disable=unused-import
+Session = sessionmaker(autoflush=False)
+
class TagTypes(enum.Enum):
@@ -87,13 +88,13 @@ class TaggedObject(Model, AuditMixinNullable):
tag = relationship("Tag", backref="objects")
-def get_tag(name: str, type_: TagTypes) -> Tag:
+def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
try:
- tag = db.session.query(Tag).filter_by(name=name, type=type_).one()
+ tag = session.query(Tag).filter_by(name=name, type=type_).one()
except NoResultFound:
tag = Tag(name=name, type=type_)
- db.session.add(tag)
- db.session.commit()
+ session.add(tag)
+ session.commit()
return tag
@@ -121,43 +122,52 @@ class ObjectUpdater:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
- def _add_owners(cls, target: Union["Dashboard", "FavStar", "Slice"]) -> None:
+ def _add_owners(
+ cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
+ ) -> None:
for owner_id in cls.get_owners_ids(target):
name = "owner:{0}".format(owner_id)
- tag = get_tag(name, TagTypes.owner)
+ tag = get_tag(name, session, TagTypes.owner)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
- db.session.add(tagged_object)
+ session.add(tagged_object)
@classmethod
- def after_insert( # pylint: disable=unused-argument
+ def after_insert(
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
+ # pylint: disable=unused-argument
+ session = Session(bind=connection)
+
# add `owner:` tags
- cls._add_owners(target)
+ cls._add_owners(session, target)
# add `type:` tags
- tag = get_tag("type:{0}".format(cls.object_type), TagTypes.type)
+ tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
- db.session.add(tagged_object)
- db.session.commit()
+ session.add(tagged_object)
+
+ session.commit()
@classmethod
- def after_update( # pylint: disable=unused-argument
+ def after_update(
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
+ # pylint: disable=unused-argument
+ session = Session(bind=connection)
+
# delete current `owner:` tags
query = (
- db.session.query(TaggedObject.id)
+ session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
@@ -166,28 +176,32 @@ class ObjectUpdater:
)
)
ids = [row[0] for row in query]
- db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
+ session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
# add `owner:` tags
- cls._add_owners(target)
- db.session.commit()
+ cls._add_owners(session, target)
+
+ session.commit()
@classmethod
- def after_delete( # pylint: disable=unused-argument
+ def after_delete(
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
+ # pylint: disable=unused-argument
+ session = Session(bind=connection)
+
# delete row from `tagged_objects`
- db.session.query(TaggedObject).filter(
+ session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
- db.session.commit()
+ session.commit()
class ChartUpdater(ObjectUpdater):
@@ -219,26 +233,31 @@ class QueryUpdater(ObjectUpdater):
class FavStarUpdater:
@classmethod
- def after_insert( # pylint: disable=unused-argument
+ def after_insert(
cls, mapper: Mapper, connection: Connection, target: "FavStar"
) -> None:
+ # pylint: disable=unused-argument
+ session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
- tag = get_tag(name, TagTypes.favorited_by)
+ tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
- db.session.add(tagged_object)
- db.session.commit()
+ session.add(tagged_object)
+
+ session.commit()
@classmethod
- def after_delete( # pylint: disable=unused-argument
- cls, mapper: Mapper, connection: Connection, target: "FavStar",
+ def after_delete(
+ cls, mapper: Mapper, connection: Connection, target: "FavStar"
) -> None:
+ # pylint: disable=unused-argument
+ session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
query = (
- db.session.query(TaggedObject.id)
+ session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
@@ -247,8 +266,8 @@ class FavStarUpdater:
)
)
ids = [row[0] for row in query]
- db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
+ session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
- db.session.commit()
+ session.commit()
diff --git a/superset/security/manager.py b/superset/security/manager.py
index f731c42..da92d16 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -507,7 +507,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
- database, user_perms, schema_perms
+ self.get_session, database, user_perms, schema_perms
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
@@ -568,7 +568,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self.add_permission_view_menu(view_menu, perm)
logger.info("Creating missing datasource permissions.")
- datasources = ConnectorRegistry.get_all_datasources()
+ datasources = ConnectorRegistry.get_all_datasources(self.get_session)
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())
@@ -901,7 +901,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if not (schema_perm and self.can_access("schema_access", schema_perm)):
datasources = SqlaTable.query_datasources_by_name(
- database, table_.table, schema=table_.schema
+ self.get_session, database, table_.table, schema=table_.schema
)
# Access to any datasource is suffice.
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index d941473..8c3f24f 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -132,7 +132,8 @@ def session_scope(nullpool: bool) -> Iterator[Session]:
)
if nullpool:
engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
- session_class = sessionmaker(bind=engine)
+ session_class = sessionmaker()
+ session_class.configure(bind=engine)
session = session_class()
else:
session = db.session()
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index f4c9e32..54b0dc1 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -134,7 +134,8 @@ class DummyStrategy(Strategy):
name = "dummy"
def get_urls(self) -> List[str]:
- charts = db.session.query(Slice).all()
+ session = db.create_scoped_session()
+ charts = session.query(Slice).all()
return [get_url(chart) for chart in charts]
@@ -166,9 +167,10 @@ class TopNDashboardsStrategy(Strategy):
def get_urls(self) -> List[str]:
urls = []
+ session = db.create_scoped_session()
records = (
- db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
+ session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
@@ -176,9 +178,7 @@ class TopNDashboardsStrategy(Strategy):
.all()
)
dash_ids = [record.dashboard_id for record in records]
- dashboards = (
- db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
- )
+ dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard)
@@ -211,13 +211,14 @@ class DashboardTagsStrategy(Strategy):
def get_urls(self) -> List[str]:
urls = []
+ session = db.create_scoped_session()
- tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
+ tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
# add dashboards that are tagged
tagged_objects = (
- db.session.query(TaggedObject)
+ session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
@@ -227,16 +228,14 @@ class DashboardTagsStrategy(Strategy):
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
- tagged_dashboards = db.session.query(Dashboard).filter(
- Dashboard.id.in_(dash_ids)
- )
+ tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
urls.append(get_url(chart))
# add charts that are tagged
tagged_objects = (
- db.session.query(TaggedObject)
+ session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
@@ -246,7 +245,7 @@ class DashboardTagsStrategy(Strategy):
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
- tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
+ tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
urls.append(get_url(chart))
diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py
index 0a74ddf..4fd55aa 100644
--- a/superset/tasks/schedules.py
+++ b/superset/tasks/schedules.py
@@ -47,6 +47,7 @@ from flask_login import login_user
from retry.api import retry_call
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
+from sqlalchemy.orm import Session
from werkzeug.http import parse_cookie
from superset import app, db, security_manager, thumbnail_cache
@@ -541,7 +542,8 @@ def schedule_alert_query( # pylint: disable=unused-argument
is_test_alert: Optional[bool] = False,
) -> None:
model_cls = get_scheduler_model(report_type)
- schedule = db.session.query(model_cls).get(schedule_id)
+ dbsession = db.create_scoped_session()
+ schedule = dbsession.query(model_cls).get(schedule_id)
# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
@@ -553,7 +555,7 @@ def schedule_alert_query( # pylint: disable=unused-argument
deliver_alert(schedule.id, recipients)
return
- if run_alert_query(schedule.id):
+ if run_alert_query(schedule.id, dbsession):
# deliver_dashboard OR deliver_slice
return
else:
@@ -616,7 +618,7 @@ def deliver_alert(alert_id: int, recipients: Optional[str] = None) -> None:
_deliver_email(recipients, deliver_as_group, subject, body, data, images)
-def run_alert_query(alert_id: int) -> Optional[bool]:
+def run_alert_query(alert_id: int, dbsession: Session) -> Optional[bool]:
"""
Execute alert.sql and return value if any rows are returned
"""
@@ -670,7 +672,7 @@ def run_alert_query(alert_id: int) -> Optional[bool]:
state=state,
)
)
- db.session.commit()
+ dbsession.commit()
return None
@@ -710,7 +712,8 @@ def schedule_window(
if not model_cls:
return None
- schedules = db.session.query(model_cls).filter(model_cls.active.is_(True))
+ dbsession = db.create_scoped_session()
+ schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
for schedule in schedules:
logging.info("Processing schedule %s", schedule)
diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py
index f8f673d..6ae500b 100644
--- a/superset/utils/dashboard_import_export.py
+++ b/superset/utils/dashboard_import_export.py
@@ -22,10 +22,10 @@ from io import BytesIO
from typing import Any, Dict, Optional
from flask_babel import lazy_gettext as _
+from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.exceptions import DashboardImportException
-from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -71,6 +71,7 @@ def decode_dashboards( # pylint: disable=too-many-return-statements
def import_dashboards(
+ session: Session,
data_stream: BytesIO,
database_id: Optional[int] = None,
import_time: Optional[int] = None,
@@ -83,16 +84,16 @@ def import_dashboards(
raise DashboardImportException(_("No data in file"))
for table in data["datasources"]:
type(table).import_obj(table, database_id, import_time=import_time)
- db.session.commit()
+ session.commit()
for dashboard in data["dashboards"]:
Dashboard.import_obj(dashboard, import_time=import_time)
- db.session.commit()
+ session.commit()
-def export_dashboards() -> str:
+def export_dashboards(session: Session) -> str:
"""Returns all dashboards metadata as a json dump"""
logger.info("Starting export")
- dashboards = db.session.query(Dashboard)
+ dashboards = session.query(Dashboard)
dashboard_ids = []
for dashboard in dashboards:
dashboard_ids.append(dashboard.id)
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index 8edae22..4d9e049 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -17,8 +17,9 @@
import logging
from typing import Any, Dict, List, Optional
+from sqlalchemy.orm import Session
+
from superset.connectors.druid.models import DruidCluster
-from superset.extensions import db
from superset.models.core import Database
DATABASES_KEY = "databases"
@@ -43,11 +44,11 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
def export_to_dict(
- recursive: bool, back_references: bool, include_defaults: bool
+ session: Session, recursive: bool, back_references: bool, include_defaults: bool
) -> Dict[str, Any]:
"""Exports databases and druid clusters to a dictionary"""
logger.info("Starting export")
- dbs = db.session.query(Database)
+ dbs = session.query(Database)
databases = [
database.export_to_dict(
recursive=recursive,
@@ -57,7 +58,7 @@ def export_to_dict(
for database in dbs
]
logger.info("Exported %d %s", len(databases), DATABASES_KEY)
- cls = db.session.query(DruidCluster)
+ cls = session.query(DruidCluster)
clusters = [
cluster.export_to_dict(
recursive=recursive,
@@ -75,20 +76,22 @@ def export_to_dict(
return data
-def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None:
+def import_from_dict(
+ session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
+) -> None:
"""Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
- Database.import_from_dict(database, sync=sync)
+ Database.import_from_dict(session, database, sync=sync)
logger.info(
"Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
)
for datasource in data.get(DRUID_CLUSTERS_KEY, []):
- DruidCluster.import_from_dict(datasource, sync=sync)
- db.session.commit()
+ DruidCluster.import_from_dict(session, datasource, sync=sync)
+ session.commit()
else:
logger.info("Supplied object is not a dictionary.")
diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py
index a59a3d6..25da876 100644
--- a/superset/utils/import_datasource.py
+++ b/superset/utils/import_datasource.py
@@ -18,14 +18,14 @@ import logging
from typing import Callable, Optional
from flask_appbuilder import Model
+from sqlalchemy.orm import Session
from sqlalchemy.orm.session import make_transient
-from superset.extensions import db
-
logger = logging.getLogger(__name__)
def import_datasource( # pylint: disable=too-many-arguments
+ session: Session,
i_datasource: Model,
lookup_database: Callable[[Model], Model],
lookup_datasource: Callable[[Model], Model],
@@ -52,11 +52,11 @@ def import_datasource( # pylint: disable=too-many-arguments
if datasource:
datasource.override(i_datasource)
- db.session.flush()
+ session.flush()
else:
datasource = i_datasource.copy()
- db.session.add(datasource)
- db.session.flush()
+ session.add(datasource)
+ session.flush()
for metric in i_datasource.metrics:
new_m = metric.copy()
@@ -81,11 +81,13 @@ def import_datasource( # pylint: disable=too-many-arguments
imported_c = i_datasource.column_class.import_obj(new_c)
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
datasource.columns.append(imported_c)
- db.session.flush()
+ session.flush()
return datasource.id
-def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
+def import_simple_obj(
+ session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
+) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
@@ -95,9 +97,9 @@ def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Mod
i_obj.table = None
if existing_column:
existing_column.override(i_obj)
- db.session.flush()
+ session.flush()
return existing_column
- db.session.add(i_obj)
- db.session.flush()
+ session.add(i_obj)
+ session.flush()
return i_obj
diff --git a/superset/views/base.py b/superset/views/base.py
index 58c4943..7aeae79 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -487,7 +487,8 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
roles = [r.name for r in get_user_roles()]
if "Admin" in roles:
return True
- orig_obj = db.session.query(obj.__class__).filter_by(id=obj.id).first()
+ scoped_session = db.create_scoped_session()
+ orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
# Making a list of owners that works across ORM models
owners: List[User] = []
diff --git a/superset/views/chart/views.py b/superset/views/chart/views.py
index db100a7..0523e33 100644
--- a/superset/views/chart/views.py
+++ b/superset/views/chart/views.py
@@ -20,7 +20,7 @@ from flask_appbuilder import expose, has_access
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
-from superset import app
+from superset import app, db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import RouteMethod
from superset.models.slice import Slice
@@ -56,7 +56,7 @@ class SliceModelView(
def add(self) -> FlaskResponse:
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
- for d in ConnectorRegistry.get_all_datasources()
+ for d in ConnectorRegistry.get_all_datasources(db.session)
]
return self.render_template(
"superset/add_slice.html",
diff --git a/superset/views/core.py b/superset/views/core.py
index bcbce02..f3a2264 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -40,6 +40,7 @@ from sqlalchemy.exc import (
OperationalError,
SQLAlchemyError,
)
+from sqlalchemy.orm.session import Session
from werkzeug.urls import Href
import superset.models.core as models
@@ -163,7 +164,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
sorted(
[
datasource.short_data
- for datasource in ConnectorRegistry.get_all_datasources()
+ for datasource in ConnectorRegistry.get_all_datasources(db.session)
if datasource.short_data.get("name")
],
key=lambda datasource: datasource["name"],
@@ -202,7 +203,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
db_ds_names.add(fullname)
- existing_datasources = ConnectorRegistry.get_all_datasources()
+ existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
role = security_manager.find_role(role_name)
# remove all permissions
@@ -269,15 +270,15 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@has_access
@expose("/approve")
def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use
- def clean_fulfilled_requests() -> None:
- for dar in db.session.query(DAR).all():
+ def clean_fulfilled_requests(session: Session) -> None:
+ for dar in session.query(DAR).all():
datasource = ConnectorRegistry.get_datasource(
- dar.datasource_type, dar.datasource_id
+ dar.datasource_type, dar.datasource_id, session
)
if not datasource or security_manager.can_access_datasource(datasource):
# datasource does not exist anymore
- db.session.delete(dar)
- db.session.commit()
+ session.delete(dar)
+ session.commit()
datasource_type = request.args["datasource_type"]
datasource_id = request.args["datasource_id"]
@@ -285,7 +286,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
role_to_grant = request.args.get("role_to_grant")
role_to_extend = request.args.get("role_to_extend")
- datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+ session = db.session
+ datasource = ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, session
+ )
if not datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
@@ -297,7 +301,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(USER_MISSING_ERR)
requests = (
- db.session.query(DAR)
+ session.query(DAR)
.filter(
DAR.datasource_id == datasource_id,
DAR.datasource_type == datasource_type,
@@ -357,13 +361,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
app.config,
)
flash(msg, "info")
- clean_fulfilled_requests()
+ clean_fulfilled_requests(session)
else:
flash(__("You have no permission to approve this request"), "danger")
return redirect("/accessrequestsmodelview/list/")
for request_ in requests:
- db.session.delete(request_)
- db.session.commit()
+ session.delete(request_)
+ session.commit()
return redirect("/accessrequestsmodelview/list/")
@has_access
@@ -544,7 +548,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
database_id = request.form.get("db_id")
try:
dashboard_import_export.import_dashboards(
- import_file.stream, database_id
+ db.session, import_file.stream, database_id
)
success = True
except DatabaseNotFound as ex:
@@ -626,7 +630,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return redirect(error_redirect)
datasource = ConnectorRegistry.get_datasource(
- cast(str, datasource_type), datasource_id
+ cast(str, datasource_type), datasource_id, db.session
)
if not datasource:
flash(DATASOURCE_MISSING_ERR, "danger")
@@ -745,7 +749,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
:raises SupersetSecurityException: If the user cannot access the resource
"""
# TODO: Cache endpoint by user, datasource and column
- datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+ datasource = ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, db.session
+ )
if not datasource:
return json_error_response(DATASOURCE_MISSING_ERR)
@@ -1009,9 +1015,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, dashboard_id: int
) -> FlaskResponse:
"""Copy dashboard"""
+ session = db.session()
data = json.loads(request.form["data"])
dash = models.Dashboard()
- original_dash = db.session.query(Dashboard).get(dashboard_id)
+ original_dash = session.query(Dashboard).get(dashboard_id)
dash.owners = [g.user] if g.user else []
dash.dashboard_title = data["dashboard_title"]
@@ -1022,8 +1029,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
for slc in original_dash.slices:
new_slice = slc.clone()
new_slice.owners = [g.user] if g.user else []
- db.session.add(new_slice)
- db.session.flush()
+ session.add(new_slice)
+ session.flush()
new_slice.dashboards.append(dash)
old_to_new_slice_ids[slc.id] = new_slice.id
@@ -1039,9 +1046,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
dash.params = original_dash.params
DashboardDAO.set_dash_metadata(dash, data, old_to_new_slice_ids)
- db.session.add(dash)
- db.session.commit()
+ session.add(dash)
+ session.commit()
dash_json = json.dumps(dash.data)
+ session.close()
return json_success(dash_json)
@api
@@ -1051,12 +1059,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, dashboard_id: int
) -> FlaskResponse:
"""Save a dashboard's metadata"""
- dash = db.session.query(Dashboard).get(dashboard_id)
+ session = db.session()
+ dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
data = json.loads(request.form["data"])
DashboardDAO.set_dash_metadata(dash, data)
- db.session.merge(dash)
- db.session.commit()
+ session.merge(dash)
+ session.commit()
+ session.close()
return json_success(json.dumps({"status": "SUCCESS"}))
@api
@@ -1067,12 +1077,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
) -> FlaskResponse:
"""Add and save slices to a dashboard"""
data = json.loads(request.form["data"])
- dash = db.session.query(Dashboard).get(dashboard_id)
+ session = db.session()
+ dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
- new_slices = db.session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
+ new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
dash.slices += new_slices
- db.session.merge(dash)
- db.session.commit()
+ session.merge(dash)
+ session.commit()
+ session.close()
return "SLICES ADDED"
@api
@@ -1419,6 +1431,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
Note for slices a force refresh occurs.
"""
+ session = db.session()
slice_id = request.args.get("slice_id")
dashboard_id = request.args.get("dashboard_id")
table_name = request.args.get("table_name")
@@ -1433,14 +1446,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
status=400,
)
if slice_id:
- slices = db.session.query(Slice).filter_by(id=slice_id).all()
+ slices = session.query(Slice).filter_by(id=slice_id).all()
if not slices:
return json_error_response(
__("Chart %(id)s not found", id=slice_id), status=404
)
elif table_name and db_name:
table = (
- db.session.query(SqlaTable)
+ session.query(SqlaTable)
.join(models.Database)
.filter(
models.Database.database_name == db_name
@@ -1457,7 +1470,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
status=404,
)
slices = (
- db.session.query(Slice)
+ session.query(Slice)
.filter_by(datasource_id=table.id, datasource_type=table.type)
.all()
)
@@ -1500,16 +1513,17 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, class_name: str, obj_id: int, action: str
) -> FlaskResponse:
"""Toggle favorite stars on Slices and Dashboard"""
+ session = db.session()
FavStar = models.FavStar
count = 0
favs = (
- db.session.query(FavStar)
+ session.query(FavStar)
.filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id())
.all()
)
if action == "select":
if not favs:
- db.session.add(
+ session.add(
FavStar(
class_name=class_name,
obj_id=obj_id,
@@ -1520,10 +1534,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
count = 1
elif action == "unselect":
for fav in favs:
- db.session.delete(fav)
+ session.delete(fav)
else:
count = len(favs)
- db.session.commit()
+ session.commit()
return json_success(json.dumps({"count": count}))
@api
@@ -1536,13 +1550,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
logger.warning(
"This API endpoint is deprecated and will be removed in version 1.0.0"
)
+ session = db.session()
Role = ab_models.Role
dash = (
- db.session.query(Dashboard)
- .filter(Dashboard.id == dashboard_id)
- .one_or_none()
+ session.query(Dashboard).filter(Dashboard.id == dashboard_id).one_or_none()
)
- admin_role = db.session.query(Role).filter(Role.name == "Admin").one_or_none()
+ admin_role = session.query(Role).filter(Role.name == "Admin").one_or_none()
if request.method == "GET":
if dash:
@@ -1561,7 +1574,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
dash.published = str(request.form["published"]).lower() == "true"
- db.session.commit()
+ session.commit()
return json_success(json.dumps({"published": dash.published}))
@has_access
@@ -1570,7 +1583,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, dashboard_id_or_slug: str
) -> FlaskResponse:
"""Server side rendering for a dashboard"""
- qry = db.session.query(Dashboard)
+ session = db.session()
+ qry = session.query(Dashboard)
if dashboard_id_or_slug.isdigit():
qry = qry.filter_by(id=int(dashboard_id_or_slug))
else:
@@ -2028,7 +2042,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"SQL validation does not support template parameters", status=400
)
- mydb = db.session.query(models.Database).filter_by(id=database_id).one_or_none()
+ session = db.session()
+ mydb = session.query(models.Database).filter_by(id=database_id).one_or_none()
if not mydb:
return json_error_response(
"Database with id {} is missing.".format(database_id), status=400
@@ -2077,6 +2092,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@staticmethod
def _sql_json_async( # pylint: disable=too-many-arguments
+ session: Session,
rendered_query: str,
query: Query,
expand_data: bool,
@@ -2085,6 +2101,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""
Send SQL JSON query to celery workers.
+ :param session: SQLAlchemy session object
:param rendered_query: the rendered query to perform by workers
:param query: The query (SQLAlchemy) object
:return: A Flask Response
@@ -2115,7 +2132,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
query.status = QueryStatus.FAILED
query.error_message = msg
- db.session.commit()
+ session.commit()
return json_error_response("{}".format(msg))
resp = json_success(
json.dumps(
@@ -2125,11 +2142,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
),
status=202,
)
- db.session.commit()
+ session.commit()
return resp
@staticmethod
def _sql_json_sync(
+ _session: Session,
rendered_query: str,
query: Query,
expand_data: bool,
@@ -2223,7 +2241,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
tab_name: str = cast(str, query_params.get("tab"))
status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
- mydb = db.session.query(models.Database).get(database_id)
+ session = db.session()
+ mydb = session.query(models.Database).get(database_id)
if not mydb:
return json_error_response("Database with id %i is missing.", database_id)
@@ -2254,13 +2273,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
client_id=client_id,
)
try:
- db.session.add(query)
- db.session.flush()
+ session.add(query)
+ session.flush()
query_id = query.id
- db.session.commit() # shouldn't be necessary
+ session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex))
- db.session.rollback()
+ session.rollback()
raise Exception(_("Query record was not created as expected."))
if not query_id:
raise Exception(_("Query record was not created as expected."))
@@ -2271,7 +2290,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
query.raise_for_access()
except SupersetSecurityException as ex:
query.status = QueryStatus.FAILED
- db.session.commit()
+ session.commit()
return json_errors_response([ex.error], status=403)
try:
@@ -2304,9 +2323,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# Async request.
if async_flag:
- return self._sql_json_async(rendered_query, query, expand_data, log_params)
+ return self._sql_json_async(
+ session, rendered_query, query, expand_data, log_params
+ )
# Sync request.
- return self._sql_json_sync(rendered_query, query, expand_data, log_params)
+ return self._sql_json_sync(
+ session, rendered_query, query, expand_data, log_params
+ )
@has_access
@expose("/csv/<client_id>")
@@ -2375,7 +2398,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""
datasource_id, datasource_type = request.args["datasourceKey"].split("__")
- datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+ datasource = ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, db.session
+ )
# Check if datasource exists
if not datasource:
return json_error_response(DATASOURCE_MISSING_ERR)
diff --git a/superset/views/datasource.py b/superset/views/datasource.py
index c2affcb..2ce1102 100644
--- a/superset/views/datasource.py
+++ b/superset/views/datasource.py
@@ -47,7 +47,7 @@ class Datasource(BaseSupersetView):
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
orm_datasource = ConnectorRegistry.get_datasource(
- datasource_type, datasource_id
+ datasource_type, datasource_id, db.session
)
orm_datasource.database_id = database_id
@@ -82,7 +82,7 @@ class Datasource(BaseSupersetView):
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
try:
orm_datasource = ConnectorRegistry.get_datasource(
- datasource_type, datasource_id
+ datasource_type, datasource_id, db.session
)
if not orm_datasource.data:
return json_error_response(
@@ -102,7 +102,7 @@ class Datasource(BaseSupersetView):
"""Gets column info from the source system"""
if datasource_type == "druid":
datasource = ConnectorRegistry.get_datasource(
- datasource_type, datasource_id
+ datasource_type, datasource_id, db.session
)
elif datasource_type == "table":
database = (
diff --git a/superset/views/utils.py b/superset/views/utils.py
index dd2aa06..2a8b2cc 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -105,7 +105,9 @@ def get_viz(
form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
- datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+ datasource = ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, db.session
+ )
viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
return viz_obj
@@ -291,7 +293,8 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"]
def get_dashboard_extra_filters(
slice_id: int, dashboard_id: int
) -> List[Dict[str, Any]]:
- dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
+ session = db.session()
+ dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
# is chart in this dashboard?
if (
diff --git a/tests/access_tests.py b/tests/access_tests.py
index 10a4867..d452d13 100644
--- a/tests/access_tests.py
+++ b/tests/access_tests.py
@@ -71,17 +71,13 @@ DB_ACCESS_ROLE = "db_access_role"
SCHEMA_ACCESS_ROLE = "schema_access_role"
-def create_access_request(ds_type, ds_name, role_name, user_name):
+def create_access_request(session, ds_type, ds_name, role_name, user_name):
ds_class = ConnectorRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == "table":
- ds = db.session.query(ds_class).filter(ds_class.table_name == ds_name).first()
+ ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first()
else:
- ds = (
- db.session.query(ds_class)
- .filter(ds_class.datasource_name == ds_name)
- .first()
- )
+ ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first()
ds_perm_view = security_manager.find_permission_view_menu(
"datasource_access", ds.perm
)
@@ -93,8 +89,8 @@ def create_access_request(ds_type, ds_name, role_name, user_name):
datasource_type=ds_type,
created_by_fk=security_manager.find_user(username=user_name).id,
)
- db.session.add(access_request)
- db.session.commit()
+ session.add(access_request)
+ session.commit()
return access_request
@@ -130,6 +126,7 @@ class TestRequestAccess(SupersetTestCase):
override_me = security_manager.find_role("override_me")
override_me.permissions = []
db.session.commit()
+ db.session.close()
def test_override_role_permissions_is_admin_only(self):
self.logout()
@@ -214,6 +211,7 @@ class TestRequestAccess(SupersetTestCase):
)
def test_clean_requests_after_role_extend(self):
+ session = db.session
# Case 1. Gamma and gamma2 requested test_role1 on energy_usage access
# Gamma already has role test_role1
@@ -223,10 +221,12 @@ class TestRequestAccess(SupersetTestCase):
# gamma2 and gamma request table_role on energy usage
if app.config["ENABLE_ACCESS_REQUEST"]:
access_request1 = create_access_request(
- "table", "random_time_series", TEST_ROLE_1, "gamma2"
+ session, "table", "random_time_series", TEST_ROLE_1, "gamma2"
)
ds_1_id = access_request1.datasource_id
- create_access_request("table", "random_time_series", TEST_ROLE_1, "gamma")
+ create_access_request(
+ session, "table", "random_time_series", TEST_ROLE_1, "gamma"
+ )
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
self.assertTrue(access_requests)
# gamma gets test_role1
@@ -244,20 +244,22 @@ class TestRequestAccess(SupersetTestCase):
gamma_user.roles.remove(security_manager.find_role("test_role1"))
def test_clean_requests_after_alpha_grant(self):
+ session = db.session
+
# Case 2. Two access requests from gamma and gamma2
# Gamma becomes alpha, gamma2 gets granted
# Check if request by gamma has been deleted
access_request1 = create_access_request(
- "table", "birth_names", TEST_ROLE_1, "gamma"
+ session, "table", "birth_names", TEST_ROLE_1, "gamma"
)
- create_access_request("table", "birth_names", TEST_ROLE_2, "gamma2")
+ create_access_request(session, "table", "birth_names", TEST_ROLE_2, "gamma2")
ds_1_id = access_request1.datasource_id
# gamma becomes alpha
alpha_role = security_manager.find_role("Alpha")
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.append(alpha_role)
- db.session.commit()
+ session.commit()
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
self.assertTrue(access_requests)
self.client.get(
@@ -268,21 +270,23 @@ class TestRequestAccess(SupersetTestCase):
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("Alpha"))
- db.session.commit()
+ session.commit()
def test_clean_requests_after_db_grant(self):
+ session = db.session
+
# Case 3. Two access requests from gamma and gamma2
# Gamma gets database access, gamma2 access request granted
# Check if request by gamma has been deleted
gamma_user = security_manager.find_user(username="gamma")
access_request1 = create_access_request(
- "table", "energy_usage", TEST_ROLE_1, "gamma"
+ session, "table", "energy_usage", TEST_ROLE_1, "gamma"
)
- create_access_request("table", "energy_usage", TEST_ROLE_2, "gamma2")
+ create_access_request(session, "table", "energy_usage", TEST_ROLE_2, "gamma2")
ds_1_id = access_request1.datasource_id
# gamma gets granted database access
- database = db.session.query(models.Database).first()
+ database = session.query(models.Database).first()
security_manager.add_permission_view_menu("database_access", database.perm)
ds_perm_view = security_manager.find_permission_view_menu(
@@ -292,7 +296,7 @@ class TestRequestAccess(SupersetTestCase):
security_manager.find_role(DB_ACCESS_ROLE), ds_perm_view
)
gamma_user.roles.append(security_manager.find_role(DB_ACCESS_ROLE))
- db.session.commit()
+ session.commit()
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
self.assertTrue(access_requests)
# gamma2 request gets fulfilled
@@ -304,21 +308,25 @@ class TestRequestAccess(SupersetTestCase):
self.assertFalse(access_requests)
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role(DB_ACCESS_ROLE))
- db.session.commit()
+ session.commit()
def test_clean_requests_after_schema_grant(self):
+ session = db.session
+
# Case 4. Two access requests from gamma and gamma2
# Gamma gets schema access, gamma2 access request granted
# Check if request by gamma has been deleted
gamma_user = security_manager.find_user(username="gamma")
access_request1 = create_access_request(
- "table", "wb_health_population", TEST_ROLE_1, "gamma"
+ session, "table", "wb_health_population", TEST_ROLE_1, "gamma"
+ )
+ create_access_request(
+ session, "table", "wb_health_population", TEST_ROLE_2, "gamma2"
)
- create_access_request("table", "wb_health_population", TEST_ROLE_2, "gamma2")
ds_1_id = access_request1.datasource_id
ds = (
- db.session.query(SqlaTable)
+ session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
@@ -332,7 +340,7 @@ class TestRequestAccess(SupersetTestCase):
security_manager.find_role(SCHEMA_ACCESS_ROLE), schema_perm_view
)
gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
- db.session.commit()
+ session.commit()
# gamma2 request gets fulfilled
self.client.get(
EXTEND_ROLE_REQUEST.format("table", ds_1_id, "gamma2", TEST_ROLE_2)
@@ -343,24 +351,25 @@ class TestRequestAccess(SupersetTestCase):
gamma_user.roles.remove(security_manager.find_role(SCHEMA_ACCESS_ROLE))
ds = (
- db.session.query(SqlaTable)
+ session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
ds.schema = None
- db.session.commit()
+ session.commit()
@mock.patch("superset.utils.core.send_mime_email")
def test_approve(self, mock_send_mime):
if app.config["ENABLE_ACCESS_REQUEST"]:
+ session = db.session
TEST_ROLE_NAME = "table_role"
security_manager.add_role(TEST_ROLE_NAME)
# Case 1. Grant new role to the user.
access_request1 = create_access_request(
- "table", "unicode_test", TEST_ROLE_NAME, "gamma"
+ session, "table", "unicode_test", TEST_ROLE_NAME, "gamma"
)
ds_1_id = access_request1.datasource_id
self.get_resp(
@@ -395,7 +404,7 @@ class TestRequestAccess(SupersetTestCase):
# Case 2. Extend the role to have access to the table
access_request2 = create_access_request(
- "table", "energy_usage", TEST_ROLE_NAME, "gamma"
+ session, "table", "energy_usage", TEST_ROLE_NAME, "gamma"
)
ds_2_id = access_request2.datasource_id
energy_usage_perm = access_request2.datasource.perm
@@ -439,7 +448,7 @@ class TestRequestAccess(SupersetTestCase):
security_manager.add_role("druid_role")
access_request3 = create_access_request(
- "druid", "druid_ds_1", "druid_role", "gamma"
+ session, "druid", "druid_ds_1", "druid_role", "gamma"
)
self.get_resp(
GRANT_ROLE_REQUEST.format(
@@ -454,7 +463,7 @@ class TestRequestAccess(SupersetTestCase):
# Case 4. Extend the role to have access to the druid datasource
access_request4 = create_access_request(
- "druid", "druid_ds_2", "druid_role", "gamma"
+ session, "druid", "druid_ds_2", "druid_role", "gamma"
)
druid_ds_2_perm = access_request4.datasource.perm
@@ -474,18 +483,19 @@ class TestRequestAccess(SupersetTestCase):
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("druid_role"))
gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
- db.session.delete(security_manager.find_role("druid_role"))
- db.session.delete(security_manager.find_role(TEST_ROLE_NAME))
- db.session.commit()
+ session.delete(security_manager.find_role("druid_role"))
+ session.delete(security_manager.find_role(TEST_ROLE_NAME))
+ session.commit()
def test_request_access(self):
if app.config["ENABLE_ACCESS_REQUEST"]:
+ session = db.session
self.logout()
self.login(username="gamma")
gamma_user = security_manager.find_user(username="gamma")
security_manager.add_role("dummy_role")
gamma_user.roles.append(security_manager.find_role("dummy_role"))
- db.session.commit()
+ session.commit()
ACCESS_REQUEST = (
"/superset/request_access?"
@@ -501,7 +511,7 @@ class TestRequestAccess(SupersetTestCase):
# Request table access, there are no roles have this table.
table1 = (
- db.session.query(SqlaTable)
+ session.query(SqlaTable)
.filter_by(table_name="random_time_series")
.first()
)
@@ -516,7 +526,7 @@ class TestRequestAccess(SupersetTestCase):
# Request access, roles exist that contains the table.
# add table to the existing roles
table3 = (
- db.session.query(SqlaTable).filter_by(table_name="energy_usage").first()
+ session.query(SqlaTable).filter_by(table_name="energy_usage").first()
)
table_3_id = table3.id
table3_perm = table3.perm
@@ -535,7 +545,7 @@ class TestRequestAccess(SupersetTestCase):
"datasource_access", table3_perm
),
)
- db.session.commit()
+ session.commit()
self.get_resp(ACCESS_REQUEST.format("table", table_3_id, "go"))
access_request3 = self.get_access_requests("gamma", "table", table_3_id)
@@ -549,7 +559,7 @@ class TestRequestAccess(SupersetTestCase):
# Request druid access, there are no roles have this table.
druid_ds_4 = (
- db.session.query(DruidDatasource)
+ session.query(DruidDatasource)
.filter_by(datasource_name="druid_ds_1")
.first()
)
@@ -564,7 +574,7 @@ class TestRequestAccess(SupersetTestCase):
# Case 5. Roles exist that contains the druid datasource.
# add druid ds to the existing roles
druid_ds_5 = (
- db.session.query(DruidDatasource)
+ session.query(DruidDatasource)
.filter_by(datasource_name="druid_ds_2")
.first()
)
@@ -585,7 +595,7 @@ class TestRequestAccess(SupersetTestCase):
"datasource_access", druid_ds_5_perm
),
)
- db.session.commit()
+ session.commit()
self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go"))
access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id)
@@ -600,7 +610,7 @@ class TestRequestAccess(SupersetTestCase):
# cleanup
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("dummy_role"))
- db.session.commit()
+ session.commit()
if __name__ == "__main__":
diff --git a/tests/alerts_tests.py b/tests/alerts_tests.py
index 0720581..c78847c 100644
--- a/tests/alerts_tests.py
+++ b/tests/alerts_tests.py
@@ -32,118 +32,112 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-def setup_module():
+@pytest.yield_fixture(scope="module")
+def setup_database():
with app.app_context():
slice_id = db.session.query(Slice).all()[0].id
database_id = utils.get_example_database().id
- alerts = [
- Alert(
- id=1,
- label="alert_1",
- active=True,
- crontab="*/1 * * * *",
- sql="SELECT 0",
- alert_type="email",
- slice_id=slice_id,
- database_id=database_id,
- ),
- Alert(
- id=2,
- label="alert_2",
- active=True,
- crontab="*/1 * * * *",
- sql="SELECT 55",
- alert_type="email",
- slice_id=slice_id,
- database_id=database_id,
- ),
- Alert(
- id=3,
- label="alert_3",
- active=False,
- crontab="*/1 * * * *",
- sql="UPDATE 55",
- alert_type="email",
- slice_id=slice_id,
- database_id=database_id,
- ),
- Alert(id=4, active=False, label="alert_4", database_id=-1),
- Alert(id=5, active=False, label="alert_5", database_id=database_id),
- ]
-
- db.session.bulk_save_objects(alerts)
- db.session.commit()
+ alert1 = Alert(
+ id=1,
+ label="alert_1",
+ active=True,
+ crontab="*/1 * * * *",
+ sql="SELECT 0",
+ alert_type="email",
+ slice_id=slice_id,
+ database_id=database_id,
+ )
+ alert2 = Alert(
+ id=2,
+ label="alert_2",
+ active=True,
+ crontab="*/1 * * * *",
+ sql="SELECT 55",
+ alert_type="email",
+ slice_id=slice_id,
+ database_id=database_id,
+ )
+ alert3 = Alert(
+ id=3,
+ label="alert_3",
+ active=False,
+ crontab="*/1 * * * *",
+ sql="UPDATE 55",
+ alert_type="email",
+ slice_id=slice_id,
+ database_id=database_id,
+ )
+ alert4 = Alert(id=4, active=False, label="alert_4", database_id=-1)
+ alert5 = Alert(id=5, active=False, label="alert_5", database_id=database_id)
+ for num in range(1, 6):
+ eval(f"db.session.add(alert{num})")
+ db.session.commit()
+ yield db.session
-def teardown_module():
- with app.app_context():
db.session.query(AlertLog).delete()
db.session.query(Alert).delete()
@patch("superset.tasks.schedules.deliver_alert")
@patch("superset.tasks.schedules.logging.Logger.error")
-def test_run_alert_query(mock_error, mock_deliver_alert):
- with app.app_context():
- run_alert_query(db.session.query(Alert).filter_by(id=1).one().id)
- alert1 = db.session.query(Alert).filter_by(id=1).one()
- assert mock_deliver_alert.call_count == 0
- assert len(alert1.logs) == 1
- assert alert1.logs[0].alert_id == 1
- assert alert1.logs[0].state == "pass"
-
- run_alert_query(db.session.query(Alert).filter_by(id=2).one().id)
- alert2 = db.session.query(Alert).filter_by(id=2).one()
- assert mock_deliver_alert.call_count == 1
- assert len(alert2.logs) == 1
- assert alert2.logs[0].alert_id == 2
- assert alert2.logs[0].state == "trigger"
-
- run_alert_query(db.session.query(Alert).filter_by(id=3).one().id)
- alert3 = db.session.query(Alert).filter_by(id=3).one()
- assert mock_deliver_alert.call_count == 1
- assert mock_error.call_count == 2
- assert len(alert3.logs) == 1
- assert alert3.logs[0].alert_id == 3
- assert alert3.logs[0].state == "error"
-
- run_alert_query(db.session.query(Alert).filter_by(id=4).one().id)
- assert mock_deliver_alert.call_count == 1
- assert mock_error.call_count == 3
-
- run_alert_query(db.session.query(Alert).filter_by(id=5).one().id)
- assert mock_deliver_alert.call_count == 1
- assert mock_error.call_count == 4
+def test_run_alert_query(mock_error, mock_deliver, setup_database):
+ database = setup_database
+ run_alert_query(database.query(Alert).filter_by(id=1).one().id, database)
+ alert1 = database.query(Alert).filter_by(id=1).one()
+ assert mock_deliver.call_count == 0
+ assert len(alert1.logs) == 1
+ assert alert1.logs[0].alert_id == 1
+ assert alert1.logs[0].state == "pass"
+
+ run_alert_query(database.query(Alert).filter_by(id=2).one().id, database)
+ alert2 = database.query(Alert).filter_by(id=2).one()
+ assert mock_deliver.call_count == 1
+ assert len(alert2.logs) == 1
+ assert alert2.logs[0].alert_id == 2
+ assert alert2.logs[0].state == "trigger"
+
+ run_alert_query(database.query(Alert).filter_by(id=3).one().id, database)
+ alert3 = database.query(Alert).filter_by(id=3).one()
+ assert mock_deliver.call_count == 1
+ assert mock_error.call_count == 2
+ assert len(alert3.logs) == 1
+ assert alert3.logs[0].alert_id == 3
+ assert alert3.logs[0].state == "error"
+
+ run_alert_query(database.query(Alert).filter_by(id=4).one().id, database)
+ assert mock_deliver.call_count == 1
+ assert mock_error.call_count == 3
+
+ run_alert_query(database.query(Alert).filter_by(id=5).one().id, database)
+ assert mock_deliver.call_count == 1
+ assert mock_error.call_count == 4
@patch("superset.tasks.schedules.deliver_alert")
@patch("superset.tasks.schedules.run_alert_query")
-def test_schedule_alert_query(mock_run_alert, mock_deliver_alert):
- with app.app_context():
- active_alert = db.session.query(Alert).filter_by(id=1).one()
- inactive_alert = db.session.query(Alert).filter_by(id=3).one()
-
- # Test that inactive alerts are no processed
- schedule_alert_query(
- report_type=ScheduleType.alert, schedule_id=inactive_alert.id
- )
- assert mock_run_alert.call_count == 0
- assert mock_deliver_alert.call_count == 0
-
- # Test that active alerts with no recipients passed in are processed regularly
- schedule_alert_query(
- report_type=ScheduleType.alert, schedule_id=active_alert.id
- )
- assert mock_run_alert.call_count == 1
- assert mock_deliver_alert.call_count == 0
-
- # Test that active alerts sent as a test are delivered immediately
- schedule_alert_query(
- report_type=ScheduleType.alert,
- schedule_id=active_alert.id,
- recipients="testing@email.com",
- is_test_alert=True,
- )
- assert mock_run_alert.call_count == 1
- assert mock_deliver_alert.call_count == 1
+def test_schedule_alert_query(mock_run_alert, mock_deliver_alert, setup_database):
+ database = setup_database
+ active_alert = database.query(Alert).filter_by(id=1).one()
+ inactive_alert = database.query(Alert).filter_by(id=3).one()
+
+ # Test that inactive alerts are no processed
+ schedule_alert_query(report_type=ScheduleType.alert, schedule_id=inactive_alert.id)
+ assert mock_run_alert.call_count == 0
+ assert mock_deliver_alert.call_count == 0
+
+ # Test that active alerts with no recipients passed in are processed regularly
+ schedule_alert_query(report_type=ScheduleType.alert, schedule_id=active_alert.id)
+ assert mock_run_alert.call_count == 1
+ assert mock_deliver_alert.call_count == 0
+
+ # Test that active alerts sent as a test are delivered immediately
+ schedule_alert_query(
+ report_type=ScheduleType.alert,
+ schedule_id=active_alert.id,
+ recipients="testing@email.com",
+ is_test_alert=True,
+ )
+ assert mock_run_alert.call_count == 1
+ assert mock_deliver_alert.call_count == 1
diff --git a/tests/base_tests.py b/tests/base_tests.py
index c74378f..e0a20a4 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -25,6 +25,7 @@ import pandas as pd
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
+from sqlalchemy.orm import Session
from tests.test_app import app
from superset.sql_parse import CtasMethod
@@ -103,25 +104,24 @@ class SupersetTestCase(TestCase):
# create druid cluster and druid datasources
with app.app_context():
+ session = db.session
cluster = (
- db.session.query(DruidCluster)
- .filter_by(cluster_name="druid_test")
- .first()
+ session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
)
if not cluster:
cluster = DruidCluster(cluster_name="druid_test")
- db.session.add(cluster)
- db.session.commit()
+ session.add(cluster)
+ session.commit()
druid_datasource1 = DruidDatasource(
datasource_name="druid_ds_1", cluster=cluster
)
- db.session.add(druid_datasource1)
+ session.add(druid_datasource1)
druid_datasource2 = DruidDatasource(
datasource_name="druid_ds_2", cluster=cluster
)
- db.session.add(druid_datasource2)
- db.session.commit()
+ session.add(druid_datasource2)
+ session.commit()
@staticmethod
def get_table_by_id(table_id: int) -> SqlaTable:
@@ -135,23 +135,25 @@ class SupersetTestCase(TestCase):
except ImportError:
return False
- def get_or_create(self, cls, criteria, **kwargs):
- obj = db.session.query(cls).filter_by(**criteria).first()
+ def get_or_create(self, cls, criteria, session, **kwargs):
+ obj = session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
obj.__dict__.update(**kwargs)
- db.session.add(obj)
- db.session.commit()
+ session.add(obj)
+ session.commit()
return obj
def login(self, username="admin", password="general"):
resp = self.get_resp("/login/", data=dict(username=username, password=password))
self.assertNotIn("User confirmation needed", resp)
- def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice:
- slc = db.session.query(Slice).filter_by(slice_name=slice_name).one()
+ def get_slice(
+ self, slice_name: str, session: Session, expunge_from_session: bool = True
+ ) -> Slice:
+ slc = session.query(Slice).filter_by(slice_name=slice_name).one()
if expunge_from_session:
- db.session.expunge_all()
+ session.expunge_all()
return slc
@staticmethod
@@ -300,6 +302,7 @@ class SupersetTestCase(TestCase):
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
+ session=db.session,
sqlalchemy_uri="sqlite:///:memory:",
id=db_id,
extra=extra,
@@ -321,6 +324,7 @@ class SupersetTestCase(TestCase):
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
+ session=db.session,
sqlalchemy_uri="presto://user@host:8080/hive",
id=db_id,
)
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 68a7213..53190cb 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -99,13 +99,15 @@ class TestAppContext(SupersetTestCase):
class TestCelery(SupersetTestCase):
def get_query_by_name(self, sql):
- query = db.session.query(Query).filter_by(sql=sql).first()
- db.session.close()
+ session = db.session
+ query = session.query(Query).filter_by(sql=sql).first()
+ session.close()
return query
def get_query_by_id(self, id):
- query = db.session.query(Query).filter_by(id=id).first()
- db.session.close()
+ session = db.session
+ query = session.query(Query).filter_by(id=id).first()
+ session.close()
return query
@classmethod
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index 0820518..5048a0a 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -58,7 +58,9 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
for owner in owners:
user = db.session.query(security_manager.user_model).get(owner)
obj_owners.append(user)
- datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
+ datasource = ConnectorRegistry.get_datasource(
+ datasource_type, datasource_id, db.session
+ )
slice = Slice(
slice_name=slice_name,
datasource_id=datasource.id,
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 2793795..ade0095 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -100,7 +100,7 @@ class TestCore(SupersetTestCase):
def test_slice_endpoint(self):
self.login(username="admin")
- slc = self.get_slice("Girls")
+ slc = self.get_slice("Girls", db.session)
resp = self.get_resp("/superset/slice/{}/".format(slc.id))
assert "Time Column" in resp
assert "List Roles" in resp
@@ -114,7 +114,7 @@ class TestCore(SupersetTestCase):
def test_viz_cache_key(self):
self.login(username="admin")
- slc = self.get_slice("Girls")
+ slc = self.get_slice("Girls", db.session)
viz = slc.viz
qobj = viz.query_obj()
@@ -233,7 +233,7 @@ class TestCore(SupersetTestCase):
def test_save_slice(self):
self.login(username="admin")
slice_name = f"Energy Sankey"
- slice_id = self.get_slice(slice_name).id
+ slice_id = self.get_slice(slice_name, db.session).id
copy_name_prefix = "Test Sankey"
copy_name = f"{copy_name_prefix}[save]{random.random()}"
tbl_id = self.table_ids.get("energy_usage")
@@ -299,7 +299,7 @@ class TestCore(SupersetTestCase):
def test_filter_endpoint(self):
self.login(username="admin")
slice_name = "Energy Sankey"
- slice_id = self.get_slice(slice_name).id
+ slice_id = self.get_slice(slice_name, db.session).id
db.session.commit()
tbl_id = self.table_ids.get("energy_usage")
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
@@ -319,7 +319,9 @@ class TestCore(SupersetTestCase):
def test_slice_data(self):
# slice data should have some required attributes
self.login(username="admin")
- slc = self.get_slice(slice_name="Girls", expunge_from_session=False)
+ slc = self.get_slice(
+ slice_name="Girls", session=db.session, expunge_from_session=False
+ )
slc_data_attributes = slc.data.keys()
assert "changed_on" in slc_data_attributes
assert "modified" in slc_data_attributes
@@ -370,7 +372,9 @@ class TestCore(SupersetTestCase):
self.assertEqual(data, [])
# make user owner of slice and verify that endpoint returns said slice
- slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
+ slc = self.get_slice(
+ slice_name=slice_name, session=db.session, expunge_from_session=False
+ )
slc.owners = [user]
db.session.merge(slc)
db.session.commit()
@@ -381,7 +385,9 @@ class TestCore(SupersetTestCase):
self.assertEqual(data[0]["title"], slice_name)
# remove ownership and ensure user no longer gets slice
- slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
+ slc = self.get_slice(
+ slice_name=slice_name, session=db.session, expunge_from_session=False
+ )
slc.owners = []
db.session.merge(slc)
db.session.commit()
@@ -559,7 +565,7 @@ class TestCore(SupersetTestCase):
db.session.commit()
def test_warm_up_cache(self):
- slc = self.get_slice("Girls")
+ slc = self.get_slice("Girls", db.session)
data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
self.assertEqual(
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
@@ -784,7 +790,7 @@ class TestCore(SupersetTestCase):
def test_user_profile(self, username="admin"):
self.login(username=username)
- slc = self.get_slice("Girls")
+ slc = self.get_slice("Girls", db.session)
# Setting some faves
url = f"/superset/favstar/Slice/{slc.id}/select/"
diff --git a/tests/database_api_tests.py b/tests/database_api_tests.py
index 7126c25..49ede7b 100644
--- a/tests/database_api_tests.py
+++ b/tests/database_api_tests.py
@@ -178,11 +178,12 @@ class TestDatabaseApi(SupersetTestCase):
"""
Database API: Test get select star with datasource access
"""
+ session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
- db.session.add(table)
- db.session.commit()
+ session.add(table)
+ session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py
index 7e078fd..08c8d5a 100644
--- a/tests/datasets/api_tests.py
+++ b/tests/datasets/api_tests.py
@@ -156,7 +156,7 @@ class TestDatasetApi(SupersetTestCase):
"template_params": None,
}
for key, value in expected_result.items():
- self.assertEqual(response["result"][key], value)
+ self.assertEqual(response["result"][key], expected_result[key])
self.assertEqual(len(response["result"]["columns"]), 8)
self.assertEqual(len(response["result"]["metrics"]), 2)
@@ -721,7 +721,10 @@ class TestDatasetApi(SupersetTestCase):
)
cli_export = export_to_dict(
- recursive=True, back_references=False, include_defaults=False,
+ session=db.session,
+ recursive=True,
+ back_references=False,
+ include_defaults=False,
)
cli_export_tables = cli_export["databases"][0]["tables"]
expected_response = []
diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py
index 725ffdf..dc0b8d8 100644
--- a/tests/dict_import_export_tests.py
+++ b/tests/dict_import_export_tests.py
@@ -47,13 +47,14 @@ class TestDictImportExport(SupersetTestCase):
def delete_imports(cls):
with app.app_context():
# Imported data clean up
- for table in db.session.query(SqlaTable):
+ session = db.session
+ for table in session.query(SqlaTable):
if DBREF in table.params_dict:
- db.session.delete(table)
- for datasource in db.session.query(DruidDatasource):
+ session.delete(table)
+ for datasource in session.query(DruidDatasource):
if DBREF in datasource.params_dict:
- db.session.delete(datasource)
- db.session.commit()
+ session.delete(datasource)
+ session.commit()
@classmethod
def setUpClass(cls):
@@ -89,7 +90,9 @@ class TestDictImportExport(SupersetTestCase):
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
cluster_name = "druid_test"
- cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
+ cluster = self.get_or_create(
+ DruidCluster, {"cluster_name": cluster_name}, db.session
+ )
name = "{0}{1}".format(NAME_PREFIX, name)
params = {DBREF: id, "database_name": cluster_name}
@@ -156,7 +159,7 @@ class TestDictImportExport(SupersetTestCase):
def test_import_table_no_metadata(self):
table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
- new_table = SqlaTable.import_from_dict(dict_table)
+ new_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported_id = new_table.id
imported = self.get_table_by_id(imported_id)
@@ -170,7 +173,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["col1"],
metric_names=["metric1"],
)
- imported_table = SqlaTable.import_from_dict(dict_table)
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
@@ -186,7 +189,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["c1", "c2"],
metric_names=["m1", "m2"],
)
- imported_table = SqlaTable.import_from_dict(dict_table)
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
@@ -196,7 +199,7 @@ class TestDictImportExport(SupersetTestCase):
table, dict_table = self.create_table(
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
- imported_table = SqlaTable.import_from_dict(dict_table)
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
table_over, dict_table_over = self.create_table(
"table_override",
@@ -204,7 +207,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_over_table = SqlaTable.import_from_dict(dict_table_over)
+ imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
db.session.commit()
imported_over = self.get_table_by_id(imported_over_table.id)
@@ -224,7 +227,7 @@ class TestDictImportExport(SupersetTestCase):
table, dict_table = self.create_table(
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
- imported_table = SqlaTable.import_from_dict(dict_table)
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
table_over, dict_table_over = self.create_table(
"table_override",
@@ -233,7 +236,7 @@ class TestDictImportExport(SupersetTestCase):
metric_names=["new_metric1"],
)
imported_over_table = SqlaTable.import_from_dict(
- dict_rep=dict_table_over, sync=["metrics", "columns"]
+ session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
)
db.session.commit()
@@ -257,7 +260,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_table = SqlaTable.import_from_dict(dict_table)
+ imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
copy_table, dict_copy_table = self.create_table(
"copy_cat",
@@ -265,7 +268,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
+ imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
db.session.commit()
self.assertEqual(imported_table.id, imported_copy_table.id)
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
@@ -278,7 +281,10 @@ class TestDictImportExport(SupersetTestCase):
self.delete_fake_db()
cli_export = export_to_dict(
- recursive=True, back_references=False, include_defaults=False,
+ session=db.session,
+ recursive=True,
+ back_references=False,
+ include_defaults=False,
)
self.get_resp("/login/", data=dict(username="admin", password="general"))
resp = self.get_resp(
@@ -297,7 +303,7 @@ class TestDictImportExport(SupersetTestCase):
datasource, dict_datasource = self.create_druid_datasource(
"pure_druid", id=ID_PREFIX + 1
)
- imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+ imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
db.session.commit()
imported = self.get_datasource(imported_cluster.id)
self.assert_datasource_equals(datasource, imported)
@@ -309,7 +315,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["col1"],
metric_names=["metric1"],
)
- imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+ imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
db.session.commit()
imported = self.get_datasource(imported_cluster.id)
self.assert_datasource_equals(datasource, imported)
@@ -325,7 +331,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["c1", "c2"],
metric_names=["m1", "m2"],
)
- imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+ imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
db.session.commit()
imported = self.get_datasource(imported_cluster.id)
self.assert_datasource_equals(datasource, imported)
@@ -334,7 +340,7 @@ class TestDictImportExport(SupersetTestCase):
datasource, dict_datasource = self.create_druid_datasource(
"druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
- imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+ imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
db.session.commit()
table_over, table_over_dict = self.create_druid_datasource(
"druid_override",
@@ -342,7 +348,9 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_over_cluster = DruidDatasource.import_from_dict(table_over_dict)
+ imported_over_cluster = DruidDatasource.import_from_dict(
+ db.session, table_over_dict
+ )
db.session.commit()
imported_over = self.get_datasource(imported_over_cluster.id)
self.assertEqual(imported_cluster.id, imported_over.id)
@@ -358,7 +366,7 @@ class TestDictImportExport(SupersetTestCase):
datasource, dict_datasource = self.create_druid_datasource(
"druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
- imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
+ imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
db.session.commit()
table_over, table_over_dict = self.create_druid_datasource(
"druid_override",
@@ -367,7 +375,7 @@ class TestDictImportExport(SupersetTestCase):
metric_names=["new_metric1"],
)
imported_over_cluster = DruidDatasource.import_from_dict(
- dict_rep=table_over_dict, sync=["metrics", "columns"]
+ session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"]
) # syncing metrics and columns
db.session.commit()
imported_over = self.get_datasource(imported_over_cluster.id)
@@ -387,7 +395,9 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported = DruidDatasource.import_from_dict(dict_rep=dict_datasource)
+ imported = DruidDatasource.import_from_dict(
+ session=db.session, dict_rep=dict_datasource
+ )
db.session.commit()
copy_datasource, dict_cp_datasource = self.create_druid_datasource(
"copy_cat",
@@ -395,7 +405,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_copy = DruidDatasource.import_from_dict(dict_cp_datasource)
+ imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource)
db.session.commit()
self.assertEqual(imported.id, imported_copy.id)
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index c757671..648eb32 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -212,7 +212,9 @@ class TestDruid(SupersetTestCase):
def test_druid_sync_from_config(self):
CLUSTER_NAME = "new_druid"
self.login()
- cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
+ cluster = self.get_or_create(
+ DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
+ )
db.session.merge(cluster)
db.session.commit()
@@ -300,12 +302,15 @@ class TestDruid(SupersetTestCase):
@unittest.skipUnless(app.config["DRUID_IS_ACTIVE"], "DRUID_IS_ACTIVE is false")
def test_filter_druid_datasource(self):
CLUSTER_NAME = "new_druid"
- cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
+ cluster = self.get_or_create(
+ DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
+ )
db.session.merge(cluster)
gamma_ds = self.get_or_create(
DruidDatasource,
{"datasource_name": "datasource_for_gamma", "cluster": cluster},
+ db.session,
)
gamma_ds.cluster = cluster
db.session.merge(gamma_ds)
@@ -313,6 +318,7 @@ class TestDruid(SupersetTestCase):
no_gamma_ds = self.get_or_create(
DruidDatasource,
{"datasource_name": "datasource_not_for_gamma", "cluster": cluster},
+ db.session,
)
no_gamma_ds.cluster = cluster
db.session.merge(no_gamma_ds)
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index fc6ee51..e772d16 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -46,19 +46,20 @@ class TestImportExport(SupersetTestCase):
def delete_imports(cls):
with app.app_context():
# Imported data clean up
- for slc in db.session.query(Slice):
+ session = db.session
+ for slc in session.query(Slice):
if "remote_id" in slc.params_dict:
- db.session.delete(slc)
- for dash in db.session.query(Dashboard):
+ session.delete(slc)
+ for dash in session.query(Dashboard):
if "remote_id" in dash.params_dict:
- db.session.delete(dash)
- for table in db.session.query(SqlaTable):
+ session.delete(dash)
+ for table in session.query(SqlaTable):
if "remote_id" in table.params_dict:
- db.session.delete(table)
- for datasource in db.session.query(DruidDatasource):
+ session.delete(table)
+ for datasource in session.query(DruidDatasource):
if "remote_id" in datasource.params_dict:
- db.session.delete(datasource)
- db.session.commit()
+ session.delete(datasource)
+ session.commit()
@classmethod
def setUpClass(cls):
@@ -125,7 +126,9 @@ class TestImportExport(SupersetTestCase):
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
cluster_name = "druid_test"
- cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
+ cluster = self.get_or_create(
+ DruidCluster, {"cluster_name": cluster_name}, db.session
+ )
params = {"remote_id": id, "database_name": cluster_name}
datasource = DruidDatasource(
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index e016b13..f816bcd 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -83,6 +83,7 @@ class TestQueryContext(SupersetTestCase):
datasource = ConnectorRegistry.get_datasource(
datasource_type=payload["datasource"]["type"],
datasource_id=payload["datasource"]["id"],
+ session=db.session,
)
description_original = datasource.description
datasource.description = "temporary description"
diff --git a/tests/security_tests.py b/tests/security_tests.py
index a161ada..60d20fd 100644
--- a/tests/security_tests.py
+++ b/tests/security_tests.py
@@ -69,8 +69,9 @@ class TestRolePermission(SupersetTestCase):
"""Testing export role permissions."""
def setUp(self):
+ session = db.session
security_manager.add_role(SCHEMA_ACCESS_ROLE)
- db.session.commit()
+ session.commit()
ds = (
db.session.query(SqlaTable)
@@ -81,7 +82,7 @@ class TestRolePermission(SupersetTestCase):
ds.schema_perm = ds.get_schema_perm()
ds_slices = (
- db.session.query(Slice)
+ session.query(Slice)
.filter_by(datasource_type="table")
.filter_by(datasource_id=ds.id)
.all()
@@ -91,11 +92,12 @@ class TestRolePermission(SupersetTestCase):
create_schema_perm("[examples].[temp_schema]")
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
- db.session.commit()
+ session.commit()
def tearDown(self):
+ session = db.session
ds = (
- db.session.query(SqlaTable)
+ session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
@@ -103,7 +105,7 @@ class TestRolePermission(SupersetTestCase):
ds.schema = None
ds.schema_perm = None
ds_slices = (
- db.session.query(Slice)
+ session.query(Slice)
.filter_by(datasource_type="table")
.filter_by(datasource_id=ds.id)
.all()
@@ -112,20 +114,21 @@ class TestRolePermission(SupersetTestCase):
s.schema_perm = None
delete_schema_perm(schema_perm)
- db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
- db.session.commit()
+ session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
+ session.commit()
def test_set_perm_sqla_table(self):
+ session = db.session
table = SqlaTable(
schema="tmp_schema",
table_name="tmp_perm_table",
database=get_example_database(),
)
- db.session.add(table)
- db.session.commit()
+ session.add(table)
+ session.commit()
stored_table = (
- db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
+ session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
)
self.assertEqual(
stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})"
@@ -144,9 +147,9 @@ class TestRolePermission(SupersetTestCase):
# table name change
stored_table.table_name = "tmp_perm_table_v2"
- db.session.commit()
+ session.commit()
stored_table = (
- db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+ session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -166,9 +169,9 @@ class TestRolePermission(SupersetTestCase):
# schema name change
stored_table.schema = "tmp_schema_v2"
- db.session.commit()
+ session.commit()
stored_table = (
- db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+ session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -188,13 +191,13 @@ class TestRolePermission(SupersetTestCase):
# database change
new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db")
- db.session.add(new_db)
+ session.add(new_db)
stored_table.database = (
- db.session.query(Database).filter_by(database_name="tmp_db").one()
+ session.query(Database).filter_by(database_name="tmp_db").one()
)
- db.session.commit()
+ session.commit()
stored_table = (
- db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+ session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -214,9 +217,9 @@ class TestRolePermission(SupersetTestCase):
# no schema
stored_table.schema = None
- db.session.commit()
+ session.commit()
stored_table = (
- db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
+ session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
@@ -228,25 +231,26 @@ class TestRolePermission(SupersetTestCase):
)
self.assertIsNone(stored_table.schema_perm)
- db.session.delete(new_db)
- db.session.delete(stored_table)
- db.session.commit()
+ session.delete(new_db)
+ session.delete(stored_table)
+ session.commit()
def test_set_perm_druid_datasource(self):
+ session = db.session
druid_cluster = (
- db.session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
+ session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
)
datasource = DruidDatasource(
datasource_name="tmp_datasource",
cluster=druid_cluster,
cluster_id=druid_cluster.id,
)
- db.session.add(datasource)
- db.session.commit()
+ session.add(datasource)
+ session.commit()
# store without a schema
stored_datasource = (
- db.session.query(DruidDatasource)
+ session.query(DruidDatasource)
.filter_by(datasource_name="tmp_datasource")
.one()
)
@@ -263,7 +267,7 @@ class TestRolePermission(SupersetTestCase):
# store with a schema
stored_datasource.datasource_name = "tmp_schema.tmp_datasource"
- db.session.commit()
+ session.commit()
self.assertEqual(
stored_datasource.perm,
f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})",
@@ -280,15 +284,16 @@ class TestRolePermission(SupersetTestCase):
)
)
- db.session.delete(stored_datasource)
- db.session.commit()
+ session.delete(stored_datasource)
+ session.commit()
def test_set_perm_druid_cluster(self):
+ session = db.session
cluster = DruidCluster(cluster_name="tmp_druid_cluster")
- db.session.add(cluster)
+ session.add(cluster)
stored_cluster = (
- db.session.query(DruidCluster)
+ session.query(DruidCluster)
.filter_by(cluster_name="tmp_druid_cluster")
.one()
)
@@ -302,7 +307,7 @@ class TestRolePermission(SupersetTestCase):
)
stored_cluster.cluster_name = "tmp_druid_cluster2"
- db.session.commit()
+ session.commit()
self.assertEqual(
stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})"
)
@@ -312,17 +317,18 @@ class TestRolePermission(SupersetTestCase):
)
)
- db.session.delete(stored_cluster)
- db.session.commit()
+ session.delete(stored_cluster)
+ session.commit()
def test_set_perm_database(self):
+ session = db.session
database = Database(
database_name="tmp_database", sqlalchemy_uri="sqlite://test"
)
- db.session.add(database)
+ session.add(database)
stored_db = (
- db.session.query(Database).filter_by(database_name="tmp_database").one()
+ session.query(Database).filter_by(database_name="tmp_database").one()
)
self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})")
self.assertIsNotNone(
@@ -332,9 +338,9 @@ class TestRolePermission(SupersetTestCase):
)
stored_db.database_name = "tmp_database2"
- db.session.commit()
+ session.commit()
stored_db = (
- db.session.query(Database).filter_by(database_name="tmp_database2").one()
+ session.query(Database).filter_by(database_name="tmp_database2").one()
)
self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})")
self.assertIsNotNone(
@@ -343,8 +349,8 @@ class TestRolePermission(SupersetTestCase):
)
)
- db.session.delete(stored_db)
- db.session.commit()
+ session.delete(stored_db)
+ session.commit()
def test_hybrid_perm_druid_cluster(self):
cluster = DruidCluster(cluster_name="tmp_druid_cluster3")
@@ -394,13 +400,14 @@ class TestRolePermission(SupersetTestCase):
db.session.commit()
def test_set_perm_slice(self):
+ session = db.session
database = Database(
database_name="tmp_database", sqlalchemy_uri="sqlite://test"
)
table = SqlaTable(table_name="tmp_perm_table", database=database)
- db.session.add(database)
- db.session.add(table)
- db.session.commit()
+ session.add(database)
+ session.add(table)
+ session.commit()
# no schema permission
slice = Slice(
@@ -409,10 +416,10 @@ class TestRolePermission(SupersetTestCase):
datasource_name="tmp_perm_table",
slice_name="slice_name",
)
- db.session.add(slice)
- db.session.commit()
+ session.add(slice)
+ session.commit()
- slice = db.session.query(Slice).filter_by(slice_name="slice_name").one()
+ slice = session.query(Slice).filter_by(slice_name="slice_name").one()
self.assertEqual(slice.perm, table.perm)
self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
self.assertEqual(slice.schema_perm, table.schema_perm)
@@ -420,7 +427,7 @@ class TestRolePermission(SupersetTestCase):
table.schema = "tmp_perm_schema"
table.table_name = "tmp_perm_table_v2"
- db.session.commit()
+ session.commit()
# TODO(bogdan): modify slice permissions on the table update.
self.assertNotEquals(slice.perm, table.perm)
self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
@@ -433,7 +440,7 @@ class TestRolePermission(SupersetTestCase):
# updating slice refreshes the permissions
slice.slice_name = "slice_name_v2"
- db.session.commit()
+ session.commit()
self.assertEqual(slice.perm, table.perm)
self.assertEqual(
slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})"
@@ -441,10 +448,11 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(slice.schema_perm, table.schema_perm)
self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]")
- db.session.delete(slice)
- db.session.delete(table)
- db.session.delete(database)
- db.session.commit()
+ session.delete(slice)
+ session.delete(table)
+ session.delete(database)
+
+ session.commit()
# TODO test slice permission
@@ -524,11 +532,11 @@ class TestRolePermission(SupersetTestCase):
self.assertNotIn("Girl Name Cloud", data) # birth_names slice, no access
def test_sqllab_gamma_user_schema_access_to_sqllab(self):
- example_db = (
- db.session.query(Database).filter_by(database_name="examples").one()
- )
+ session = db.session
+
+ example_db = session.query(Database).filter_by(database_name="examples").one()
example_db.expose_in_sqllab = True
- db.session.commit()
+ session.commit()
arguments = {
"keys": ["none"],
@@ -951,10 +959,12 @@ class TestRowLevelSecurity(SupersetTestCase):
rls_entry = None
def setUp(self):
+ session = db.session
+
# Create the RowLevelSecurityFilter
self.rls_entry = RowLevelSecurityFilter()
self.rls_entry.tables.extend(
- db.session.query(SqlaTable)
+ session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
.all()
)
@@ -964,11 +974,13 @@ class TestRowLevelSecurity(SupersetTestCase):
) # db.session.query(Role).filter_by(name="Gamma").first())
self.rls_entry.roles.append(security_manager.find_role("Alpha"))
db.session.add(self.rls_entry)
+
db.session.commit()
def tearDown(self):
- db.session.delete(self.rls_entry)
- db.session.commit()
+ session = db.session
+ session.delete(self.rls_entry)
+ session.commit()
# Do another test to make sure it doesn't alter another query
def test_rls_filter_alters_query(self):
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 9315d39..bff8d9d 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -63,6 +63,7 @@ class TestSqlLab(SupersetTestCase):
self.logout()
db.session.query(Query).delete()
db.session.commit()
+ db.session.close()
def test_sql_json(self):
self.login("admin")
@@ -459,6 +460,7 @@ class TestSqlLab(SupersetTestCase):
Test query api with can_access_all_queries perm added to
gamma and make sure all queries show up.
"""
+ session = db.session
# Add all_query_access perm to Gamma user
all_queries_view = security_manager.find_permission_view_menu(
@@ -468,7 +470,7 @@ class TestSqlLab(SupersetTestCase):
security_manager.add_permission_role(
security_manager.find_role("gamma_sqllab"), all_queries_view
)
- db.session.commit()
+ session.commit()
# Test search_queries for Admin user
self.run_some_queries()
@@ -485,7 +487,7 @@ class TestSqlLab(SupersetTestCase):
security_manager.find_role("gamma_sqllab"), all_queries_view
)
- db.session.commit()
+ session.commit()
def test_query_admin_can_access_all_queries(self) -> None:
"""
diff --git a/tests/strategy_tests.py b/tests/strategy_tests.py
index 49e2349..c4f0019 100644
--- a/tests/strategy_tests.py
+++ b/tests/strategy_tests.py
@@ -194,7 +194,7 @@ class TestCacheWarmUp(SupersetTestCase):
db.session.commit()
def test_dashboard_tags(self):
- tag1 = get_tag("tag1", TagTypes.custom)
+ tag1 = get_tag("tag1", db.session, TagTypes.custom)
# delete first to make test idempotent
self.reset_tag(tag1)
@@ -204,7 +204,7 @@ class TestCacheWarmUp(SupersetTestCase):
self.assertEqual(result, expected)
# tag dashboard 'births' with `tag1`
- tag1 = get_tag("tag1", TagTypes.custom)
+ tag1 = get_tag("tag1", db.session, TagTypes.custom)
dash = self.get_dash_by_slug("births")
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
tagged_object = TaggedObject(
@@ -216,7 +216,7 @@ class TestCacheWarmUp(SupersetTestCase):
self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
strategy = DashboardTagsStrategy(["tag2"])
- tag2 = get_tag("tag2", TagTypes.custom)
+ tag2 = get_tag("tag2", db.session, TagTypes.custom)
self.reset_tag(tag2)
result = sorted(strategy.get_urls())