You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/08/24 16:21:26 UTC

[superset] 01/04: fix: Ensure SQLAlchemy sessions are closed (#25031)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit ad89ea549b1d3a0f2f21a1f2df77640dc04c4ca4
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Aug 23 11:57:36 2023 -0700

    fix: Ensure SQLAlchemy sessions are closed (#25031)
    
    (cherry picked from commit adaab3550c4487b17868a8880cfa146a7806422a)
---
 superset/models/dashboard.py |  48 +++++++++--------
 superset/tags/models.py      | 126 +++++++++++++++++++++++--------------------
 superset/tasks/cache.py      | 106 ++++++++++++++++++++----------------
 3 files changed, 153 insertions(+), 127 deletions(-)

diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index f837c76610..0848126012 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -74,28 +74,31 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -
 
     session_class = sessionmaker(autoflush=False)
     session = session_class(bind=connection)
-    new_user = session.query(User).filter_by(id=target.id).first()
-
-    # copy template dashboard to user
-    template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
-    dashboard = Dashboard(
-        dashboard_title=template.dashboard_title,
-        position_json=template.position_json,
-        description=template.description,
-        css=template.css,
-        json_metadata=template.json_metadata,
-        slices=template.slices,
-        owners=[new_user],
-    )
-    session.add(dashboard)
-    session.commit()
 
-    # set dashboard as the welcome dashboard
-    extra_attributes = UserAttribute(
-        user_id=target.id, welcome_dashboard_id=dashboard.id
-    )
-    session.add(extra_attributes)
-    session.commit()
+    try:
+        new_user = session.query(User).filter_by(id=target.id).first()
+
+        # copy template dashboard to user
+        template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
+        dashboard = Dashboard(
+            dashboard_title=template.dashboard_title,
+            position_json=template.position_json,
+            description=template.description,
+            css=template.css,
+            json_metadata=template.json_metadata,
+            slices=template.slices,
+            owners=[new_user],
+        )
+        session.add(dashboard)
+
+        # set dashboard as the welcome dashboard
+        extra_attributes = UserAttribute(
+            user_id=target.id, welcome_dashboard_id=dashboard.id
+        )
+        session.add(extra_attributes)
+        session.commit()
+    finally:
+        session.close()
 
 
 sqla.event.listen(User, "after_insert", copy_dashboard)
@@ -411,13 +414,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
                 "native_filter_configuration", []
             )
             for native_filter in native_filter_configuration:
-                session = db.session()
                 for target in native_filter.get("targets", []):
                     id_ = target.get("datasetId")
                     if id_ is None:
                         continue
                     datasource = DatasourceDAO.get_datasource(
-                        session, utils.DatasourceType.TABLE, id_
+                        db.session, utils.DatasourceType.TABLE, id_
                     )
                     datasource_ids.add((datasource.id, datasource.type))
 
diff --git a/superset/tags/models.py b/superset/tags/models.py
index 87670a0d29..1ce682c4a4 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -156,17 +156,19 @@ class ObjectUpdater:
     ) -> None:
         session = Session(bind=connection)
 
-        # add `owner:` tags
-        cls._add_owners(session, target)
+        try:
+            # add `owner:` tags
+            cls._add_owners(session, target)
 
-        # add `type:` tags
-        tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
-        tagged_object = TaggedObject(
-            tag_id=tag.id, object_id=target.id, object_type=cls.object_type
-        )
-        session.add(tagged_object)
-
-        session.commit()
+            # add `type:` tags
+            tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
+            tagged_object = TaggedObject(
+                tag_id=tag.id, object_id=target.id, object_type=cls.object_type
+            )
+            session.add(tagged_object)
+            session.commit()
+        finally:
+            session.close()
 
     @classmethod
     def after_update(
@@ -177,25 +179,27 @@ class ObjectUpdater:
     ) -> None:
         session = Session(bind=connection)
 
-        # delete current `owner:` tags
-        query = (
-            session.query(TaggedObject.id)
-            .join(Tag)
-            .filter(
-                TaggedObject.object_type == cls.object_type,
-                TaggedObject.object_id == target.id,
-                Tag.type == TagTypes.owner,
+        try:
+            # delete current `owner:` tags
+            query = (
+                session.query(TaggedObject.id)
+                .join(Tag)
+                .filter(
+                    TaggedObject.object_type == cls.object_type,
+                    TaggedObject.object_id == target.id,
+                    Tag.type == TagTypes.owner,
+                )
+            )
+            ids = [row[0] for row in query]
+            session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
+                synchronize_session=False
             )
-        )
-        ids = [row[0] for row in query]
-        session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
-            synchronize_session=False
-        )
-
-        # add `owner:` tags
-        cls._add_owners(session, target)
 
-        session.commit()
+            # add `owner:` tags
+            cls._add_owners(session, target)
+            session.commit()
+        finally:
+            session.close()
 
     @classmethod
     def after_delete(
@@ -206,13 +210,16 @@ class ObjectUpdater:
     ) -> None:
         session = Session(bind=connection)
 
-        # delete row from `tagged_objects`
-        session.query(TaggedObject).filter(
-            TaggedObject.object_type == cls.object_type,
-            TaggedObject.object_id == target.id,
-        ).delete()
+        try:
+            # delete row from `tagged_objects`
+            session.query(TaggedObject).filter(
+                TaggedObject.object_type == cls.object_type,
+                TaggedObject.object_id == target.id,
+            ).delete()
 
-        session.commit()
+            session.commit()
+        finally:
+            session.close()
 
 
 class ChartUpdater(ObjectUpdater):
@@ -253,35 +260,40 @@ class FavStarUpdater:
         cls, _mapper: Mapper, connection: Connection, target: FavStar
     ) -> None:
         session = Session(bind=connection)
