You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2023/06/21 16:30:14 UTC

[superset] branch master updated: chore(dao): Add generic type for better type checking (#24465)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 92e2ee9d07 chore(dao): Add generic type for better type checking (#24465)
92e2ee9d07 is described below

commit 92e2ee9d0745f8717adea493636451751db0eb08
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Jun 21 09:30:07 2023 -0700

    chore(dao): Add generic type for better type checking (#24465)
---
 .../annotations/commands/delete.py                 |  2 ++
 .../annotations/commands/update.py                 |  2 ++
 superset/annotation_layers/commands/delete.py      |  2 ++
 superset/annotation_layers/commands/update.py      |  2 ++
 superset/charts/commands/update.py                 |  2 ++
 superset/commands/export/models.py                 |  2 +-
 superset/daos/annotation.py                        |  8 ++---
 superset/daos/base.py                              | 35 +++++++++++-----------
 superset/daos/chart.py                             |  3 +-
 superset/daos/css.py                               |  4 +--
 superset/daos/dashboard.py                         | 10 ++-----
 superset/daos/database.py                          |  7 ++---
 superset/daos/dataset.py                           | 13 ++++----
 superset/daos/datasource.py                        |  2 +-
 superset/daos/log.py                               |  4 +--
 superset/daos/query.py                             |  6 ++--
 superset/daos/report.py                            |  3 +-
 superset/daos/security.py                          |  4 +--
 superset/daos/tag.py                               |  3 +-
 superset/dashboards/commands/delete.py             |  2 ++
 superset/dashboards/commands/update.py             |  2 ++
 superset/dashboards/filter_sets/commands/delete.py |  4 ++-
 superset/dashboards/filter_sets/commands/update.py |  2 ++
 superset/databases/commands/delete.py              |  2 ++
 superset/databases/ssh_tunnel/commands/delete.py   |  2 ++
 superset/reports/commands/delete.py                |  2 ++
 superset/row_level_security/commands/update.py     |  2 ++
 27 files changed, 68 insertions(+), 64 deletions(-)

diff --git a/superset/annotation_layers/annotations/commands/delete.py b/superset/annotation_layers/annotations/commands/delete.py
index b86ae997a4..2af01f57f4 100644
--- a/superset/annotation_layers/annotations/commands/delete.py
+++ b/superset/annotation_layers/annotations/commands/delete.py
@@ -38,6 +38,8 @@ class DeleteAnnotationCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             annotation = AnnotationDAO.delete(self._model)
         except DAODeleteFailedError as ex:
diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py
index 03797a555b..76287d24a9 100644
--- a/superset/annotation_layers/annotations/commands/update.py
+++ b/superset/annotation_layers/annotations/commands/update.py
@@ -45,6 +45,8 @@ class UpdateAnnotationCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             annotation = AnnotationDAO.update(self._model, self._properties)
         except DAOUpdateFailedError as ex:
diff --git a/superset/annotation_layers/commands/delete.py b/superset/annotation_layers/commands/delete.py
index 0692d4dd83..1af4242dce 100644
--- a/superset/annotation_layers/commands/delete.py
+++ b/superset/annotation_layers/commands/delete.py
@@ -39,6 +39,8 @@ class DeleteAnnotationLayerCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             annotation_layer = AnnotationLayerDAO.delete(self._model)
         except DAODeleteFailedError as ex:
diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py
index ca3a288413..e7f6963e82 100644
--- a/superset/annotation_layers/commands/update.py
+++ b/superset/annotation_layers/commands/update.py
@@ -42,6 +42,8 @@ class UpdateAnnotationLayerCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
         except DAOUpdateFailedError as ex:
diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py
index 9a5b4e1f29..32fd49e7cd 100644
--- a/superset/charts/commands/update.py
+++ b/superset/charts/commands/update.py
@@ -56,6 +56,8 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             if self._properties.get("query_context_generation") is None:
                 self._properties["last_saved_at"] = datetime.now()
diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py
index 27f4572af3..61532d4a03 100644
--- a/superset/commands/export/models.py
+++ b/superset/commands/export/models.py
@@ -30,7 +30,7 @@ METADATA_FILE_NAME = "metadata.yaml"
 
 
 class ExportModelsCommand(BaseCommand):
-    dao: type[BaseDAO] = BaseDAO
+    dao: type[BaseDAO[Model]] = BaseDAO
     not_found: type[CommandException] = CommandException
 
     def __init__(self, model_ids: list[int], export_related: bool = True):
diff --git a/superset/daos/annotation.py b/superset/daos/annotation.py
index 171a708fa4..2df336647a 100644
--- a/superset/daos/annotation.py
+++ b/superset/daos/annotation.py
@@ -27,9 +27,7 @@ from superset.models.annotations import Annotation, AnnotationLayer
 logger = logging.getLogger(__name__)
 
 
-class AnnotationDAO(BaseDAO):
-    model_cls = Annotation
-
+class AnnotationDAO(BaseDAO[Annotation]):
     @staticmethod
     def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
         item_ids = [model.id for model in models] if models else []
@@ -64,9 +62,7 @@ class AnnotationDAO(BaseDAO):
         return not db.session.query(query.exists()).scalar()
 
 
-class AnnotationLayerDAO(BaseDAO):
-    model_cls = AnnotationLayer
-
+class AnnotationLayerDAO(BaseDAO[AnnotationLayer]):
     @staticmethod
     def bulk_delete(
         models: Optional[list[AnnotationLayer]], commit: bool = True
diff --git a/superset/daos/base.py b/superset/daos/base.py
index 6465e5b177..c0758f51dd 100644
--- a/superset/daos/base.py
+++ b/superset/daos/base.py
@@ -14,8 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=isinstance-second-argument-not-valid-type
-from typing import Any, Optional, Union
+from typing import Any, Generic, get_args, Optional, TypeVar, Union
 
 from flask_appbuilder.models.filters import BaseFilter
 from flask_appbuilder.models.sqla import Model
@@ -31,8 +30,10 @@ from superset.daos.exceptions import (
 )
 from superset.extensions import db
 
+T = TypeVar("T", bound=Model)  # pylint: disable=invalid-name
 
-class BaseDAO:
+
+class BaseDAO(Generic[T]):
     """
     Base DAO, implement base CRUD sqlalchemy operations
     """
@@ -48,6 +49,11 @@ class BaseDAO:
     """
     id_column_name = "id"
 
+    def __init_subclass__(cls) -> None:  # pylint: disable=arguments-differ
+        cls.model_cls = get_args(
+            cls.__orig_bases__[0]  # type: ignore  # pylint: disable=no-member
+        )[0]
+
     @classmethod
     def find_by_id(
         cls,
@@ -78,7 +84,7 @@ class BaseDAO:
         model_ids: Union[list[str], list[int]],
         session: Session = None,
         skip_base_filter: bool = False,
-    ) -> list[Model]:
+    ) -> list[T]:
         """
         Find a List of models by a list of ids, if defined applies `base_filter`
         """
@@ -95,7 +101,7 @@ class BaseDAO:
         return query.all()
 
     @classmethod
-    def find_all(cls) -> list[Model]:
+    def find_all(cls) -> list[T]:
         """
         Get all that fit the `base_filter`
         """
@@ -108,7 +114,7 @@ class BaseDAO:
         return query.all()
 
     @classmethod
-    def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
+    def find_one_or_none(cls, **filter_by: Any) -> Optional[T]:
         """
         Get the first that fit the `base_filter`
         """
@@ -121,7 +127,7 @@ class BaseDAO:
         return query.filter_by(**filter_by).one_or_none()
 
     @classmethod
-    def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
+    def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
         """
         Generic for creating models
         :raises: DAOCreateFailedError
@@ -141,17 +147,13 @@ class BaseDAO:
         return model
 
     @classmethod
-    def save(cls, instance_model: Model, commit: bool = True) -> Model:
+    def save(cls, instance_model: T, commit: bool = True) -> None:
         """
         Generic for saving models
         :raises: DAOCreateFailedError
         """
         if cls.model_cls is None:
             raise DAOConfigError()
-        if not isinstance(instance_model, cls.model_cls):
-            raise DAOCreateFailedError(
-                "the instance model is not a type of the model class"
-            )
         try:
             db.session.add(instance_model)
             if commit:
@@ -159,12 +161,9 @@ class BaseDAO:
         except SQLAlchemyError as ex:  # pragma: no cover
             db.session.rollback()
             raise DAOCreateFailedError(exception=ex) from ex
-        return instance_model
 
     @classmethod
-    def update(
-        cls, model: Model, properties: dict[str, Any], commit: bool = True
-    ) -> Model:
+    def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
         """
         Generic update a model
         :raises: DAOCreateFailedError
@@ -181,7 +180,7 @@ class BaseDAO:
         return model
 
     @classmethod
-    def delete(cls, model: Model, commit: bool = True) -> Model:
+    def delete(cls, model: T, commit: bool = True) -> T:
         """
         Generic delete a model
         :raises: DAODeleteFailedError
@@ -196,7 +195,7 @@ class BaseDAO:
         return model
 
     @classmethod
-    def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
+    def bulk_delete(cls, models: list[T], commit: bool = True) -> None:
         try:
             for model in models:
                 cls.delete(model, False)
diff --git a/superset/daos/chart.py b/superset/daos/chart.py
index 838d93abdf..1a13965022 100644
--- a/superset/daos/chart.py
+++ b/superset/daos/chart.py
@@ -34,8 +34,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class ChartDAO(BaseDAO):
-    model_cls = Slice
+class ChartDAO(BaseDAO[Slice]):
     base_filter = ChartFilter
 
     @staticmethod
diff --git a/superset/daos/css.py b/superset/daos/css.py
index 224277a40a..3a1cbe8fda 100644
--- a/superset/daos/css.py
+++ b/superset/daos/css.py
@@ -27,9 +27,7 @@ from superset.models.core import CssTemplate
 logger = logging.getLogger(__name__)
 
 
-class CssTemplateDAO(BaseDAO):
-    model_cls = CssTemplate
-
+class CssTemplateDAO(BaseDAO[CssTemplate]):
     @staticmethod
     def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
         item_ids = [model.id for model in models] if models else []
diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py
index 1e31591e1f..1650711d50 100644
--- a/superset/daos/dashboard.py
+++ b/superset/daos/dashboard.py
@@ -49,8 +49,7 @@ from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
 logger = logging.getLogger(__name__)
 
 
-class DashboardDAO(BaseDAO):
-    model_cls = Dashboard
+class DashboardDAO(BaseDAO[Dashboard]):
     base_filter = DashboardAccessFilter
 
     @classmethod
@@ -379,8 +378,7 @@ class DashboardDAO(BaseDAO):
             db.session.commit()
 
 
-class EmbeddedDashboardDAO(BaseDAO):
-    model_cls = EmbeddedDashboard
+class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
     # There isn't really a regular scenario where we would rather get Embedded by id
     id_column_name = "uuid"
 
@@ -407,9 +405,7 @@ class EmbeddedDashboardDAO(BaseDAO):
         raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")
 
 
-class FilterSetDAO(BaseDAO):
-    model_cls = FilterSet
-
+class FilterSetDAO(BaseDAO[FilterSet]):
     @classmethod
     def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
         if cls.model_cls is None:
diff --git a/superset/daos/database.py b/superset/daos/database.py
index 569568472a..0a3cb65b2e 100644
--- a/superset/daos/database.py
+++ b/superset/daos/database.py
@@ -31,8 +31,7 @@ from superset.utils.ssh_tunnel import unmask_password_info
 logger = logging.getLogger(__name__)
 
 
-class DatabaseDAO(BaseDAO):
-    model_cls = Database
+class DatabaseDAO(BaseDAO[Database]):
     base_filter = DatabaseFilter
 
     @classmethod
@@ -138,9 +137,7 @@ class DatabaseDAO(BaseDAO):
         return ssh_tunnel
 
 
-class SSHTunnelDAO(BaseDAO):
-    model_cls = SSHTunnel
-
+class SSHTunnelDAO(BaseDAO[SSHTunnel]):
     @classmethod
     def update(
         cls,
diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py
index 3937e6c312..4634a7e46f 100644
--- a/superset/daos/dataset.py
+++ b/superset/daos/dataset.py
@@ -31,8 +31,7 @@ from superset.views.base import DatasourceFilter
 logger = logging.getLogger(__name__)
 
 
-class DatasetDAO(BaseDAO):  # pylint: disable=too-many-public-methods
-    model_cls = SqlaTable
+class DatasetDAO(BaseDAO[SqlaTable]):  # pylint: disable=too-many-public-methods
     base_filter = DatasourceFilter
 
     @staticmethod
@@ -151,7 +150,7 @@ class DatasetDAO(BaseDAO):  # pylint: disable=too-many-public-methods
         model: SqlaTable,
         properties: dict[str, Any],
         commit: bool = True,
-    ) -> Optional[SqlaTable]:
+    ) -> SqlaTable:
         """
         Updates a Dataset model on the metadata DB
         """
@@ -397,9 +396,9 @@ class DatasetDAO(BaseDAO):  # pylint: disable=too-many-public-methods
         )
 
 
-class DatasetColumnDAO(BaseDAO):
-    model_cls = TableColumn
+class DatasetColumnDAO(BaseDAO[TableColumn]):
+    pass
 
 
-class DatasetMetricDAO(BaseDAO):
-    model_cls = SqlMetric
+class DatasetMetricDAO(BaseDAO[SqlMetric]):
+    pass
diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py
index 684106161c..2bdf4ca21f 100644
--- a/superset/daos/datasource.py
+++ b/superset/daos/datasource.py
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
 Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
 
 
-class DatasourceDAO(BaseDAO):
+class DatasourceDAO(BaseDAO[Datasource]):
     sources: dict[Union[DatasourceType, str], type[Datasource]] = {
         DatasourceType.TABLE: SqlaTable,
         DatasourceType.QUERY: Query,
diff --git a/superset/daos/log.py b/superset/daos/log.py
index 81767a48cb..002c3f2307 100644
--- a/superset/daos/log.py
+++ b/superset/daos/log.py
@@ -30,9 +30,7 @@ from superset.utils.core import get_user_id
 from superset.utils.dates import datetime_to_epoch
 
 
-class LogDAO(BaseDAO):
-    model_cls = Log
-
+class LogDAO(BaseDAO[Log]):
     @staticmethod
     def get_recent_activity(
         actions: list[str],
diff --git a/superset/daos/query.py b/superset/daos/query.py
index 8996e27a3b..80b5d1ad4e 100644
--- a/superset/daos/query.py
+++ b/superset/daos/query.py
@@ -35,8 +35,7 @@ from superset.utils.dates import now_as_float
 logger = logging.getLogger(__name__)
 
 
-class QueryDAO(BaseDAO):
-    model_cls = Query
+class QueryDAO(BaseDAO[Query]):
     base_filter = QueryFilter
 
     @staticmethod
@@ -104,8 +103,7 @@ class QueryDAO(BaseDAO):
         db.session.commit()
 
 
-class SavedQueryDAO(BaseDAO):
-    model_cls = SavedQuery
+class SavedQueryDAO(BaseDAO[SavedQuery]):
     base_filter = SavedQueryFilter
 
     @staticmethod
diff --git a/superset/daos/report.py b/superset/daos/report.py
index 4f8d914adc..70a87a6454 100644
--- a/superset/daos/report.py
+++ b/superset/daos/report.py
@@ -42,8 +42,7 @@ logger = logging.getLogger(__name__)
 REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error"
 
 
-class ReportScheduleDAO(BaseDAO):
-    model_cls = ReportSchedule
+class ReportScheduleDAO(BaseDAO[ReportSchedule]):
     base_filter = ReportScheduleFilter
 
     @staticmethod
diff --git a/superset/daos/security.py b/superset/daos/security.py
index a435f224a6..392d741e3d 100644
--- a/superset/daos/security.py
+++ b/superset/daos/security.py
@@ -19,5 +19,5 @@ from superset.connectors.sqla.models import RowLevelSecurityFilter
 from superset.daos.base import BaseDAO
 
 
-class RLSDAO(BaseDAO):
-    model_cls = RowLevelSecurityFilter
+class RLSDAO(BaseDAO[RowLevelSecurityFilter]):
+    pass
diff --git a/superset/daos/tag.py b/superset/daos/tag.py
index ec991edb13..90b0134ca7 100644
--- a/superset/daos/tag.py
+++ b/superset/daos/tag.py
@@ -31,8 +31,7 @@ from superset.tags.models import get_tag, ObjectTypes, Tag, TaggedObject, TagTyp
 logger = logging.getLogger(__name__)
 
 
-class TagDAO(BaseDAO):
-    model_cls = Tag
+class TagDAO(BaseDAO[Tag]):
     # base_filter = TagAccessFilter
 
     @staticmethod
diff --git a/superset/dashboards/commands/delete.py b/superset/dashboards/commands/delete.py
index f774b92a51..1f5eb4ae3b 100644
--- a/superset/dashboards/commands/delete.py
+++ b/superset/dashboards/commands/delete.py
@@ -44,6 +44,8 @@ class DeleteDashboardCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             dashboard = DashboardDAO.delete(self._model)
         except DAODeleteFailedError as ex:
diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py
index cd9c07e0fd..c880eebe89 100644
--- a/superset/dashboards/commands/update.py
+++ b/superset/dashboards/commands/update.py
@@ -48,6 +48,8 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
             if self._properties.get("json_metadata"):
diff --git a/superset/dashboards/filter_sets/commands/delete.py b/superset/dashboards/filter_sets/commands/delete.py
index 93f4383399..c058354245 100644
--- a/superset/dashboards/filter_sets/commands/delete.py
+++ b/superset/dashboards/filter_sets/commands/delete.py
@@ -36,8 +36,10 @@ class DeleteFilterSetCommand(BaseFilterSetCommand):
         self._filter_set_id = filter_set_id
 
     def run(self) -> Model:
+        self.validate()
+        assert self._filter_set
+
         try:
-            self.validate()
             return FilterSetDAO.delete(self._filter_set, commit=True)
         except DAODeleteFailedError as err:
             raise FilterSetDeleteFailedError(str(self._filter_set_id), "") from err
diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py
index eecaa34aeb..a63c8d46f2 100644
--- a/superset/dashboards/filter_sets/commands/update.py
+++ b/superset/dashboards/filter_sets/commands/update.py
@@ -39,6 +39,8 @@ class UpdateFilterSetCommand(BaseFilterSetCommand):
     def run(self) -> Model:
         try:
             self.validate()
+            assert self._filter_set
+
             if (
                 OWNER_TYPE_FIELD in self._properties
                 and self._properties[OWNER_TYPE_FIELD] == "Dashboard"
diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py
index b8eb3f6e5e..95d212e290 100644
--- a/superset/databases/commands/delete.py
+++ b/superset/databases/commands/delete.py
@@ -42,6 +42,8 @@ class DeleteDatabaseCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             database = DatabaseDAO.delete(self._model)
         except DAODeleteFailedError as ex:
diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py
index 910df35a19..375c496f2a 100644
--- a/superset/databases/ssh_tunnel/commands/delete.py
+++ b/superset/databases/ssh_tunnel/commands/delete.py
@@ -42,6 +42,8 @@ class DeleteSSHTunnelCommand(BaseCommand):
         if not is_feature_enabled("SSH_TUNNELING"):
             raise SSHTunnelingNotEnabledError()
         self.validate()
+        assert self._model
+
         try:
             ssh_tunnel = SSHTunnelDAO.delete(self._model)
         except DAODeleteFailedError as ex:
diff --git a/superset/reports/commands/delete.py b/superset/reports/commands/delete.py
index 3f7e4e5d23..f52d96f7f5 100644
--- a/superset/reports/commands/delete.py
+++ b/superset/reports/commands/delete.py
@@ -41,6 +41,8 @@ class DeleteReportScheduleCommand(BaseCommand):
 
     def run(self) -> Model:
         self.validate()
+        assert self._model
+
         try:
             report_schedule = ReportScheduleDAO.delete(self._model)
         except DAODeleteFailedError as ex:
diff --git a/superset/row_level_security/commands/update.py b/superset/row_level_security/commands/update.py
index d44aa3efaf..bc5ef368ba 100644
--- a/superset/row_level_security/commands/update.py
+++ b/superset/row_level_security/commands/update.py
@@ -41,6 +41,8 @@ class UpdateRLSRuleCommand(BaseCommand):
 
     def run(self) -> Any:
         self.validate()
+        assert self._model
+
         try:
             rule = RLSDAO.update(self._model, self._properties)
         except DAOUpdateFailedError as ex: