You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2023/10/05 17:38:02 UTC

[superset] branch master updated: fix(tags): fix clears delete on Tags Modal (#25470)

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new dcfebfce9d fix(tags): fix clears delete on Tags Modal (#25470)
dcfebfce9d is described below

commit dcfebfce9d0c3f9e249fb4146edaab2a11b77734
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Thu Oct 5 13:37:53 2023 -0400

    fix(tags): fix clears delete on Tags Modal (#25470)
    
    Co-authored-by: Beto Dealmeida <ro...@dealmeida.net>
---
 superset/daos/tag.py                          |  7 ++-
 superset/tags/commands/create.py              | 61 +++++++++++----------------
 superset/tags/commands/update.py              | 15 +++----
 superset/tags/schemas.py                      |  4 +-
 tests/unit_tests/tags/commands/create_test.py | 26 +++++++-----
 5 files changed, 53 insertions(+), 60 deletions(-)

diff --git a/superset/daos/tag.py b/superset/daos/tag.py
index b6872a5376..2acd221a35 100644
--- a/superset/daos/tag.py
+++ b/superset/daos/tag.py
@@ -390,7 +390,12 @@ class TagDAO(BaseDAO[Tag]):
         updated_tagged_objects = {
             (to_object_type(obj[0]), obj[1]) for obj in objects_to_tag
         }
-        tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects
+
+        tagged_objects_to_delete = (
+            current_tagged_objects
+            if not objects_to_tag
+            else current_tagged_objects - updated_tagged_objects
+        )
 
         for object_type, object_id in updated_tagged_objects:
             # create rows for new objects, and skip tags that already exist
diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py
index e8311ad520..883c498bc3 100644
--- a/superset/tags/commands/create.py
+++ b/superset/tags/commands/create.py
@@ -67,25 +67,22 @@ class CreateCustomTagCommand(CreateMixin, BaseCommand):
 
 class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
     def __init__(self, data: dict[str, Any], bulk_create: bool = False):
-        self._tag = data["name"]
-        self._objects_to_tag = data.get("objects_to_tag")
-        self._description = data.get("description")
+        self._properties = data.copy()
         self._bulk_create = bulk_create
 
     def run(self) -> None:
         self.validate()
 
         try:
-            tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom)
-            if self._objects_to_tag:
-                TagDAO.create_tag_relationship(
-                    objects_to_tag=self._objects_to_tag,
-                    tag=tag,
-                    bulk_create=self._bulk_create,
-                )
+            tag_name = self._properties["name"]
+            tag = TagDAO.get_by_name(tag_name.strip(), TagTypes.custom)
+            TagDAO.create_tag_relationship(
+                objects_to_tag=self._properties.get("objects_to_tag", []),
+                tag=tag,
+                bulk_create=self._bulk_create,
+            )
 
-            if self._description:
-                tag.description = self._description
+            tag.description = self._properties.get("description", "")
 
             db.session.commit()
 
@@ -95,31 +92,21 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
 
     def validate(self) -> None:
         exceptions = []
-        # Validate object_id
-        if self._objects_to_tag:
-            if any(obj_id == 0 for obj_type, obj_id in self._objects_to_tag):
-                exceptions.append(TagInvalidError())
-
-            # Validate object type
-            skipped_tagged_objects: list[tuple[str, int]] = []
-            for obj_type, obj_id in self._objects_to_tag:
-                skipped_tagged_objects = []
-                object_type = to_object_type(obj_type)
-
-                if not object_type:
-                    exceptions.append(
-                        TagInvalidError(f"invalid object type {object_type}")
-                    )
-                try:
-                    model = to_object_model(object_type, obj_id)  # type: ignore
-                    security_manager.raise_for_ownership(model)
-                except SupersetSecurityException:
-                    # skip the object if the user doesn't have access
-                    skipped_tagged_objects.append((obj_type, obj_id))
-
-            self._objects_to_tag = set(self._objects_to_tag) - set(
-                skipped_tagged_objects
-            )
+        objects_to_tag = set(self._properties.get("objects_to_tag", []))
+        skipped_tagged_objects: set[tuple[str, int]] = set()
+        for obj_type, obj_id in objects_to_tag:
+            object_type = to_object_type(obj_type)
+
+            if not object_type:
+                exceptions.append(TagInvalidError(f"invalid object type {object_type}"))
+            try:
+                model = to_object_model(object_type, obj_id)  # type: ignore
+                security_manager.raise_for_ownership(model)
+            except SupersetSecurityException:
+                # skip the object if the user doesn't have access
+                skipped_tagged_objects.add((obj_type, obj_id))
+
+        self._properties["objects_to_tag"] = objects_to_tag - skipped_tagged_objects
 
         if exceptions:
             raise TagInvalidError(exceptions=exceptions)
diff --git a/superset/tags/commands/update.py b/superset/tags/commands/update.py
index a13e4e8e7b..cc1c9a2be7 100644
--- a/superset/tags/commands/update.py
+++ b/superset/tags/commands/update.py
@@ -38,12 +38,10 @@ class UpdateTagCommand(UpdateMixin, BaseCommand):
     def run(self) -> Model:
         self.validate()
         if self._model:
-            if self._properties.get("objects_to_tag"):
-                # todo(hugh): can this manage duplication
-                TagDAO.create_tag_relationship(
-                    objects_to_tag=self._properties["objects_to_tag"],
-                    tag=self._model,
-                )
+            TagDAO.create_tag_relationship(
+                objects_to_tag=self._properties.get("objects_to_tag", []),
+                tag=self._model,
+            )
             if description := self._properties.get("description"):
                 self._model.description = description
             if tag_name := self._properties.get("name"):
@@ -63,11 +61,8 @@ class UpdateTagCommand(UpdateMixin, BaseCommand):
 
         # Validate object_id
         if objects_to_tag := self._properties.get("objects_to_tag"):
-            if any(obj_id == 0 for obj_type, obj_id in objects_to_tag):
-                exceptions.append(TagInvalidError(" invalid object_id"))
-
             # Validate object type
-            for obj_type, obj_id in objects_to_tag:
+            for obj_type, _ in objects_to_tag:
                 object_type = to_object_type(obj_type)
                 if not object_type:
                     exceptions.append(
diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py
index 75fdc2410a..a391fd2b80 100644
--- a/superset/tags/schemas.py
+++ b/superset/tags/schemas.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from marshmallow import fields, Schema
+from marshmallow.validate import Range
 
 from superset.dashboards.schemas import UserSchema
 
@@ -60,7 +61,8 @@ class TagObjectSchema(Schema):
     name = fields.String()
     description = fields.String(required=False, allow_none=True)
     objects_to_tag = fields.List(
-        fields.Tuple((fields.String(), fields.Int())), required=False
+        fields.Tuple((fields.String(), fields.Int(validate=Range(min=1)))),
+        required=False,
     )
 
 
diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py
index 639372a70f..d4143bd4ae 100644
--- a/tests/unit_tests/tags/commands/create_test.py
+++ b/tests/unit_tests/tags/commands/create_test.py
@@ -91,18 +91,16 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
         )
 
 
-def test_create_command_failed_validate(
-    session_with_data: Session, mocker: MockFixture
-):
+def test_create_command_success_clear(session_with_data: Session, mocker: MockFixture):
     from superset.connectors.sqla.models import SqlaTable
     from superset.daos.tag import TagDAO
     from superset.models.dashboard import Dashboard
     from superset.models.slice import Slice
     from superset.models.sql_lab import Query, SavedQuery
     from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand
-    from superset.tags.commands.exceptions import TagInvalidError
     from superset.tags.models import ObjectTypes, TaggedObject
 
+    # Define a list of objects to tag
     query = session_with_data.query(SavedQuery).first()
     chart = session_with_data.query(Slice).first()
     dashboard = session_with_data.query(Dashboard).first()
@@ -110,16 +108,22 @@ def test_create_command_failed_validate(
     mocker.patch(
         "superset.security.SupersetSecurityManager.is_admin", return_value=True
     )
-    mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query)
-    mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart)
+    mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
+    mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query)
 
     objects_to_tag = [
         (ObjectTypes.query, query.id),
         (ObjectTypes.chart, chart.id),
-        (ObjectTypes.dashboard, 0),
+        (ObjectTypes.dashboard, dashboard.id),
     ]
 
-    with pytest.raises(TagInvalidError):
-        CreateCustomTagWithRelationshipsCommand(
-            data={"name": "test_tag", "objects_to_tag": objects_to_tag}
-        ).run()
+    CreateCustomTagWithRelationshipsCommand(
+        data={"name": "test_tag", "objects_to_tag": objects_to_tag}
+    ).run()
+    assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+
+    CreateCustomTagWithRelationshipsCommand(
+        data={"name": "test_tag", "objects_to_tag": []}
+    ).run()
+
+    assert len(session_with_data.query(TaggedObject).all()) == 0