-        name = f"favorited_by:{target.user_id}"
-        tag = get_tag(name, session, TagTypes.favorited_by)
-        tagged_object = TaggedObject(
-            tag_id=tag.id,
-            object_id=target.obj_id,
-            object_type=get_object_type(target.class_name),
-        )
-        session.add(tagged_object)
-
-        session.commit()
+        try:
+            name = f"favorited_by:{target.user_id}"
+            tag = get_tag(name, session, TagTypes.favorited_by)
+            tagged_object = TaggedObject(
+                tag_id=tag.id,
+                object_id=target.obj_id,
+                object_type=get_object_type(target.class_name),
+            )
+            session.add(tagged_object)
+            session.commit()
+        finally:
+            session.close()
 
     @classmethod
     def after_delete(
         cls, _mapper: Mapper, connection: Connection, target: FavStar
     ) -> None:
         session = Session(bind=connection)
-        name = f"favorited_by:{target.user_id}"
-        query = (
-            session.query(TaggedObject.id)
-            .join(Tag)
-            .filter(
-                TaggedObject.object_id == target.obj_id,
-                Tag.type == TagTypes.favorited_by,
-                Tag.name == name,
+        try:
+            name = f"favorited_by:{target.user_id}"
+            query = (
+                session.query(TaggedObject.id)
+                .join(Tag)
+                .filter(
+                    TaggedObject.object_id == target.obj_id,
+                    Tag.type == TagTypes.favorited_by,
+                    Tag.name == name,
+                )
+            )
+            ids = [row[0] for row in query]
+            session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
+                synchronize_session=False
             )
-        )
-        ids = [row[0] for row in query]
-        session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
-            synchronize_session=False
-        )
 
-        session.commit()
+            session.commit()
+        finally:
+            session.close()
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index 68b5657a22..01f6351919 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -95,7 +95,11 @@ class DummyStrategy(Strategy):  # pylint: disable=too-few-public-methods
 
     def get_payloads(self) -> list[dict[str, int]]:
         session = db.create_scoped_session()
-        charts = session.query(Slice).all()
+
+        try:
+            charts = session.query(Slice).all()
+        finally:
+            session.close()
 
         return [get_payload(chart) for chart in charts]
 
@@ -129,20 +133,24 @@ class TopNDashboardsStrategy(Strategy):  # pylint: disable=too-few-public-method
         payloads = []
         session = db.create_scoped_session()
 
-        records = (
-            session.query(Log.dashboard_id, func.count(Log.dashboard_id))
-            .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
-            .group_by(Log.dashboard_id)
-            .order_by(func.count(Log.dashboard_id).desc())
-            .limit(self.top_n)
-            .all()
-        )
-        dash_ids = [record.dashboard_id for record in records]
-        dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
-        for dashboard in dashboards:
-            for chart in dashboard.slices:
-                payloads.append(get_payload(chart, dashboard))
-
+        try:
+            records = (
+                session.query(Log.dashboard_id, func.count(Log.dashboard_id))
+                .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
+                .group_by(Log.dashboard_id)
+                .order_by(func.count(Log.dashboard_id).desc())
+                .limit(self.top_n)
+                .all()
+            )
+            dash_ids = [record.dashboard_id for record in records]
+            dashboards = (
+                session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
+            )
+            for dashboard in dashboards:
+                for chart in dashboard.slices:
+                    payloads.append(get_payload(chart, dashboard))
+        finally:
+            session.close()
         return payloads
 
 
@@ -172,42 +180,46 @@ class DashboardTagsStrategy(Strategy):  # pylint: disable=too-few-public-methods
         payloads = []
         session = db.create_scoped_session()
 
-        tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
-        tag_ids = [tag.id for tag in tags]
-
-        # add dashboards that are tagged
-        tagged_objects = (
-            session.query(TaggedObject)
-            .filter(
-                and_(
-                    TaggedObject.object_type == "dashboard",
-                    TaggedObject.tag_id.in_(tag_ids),
+        try:
+            tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
+            tag_ids = [tag.id for tag in tags]
+
+            # add dashboards that are tagged
+            tagged_objects = (
+                session.query(TaggedObject)
+                .filter(
+                    and_(
+                        TaggedObject.object_type == "dashboard",
+                        TaggedObject.tag_id.in_(tag_ids),
+                    )
                 )
+                .all()
             )
-            .all()
-        )
-        dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
-        tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
-        for dashboard in tagged_dashboards:
-            for chart in dashboard.slices:
-                payloads.append(get_payload(chart))
-
-        # add charts that are tagged
-        tagged_objects = (
-            session.query(TaggedObject)
-            .filter(
-                and_(
-                    TaggedObject.object_type == "chart",
-                    TaggedObject.tag_id.in_(tag_ids),
+            dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
+            tagged_dashboards = session.query(Dashboard).filter(
+                Dashboard.id.in_(dash_ids)
+            )
+            for dashboard in tagged_dashboards:
+                for chart in dashboard.slices:
+                    payloads.append(get_payload(chart))
+
+            # add charts that are tagged
+            tagged_objects = (
+                session.query(TaggedObject)
+                .filter(
+                    and_(
+                        TaggedObject.object_type == "chart",
+                        TaggedObject.tag_id.in_(tag_ids),
+                    )
                 )
+                .all()
             )
-            .all()
-        )
-        chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
-        tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
-        for chart in tagged_charts:
-            payloads.append(get_payload(chart))
-
+            chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
+            tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
+            for chart in tagged_charts:
+                payloads.append(get_payload(chart))
+        finally:
+            session.close()
         return payloads