You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2023/10/05 17:38:02 UTC
[superset] branch master updated: fix(tags): fix clears delete on Tags Modal (#25470)
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new dcfebfce9d fix(tags): fix clears delete on Tags Modal (#25470)
dcfebfce9d is described below
commit dcfebfce9d0c3f9e249fb4146edaab2a11b77734
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Thu Oct 5 13:37:53 2023 -0400
fix(tags): fix clears delete on Tags Modal (#25470)
Co-authored-by: Beto Dealmeida <ro...@dealmeida.net>
---
superset/daos/tag.py | 7 ++-
superset/tags/commands/create.py | 61 +++++++++++----------------
superset/tags/commands/update.py | 15 +++----
superset/tags/schemas.py | 4 +-
tests/unit_tests/tags/commands/create_test.py | 26 +++++++-----
5 files changed, 53 insertions(+), 60 deletions(-)
diff --git a/superset/daos/tag.py b/superset/daos/tag.py
index b6872a5376..2acd221a35 100644
--- a/superset/daos/tag.py
+++ b/superset/daos/tag.py
@@ -390,7 +390,12 @@ class TagDAO(BaseDAO[Tag]):
updated_tagged_objects = {
(to_object_type(obj[0]), obj[1]) for obj in objects_to_tag
}
- tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects
+
+ tagged_objects_to_delete = (
+ current_tagged_objects
+ if not objects_to_tag
+ else current_tagged_objects - updated_tagged_objects
+ )
for object_type, object_id in updated_tagged_objects:
# create rows for new objects, and skip tags that already exist
diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py
index e8311ad520..883c498bc3 100644
--- a/superset/tags/commands/create.py
+++ b/superset/tags/commands/create.py
@@ -67,25 +67,22 @@ class CreateCustomTagCommand(CreateMixin, BaseCommand):
class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any], bulk_create: bool = False):
- self._tag = data["name"]
- self._objects_to_tag = data.get("objects_to_tag")
- self._description = data.get("description")
+ self._properties = data.copy()
self._bulk_create = bulk_create
def run(self) -> None:
self.validate()
try:
- tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom)
- if self._objects_to_tag:
- TagDAO.create_tag_relationship(
- objects_to_tag=self._objects_to_tag,
- tag=tag,
- bulk_create=self._bulk_create,
- )
+ tag_name = self._properties["name"]
+ tag = TagDAO.get_by_name(tag_name.strip(), TagTypes.custom)
+ TagDAO.create_tag_relationship(
+ objects_to_tag=self._properties.get("objects_to_tag", []),
+ tag=tag,
+ bulk_create=self._bulk_create,
+ )
- if self._description:
- tag.description = self._description
+ tag.description = self._properties.get("description", "")
db.session.commit()
@@ -95,31 +92,21 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
def validate(self) -> None:
exceptions = []
- # Validate object_id
- if self._objects_to_tag:
- if any(obj_id == 0 for obj_type, obj_id in self._objects_to_tag):
- exceptions.append(TagInvalidError())
-
- # Validate object type
- skipped_tagged_objects: list[tuple[str, int]] = []
- for obj_type, obj_id in self._objects_to_tag:
- skipped_tagged_objects = []
- object_type = to_object_type(obj_type)
-
- if not object_type:
- exceptions.append(
- TagInvalidError(f"invalid object type {object_type}")
- )
- try:
- model = to_object_model(object_type, obj_id) # type: ignore
- security_manager.raise_for_ownership(model)
- except SupersetSecurityException:
- # skip the object if the user doesn't have access
- skipped_tagged_objects.append((obj_type, obj_id))
-
- self._objects_to_tag = set(self._objects_to_tag) - set(
- skipped_tagged_objects
- )
+ objects_to_tag = set(self._properties.get("objects_to_tag", []))
+ skipped_tagged_objects: set[tuple[str, int]] = set()
+ for obj_type, obj_id in objects_to_tag:
+ object_type = to_object_type(obj_type)
+
+ if not object_type:
+ exceptions.append(TagInvalidError(f"invalid object type {object_type}"))
+ try:
+ model = to_object_model(object_type, obj_id) # type: ignore
+ security_manager.raise_for_ownership(model)
+ except SupersetSecurityException:
+ # skip the object if the user doesn't have access
+ skipped_tagged_objects.add((obj_type, obj_id))
+
+ self._properties["objects_to_tag"] = objects_to_tag - skipped_tagged_objects
if exceptions:
raise TagInvalidError(exceptions=exceptions)
diff --git a/superset/tags/commands/update.py b/superset/tags/commands/update.py
index a13e4e8e7b..cc1c9a2be7 100644
--- a/superset/tags/commands/update.py
+++ b/superset/tags/commands/update.py
@@ -38,12 +38,10 @@ class UpdateTagCommand(UpdateMixin, BaseCommand):
def run(self) -> Model:
self.validate()
if self._model:
- if self._properties.get("objects_to_tag"):
- # todo(hugh): can this manage duplication
- TagDAO.create_tag_relationship(
- objects_to_tag=self._properties["objects_to_tag"],
- tag=self._model,
- )
+ TagDAO.create_tag_relationship(
+ objects_to_tag=self._properties.get("objects_to_tag", []),
+ tag=self._model,
+ )
if description := self._properties.get("description"):
self._model.description = description
if tag_name := self._properties.get("name"):
@@ -63,11 +61,8 @@ class UpdateTagCommand(UpdateMixin, BaseCommand):
# Validate object_id
if objects_to_tag := self._properties.get("objects_to_tag"):
- if any(obj_id == 0 for obj_type, obj_id in objects_to_tag):
- exceptions.append(TagInvalidError(" invalid object_id"))
-
# Validate object type
- for obj_type, obj_id in objects_to_tag:
+ for obj_type, _ in objects_to_tag:
object_type = to_object_type(obj_type)
if not object_type:
exceptions.append(
diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py
index 75fdc2410a..a391fd2b80 100644
--- a/superset/tags/schemas.py
+++ b/superset/tags/schemas.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from marshmallow import fields, Schema
+from marshmallow.validate import Range
from superset.dashboards.schemas import UserSchema
@@ -60,7 +61,8 @@ class TagObjectSchema(Schema):
name = fields.String()
description = fields.String(required=False, allow_none=True)
objects_to_tag = fields.List(
- fields.Tuple((fields.String(), fields.Int())), required=False
+ fields.Tuple((fields.String(), fields.Int(validate=Range(min=1)))),
+ required=False,
)
diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py
index 639372a70f..d4143bd4ae 100644
--- a/tests/unit_tests/tags/commands/create_test.py
+++ b/tests/unit_tests/tags/commands/create_test.py
@@ -91,18 +91,16 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture)
)
-def test_create_command_failed_validate(
- session_with_data: Session, mocker: MockFixture
-):
+def test_create_command_success_clear(session_with_data: Session, mocker: MockFixture):
from superset.connectors.sqla.models import SqlaTable
from superset.daos.tag import TagDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery
from superset.tags.commands.create import CreateCustomTagWithRelationshipsCommand
- from superset.tags.commands.exceptions import TagInvalidError
from superset.tags.models import ObjectTypes, TaggedObject
+ # Define a list of objects to tag
query = session_with_data.query(SavedQuery).first()
chart = session_with_data.query(Slice).first()
dashboard = session_with_data.query(Dashboard).first()
@@ -110,16 +108,22 @@ def test_create_command_failed_validate(
mocker.patch(
"superset.security.SupersetSecurityManager.is_admin", return_value=True
)
- mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query)
- mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=chart)
+ mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
+ mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", return_value=query)
objects_to_tag = [
(ObjectTypes.query, query.id),
(ObjectTypes.chart, chart.id),
- (ObjectTypes.dashboard, 0),
+ (ObjectTypes.dashboard, dashboard.id),
]
- with pytest.raises(TagInvalidError):
- CreateCustomTagWithRelationshipsCommand(
- data={"name": "test_tag", "objects_to_tag": objects_to_tag}
- ).run()
+ CreateCustomTagWithRelationshipsCommand(
+ data={"name": "test_tag", "objects_to_tag": objects_to_tag}
+ ).run()
+ assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag)
+
+ CreateCustomTagWithRelationshipsCommand(
+ data={"name": "test_tag", "objects_to_tag": []}
+ ).run()
+
+ assert len(session_with_data.query(TaggedObject).all()) == 0