You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2023/07/28 12:07:46 UTC

[superset] 01/02: fix(permalink): migrate to marshmallow codec (#24166)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 2.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b7e8a84b9fb503118c2c8ff25c0645d5822428c6
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Mon May 22 13:35:58 2023 +0300

    fix(permalink): migrate to marshmallow codec (#24166)
---
 superset/dashboards/permalink/api.py               |   6 +-
 superset/dashboards/permalink/commands/base.py     |   9 +-
 superset/dashboards/permalink/commands/create.py   |   3 +
 superset/dashboards/permalink/commands/get.py      |   7 +-
 superset/dashboards/permalink/schemas.py           |  11 +-
 superset/explore/permalink/api.py                  |   6 +-
 superset/explore/permalink/commands/base.py        |   9 +-
 superset/explore/permalink/commands/create.py      |   3 +
 superset/explore/permalink/commands/get.py         |   7 +-
 superset/explore/permalink/schemas.py              |  26 ++++-
 superset/key_value/exceptions.py                   |  12 ++
 superset/key_value/types.py                        |  36 +++++-
 .../explore/permalink/api_tests.py                 |  16 ++-
 tests/unit_tests/key_value/codec_test.py           | 122 +++++++++++++++++++++
 14 files changed, 251 insertions(+), 22 deletions(-)

diff --git a/superset/dashboards/permalink/api.py b/superset/dashboards/permalink/api.py
index a8664f0ddd..d9211df2aa 100644
--- a/superset/dashboards/permalink/api.py
+++ b/superset/dashboards/permalink/api.py
@@ -30,7 +30,7 @@ from superset.dashboards.permalink.commands.create import (
 )
 from superset.dashboards.permalink.commands.get import GetDashboardPermalinkCommand
 from superset.dashboards.permalink.exceptions import DashboardPermalinkInvalidStateError
-from superset.dashboards.permalink.schemas import DashboardPermalinkPostSchema
+from superset.dashboards.permalink.schemas import DashboardPermalinkStateSchema
 from superset.extensions import event_logger
 from superset.key_value.exceptions import KeyValueAccessDeniedError
 from superset.views.base_api import BaseSupersetApi, requires_json
@@ -39,13 +39,13 @@ logger = logging.getLogger(__name__)
 
 
 class DashboardPermalinkRestApi(BaseSupersetApi):
-    add_model_schema = DashboardPermalinkPostSchema()
+    add_model_schema = DashboardPermalinkStateSchema()
     method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
     allow_browser_login = True
     class_permission_name = "DashboardPermalinkRestApi"
     resource_name = "dashboard"
     openapi_spec_tag = "Dashboard Permanent Link"
-    openapi_spec_component_schemas = (DashboardPermalinkPostSchema,)
+    openapi_spec_component_schemas = (DashboardPermalinkStateSchema,)
 
     @expose("/<pk>/permalink", methods=["POST"])
     @protect()
diff --git a/superset/dashboards/permalink/commands/base.py b/superset/dashboards/permalink/commands/base.py
index 82e24264ca..4bfb78ea26 100644
--- a/superset/dashboards/permalink/commands/base.py
+++ b/superset/dashboards/permalink/commands/base.py
@@ -17,13 +17,18 @@
 from abc import ABC
 
 from superset.commands.base import BaseCommand
+from superset.dashboards.permalink.schemas import DashboardPermalinkSchema
 from superset.key_value.shared_entries import get_permalink_salt
-from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
+from superset.key_value.types import (
+    KeyValueResource,
+    MarshmallowKeyValueCodec,
+    SharedKey,
+)
 
 
 class BaseDashboardPermalinkCommand(BaseCommand, ABC):
     resource = KeyValueResource.DASHBOARD_PERMALINK
-    codec = JsonKeyValueCodec()
+    codec = MarshmallowKeyValueCodec(DashboardPermalinkSchema())
 
     @property
     def salt(self) -> str:
diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py
index 2b6151fbb2..0487041070 100644
--- a/superset/dashboards/permalink/commands/create.py
+++ b/superset/dashboards/permalink/commands/create.py
@@ -23,6 +23,7 @@ from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCo
 from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
 from superset.dashboards.permalink.types import DashboardPermalinkState
 from superset.key_value.commands.upsert import UpsertKeyValueCommand
+from superset.key_value.exceptions import KeyValueCodecEncodeException
 from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid
 from superset.utils.core import get_user_id
 
@@ -62,6 +63,8 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
             ).run()
             assert key.id  # for type checks
             return encode_permalink_key(key=key.id, salt=self.salt)
+        except KeyValueCodecEncodeException as ex:
+            raise DashboardPermalinkCreateFailedError(str(ex)) from ex
         except SQLAlchemyError as ex:
             logger.exception("Error running create command")
             raise DashboardPermalinkCreateFailedError() from ex
diff --git a/superset/dashboards/permalink/commands/get.py b/superset/dashboards/permalink/commands/get.py
index 4206263a37..da54ae0b66 100644
--- a/superset/dashboards/permalink/commands/get.py
+++ b/superset/dashboards/permalink/commands/get.py
@@ -25,7 +25,11 @@ from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCo
 from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
 from superset.dashboards.permalink.types import DashboardPermalinkValue
 from superset.key_value.commands.get import GetKeyValueCommand
-from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
+from superset.key_value.exceptions import (
+    KeyValueCodecDecodeException,
+    KeyValueGetFailedError,
+    KeyValueParseKeyError,
+)
 from superset.key_value.utils import decode_permalink_id
 
 logger = logging.getLogger(__name__)
@@ -51,6 +55,7 @@ class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
             return None
         except (
             DashboardNotFoundError,
+            KeyValueCodecDecodeException,
             KeyValueGetFailedError,
             KeyValueParseKeyError,
         ) as ex:
diff --git a/superset/dashboards/permalink/schemas.py b/superset/dashboards/permalink/schemas.py
index ce222d7ed6..acbfec5a17 100644
--- a/superset/dashboards/permalink/schemas.py
+++ b/superset/dashboards/permalink/schemas.py
@@ -17,7 +17,7 @@
 from marshmallow import fields, Schema
 
 
-class DashboardPermalinkPostSchema(Schema):
+class DashboardPermalinkStateSchema(Schema):
     dataMask = fields.Dict(
         required=False,
         allow_none=True,
@@ -48,3 +48,12 @@ class DashboardPermalinkPostSchema(Schema):
         allow_none=True,
         description="Optional anchor link added to url hash",
     )
+
+
+class DashboardPermalinkSchema(Schema):
+    dashboardId = fields.String(
+        required=True,
+        allow_none=False,
+        metadata={"description": "The id or slug of the dasbhoard"},
+    )
+    state = fields.Nested(DashboardPermalinkStateSchema())
diff --git a/superset/explore/permalink/api.py b/superset/explore/permalink/api.py
index 88e819aa2b..2a8ff1998d 100644
--- a/superset/explore/permalink/api.py
+++ b/superset/explore/permalink/api.py
@@ -32,7 +32,7 @@ from superset.datasets.commands.exceptions import (
 from superset.explore.permalink.commands.create import CreateExplorePermalinkCommand
 from superset.explore.permalink.commands.get import GetExplorePermalinkCommand
 from superset.explore.permalink.exceptions import ExplorePermalinkInvalidStateError
-from superset.explore.permalink.schemas import ExplorePermalinkPostSchema
+from superset.explore.permalink.schemas import ExplorePermalinkStateSchema
 from superset.extensions import event_logger
 from superset.key_value.exceptions import KeyValueAccessDeniedError
 from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics
@@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
 
 
 class ExplorePermalinkRestApi(BaseSupersetApi):
-    add_model_schema = ExplorePermalinkPostSchema()
+    add_model_schema = ExplorePermalinkStateSchema()
     method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
     allow_browser_login = True
     class_permission_name = "ExplorePermalinkRestApi"
     resource_name = "explore"
     openapi_spec_tag = "Explore Permanent Link"
-    openapi_spec_component_schemas = (ExplorePermalinkPostSchema,)
+    openapi_spec_component_schemas = (ExplorePermalinkStateSchema,)
 
     @expose("/permalink", methods=["POST"])
     @protect()
diff --git a/superset/explore/permalink/commands/base.py b/superset/explore/permalink/commands/base.py
index a87183b7e9..0b7cfbb8ec 100644
--- a/superset/explore/permalink/commands/base.py
+++ b/superset/explore/permalink/commands/base.py
@@ -17,13 +17,18 @@
 from abc import ABC
 
 from superset.commands.base import BaseCommand
+from superset.explore.permalink.schemas import ExplorePermalinkSchema
 from superset.key_value.shared_entries import get_permalink_salt
-from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
+from superset.key_value.types import (
+    KeyValueResource,
+    MarshmallowKeyValueCodec,
+    SharedKey,
+)
 
 
 class BaseExplorePermalinkCommand(BaseCommand, ABC):
     resource: KeyValueResource = KeyValueResource.EXPLORE_PERMALINK
-    codec = JsonKeyValueCodec()
+    codec = MarshmallowKeyValueCodec(ExplorePermalinkSchema())
 
     @property
     def salt(self) -> str:
diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py
index 21c0f4e42f..90e64f6df7 100644
--- a/superset/explore/permalink/commands/create.py
+++ b/superset/explore/permalink/commands/create.py
@@ -23,6 +23,7 @@ from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
 from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
 from superset.explore.utils import check_access as check_chart_access
 from superset.key_value.commands.create import CreateKeyValueCommand
+from superset.key_value.exceptions import KeyValueCodecEncodeException
 from superset.key_value.utils import encode_permalink_key
 from superset.utils.core import DatasourceType
 
@@ -58,6 +59,8 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
             if key.id is None:
                 raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
             return encode_permalink_key(key=key.id, salt=self.salt)
+        except KeyValueCodecEncodeException as ex:
+            raise ExplorePermalinkCreateFailedError(str(ex)) from ex
         except SQLAlchemyError as ex:
             logger.exception("Error running create command")
             raise ExplorePermalinkCreateFailedError() from ex
diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py
index 4823117ece..1aa093b380 100644
--- a/superset/explore/permalink/commands/get.py
+++ b/superset/explore/permalink/commands/get.py
@@ -25,7 +25,11 @@ from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
 from superset.explore.permalink.types import ExplorePermalinkValue
 from superset.explore.utils import check_access as check_chart_access
 from superset.key_value.commands.get import GetKeyValueCommand
-from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
+from superset.key_value.exceptions import (
+    KeyValueCodecDecodeException,
+    KeyValueGetFailedError,
+    KeyValueParseKeyError,
+)
 from superset.key_value.utils import decode_permalink_id
 from superset.utils.core import DatasourceType
 
@@ -59,6 +63,7 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
             return None
         except (
             DatasetNotFoundError,
+            KeyValueCodecDecodeException,
             KeyValueGetFailedError,
             KeyValueParseKeyError,
         ) as ex:
diff --git a/superset/explore/permalink/schemas.py b/superset/explore/permalink/schemas.py
index e1f9d069b8..8b1ae129e8 100644
--- a/superset/explore/permalink/schemas.py
+++ b/superset/explore/permalink/schemas.py
@@ -17,7 +17,7 @@
 from marshmallow import fields, Schema
 
 
-class ExplorePermalinkPostSchema(Schema):
+class ExplorePermalinkStateSchema(Schema):
     formData = fields.Dict(
         required=True,
         allow_none=False,
@@ -37,3 +37,27 @@ class ExplorePermalinkPostSchema(Schema):
         allow_none=True,
         description="URL Parameters",
     )
+
+
+class ExplorePermalinkSchema(Schema):
+    chartId = fields.Integer(
+        required=False,
+        allow_none=True,
+        metadata={"description": "The id of the chart"},
+    )
+    datasourceType = fields.String(
+        required=True,
+        allow_none=False,
+        metadata={"description": "The type of the datasource"},
+    )
+    datasourceId = fields.Integer(
+        required=False,
+        allow_none=True,
+        metadata={"description": "The id of the datasource"},
+    )
+    datasource = fields.String(
+        required=False,
+        allow_none=True,
+        metadata={"description": "The fully qualified datasource reference"},
+    )
+    state = fields.Nested(ExplorePermalinkStateSchema())
diff --git a/superset/key_value/exceptions.py b/superset/key_value/exceptions.py
index b05daf6b89..e16f961872 100644
--- a/superset/key_value/exceptions.py
+++ b/superset/key_value/exceptions.py
@@ -52,3 +52,15 @@ class KeyValueUpsertFailedError(UpdateFailedError):
 
 class KeyValueAccessDeniedError(ForbiddenError):
     message = _("You don't have permission to modify the value.")
+
+
+class KeyValueCodecException(SupersetException):
+    pass
+
+
+class KeyValueCodecEncodeException(KeyValueCodecException):
+    message = _("Unable to encode value")
+
+
+class KeyValueCodecDecodeException(KeyValueCodecException):
+    message = _("Unable to decode value")
diff --git a/superset/key_value/types.py b/superset/key_value/types.py
index 07d06414f6..fb9c31899f 100644
--- a/superset/key_value/types.py
+++ b/superset/key_value/types.py
@@ -24,6 +24,13 @@ from enum import Enum
 from typing import Any, Optional, TypedDict
 from uuid import UUID
 
+from marshmallow import Schema, ValidationError
+
+from superset.key_value.exceptions import (
+    KeyValueCodecDecodeException,
+    KeyValueCodecEncodeException,
+)
+
 
 @dataclass
 class Key:
@@ -61,10 +68,16 @@ class KeyValueCodec(ABC):
 
 class JsonKeyValueCodec(KeyValueCodec):
     def encode(self, value: dict[Any, Any]) -> bytes:
-        return bytes(json.dumps(value), encoding="utf-8")
+        try:
+            return bytes(json.dumps(value), encoding="utf-8")
+        except TypeError as ex:
+            raise KeyValueCodecEncodeException(str(ex)) from ex
 
     def decode(self, value: bytes) -> dict[Any, Any]:
-        return json.loads(value)
+        try:
+            return json.loads(value)
+        except TypeError as ex:
+            raise KeyValueCodecDecodeException(str(ex)) from ex
 
 
 class PickleKeyValueCodec(KeyValueCodec):
@@ -73,3 +86,22 @@ class PickleKeyValueCodec(KeyValueCodec):
 
     def decode(self, value: bytes) -> dict[Any, Any]:
         return pickle.loads(value)
+
+
+class MarshmallowKeyValueCodec(JsonKeyValueCodec):
+    def __init__(self, schema: Schema):
+        self.schema = schema
+
+    def encode(self, value: dict[Any, Any]) -> bytes:
+        try:
+            obj = self.schema.dump(value)
+            return super().encode(obj)
+        except ValidationError as ex:
+            raise KeyValueCodecEncodeException(message=str(ex)) from ex
+
+    def decode(self, value: bytes) -> dict[Any, Any]:
+        try:
+            obj = super().decode(value)
+            return self.schema.load(obj)
+        except ValidationError as ex:
+            raise KeyValueCodecEncodeException(message=str(ex)) from ex
diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py
index 4c6a3c12dd..3a07bd977a 100644
--- a/tests/integration_tests/explore/permalink/api_tests.py
+++ b/tests/integration_tests/explore/permalink/api_tests.py
@@ -22,8 +22,9 @@ import pytest
 from sqlalchemy.orm import Session
 
 from superset import db
+from superset.explore.permalink.schemas import ExplorePermalinkSchema
 from superset.key_value.models import KeyValueEntry
-from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
+from superset.key_value.types import KeyValueResource, MarshmallowKeyValueCodec
 from superset.key_value.utils import decode_permalink_id, encode_permalink_key
 from superset.models.slice import Slice
 from superset.utils.core import DatasourceType
@@ -94,14 +95,17 @@ def test_get_missing_chart(
     chart_id = 1234
     entry = KeyValueEntry(
         resource=KeyValueResource.EXPLORE_PERMALINK,
-        value=JsonKeyValueCodec().encode(
+        value=MarshmallowKeyValueCodec(ExplorePermalinkSchema()).encode(
             {
                 "chartId": chart_id,
                 "datasourceId": chart.datasource.id,
-                "datasourceType": DatasourceType.TABLE,
-                "formData": {
-                    "slice_id": chart_id,
-                    "datasource": f"{chart.datasource.id}__{chart.datasource.type}",
+                "datasourceType": DatasourceType.TABLE.value,
+                "state": {
+                    "urlParams": [["foo", "bar"]],
+                    "formData": {
+                        "slice_id": chart_id,
+                        "datasource": f"{chart.datasource.id}__{chart.datasource.type}",
+                    },
                 },
             }
         ),
diff --git a/tests/unit_tests/key_value/codec_test.py b/tests/unit_tests/key_value/codec_test.py
new file mode 100644
index 0000000000..1442a3a95a
--- /dev/null
+++ b/tests/unit_tests/key_value/codec_test.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from contextlib import nullcontext
+from typing import Any
+
+import pytest
+from marshmallow import Schema
+
+from superset.dashboards.permalink.schemas import DashboardPermalinkSchema
+from superset.key_value.exceptions import KeyValueCodecEncodeException
+from superset.key_value.types import (
+    JsonKeyValueCodec,
+    MarshmallowKeyValueCodec,
+    PickleKeyValueCodec,
+)
+
+
+@pytest.mark.parametrize(
+    "input_,expected_result",
+    [
+        (
+            {"foo": "bar"},
+            {"foo": "bar"},
+        ),
+        (
+            {"foo": (1, 2, 3)},
+            {"foo": [1, 2, 3]},
+        ),
+        (
+            {1, 2, 3},
+            KeyValueCodecEncodeException(),
+        ),
+        (
+            object(),
+            KeyValueCodecEncodeException(),
+        ),
+    ],
+)
+def test_json_codec(input_: Any, expected_result: Any):
+    cm = (
+        pytest.raises(type(expected_result))
+        if isinstance(expected_result, Exception)
+        else nullcontext()
+    )
+    with cm:
+        codec = JsonKeyValueCodec()
+        encoded_value = codec.encode(input_)
+        assert expected_result == codec.decode(encoded_value)
+
+
+@pytest.mark.parametrize(
+    "schema,input_,expected_result",
+    [
+        (
+            DashboardPermalinkSchema(),
+            {
+                "dashboardId": "1",
+                "state": {
+                    "urlParams": [["foo", "bar"], ["foo", "baz"]],
+                },
+            },
+            {
+                "dashboardId": "1",
+                "state": {
+                    "urlParams": [("foo", "bar"), ("foo", "baz")],
+                },
+            },
+        ),
+        (
+            DashboardPermalinkSchema(),
+            {"foo": "bar"},
+            KeyValueCodecEncodeException(),
+        ),
+    ],
+)
+def test_marshmallow_codec(schema: Schema, input_: Any, expected_result: Any):
+    cm = (
+        pytest.raises(type(expected_result))
+        if isinstance(expected_result, Exception)
+        else nullcontext()
+    )
+    with cm:
+        codec = MarshmallowKeyValueCodec(schema)
+        encoded_value = codec.encode(input_)
+        assert expected_result == codec.decode(encoded_value)
+
+
+@pytest.mark.parametrize(
+    "input_,expected_result",
+    [
+        (
+            {1, 2, 3},
+            {1, 2, 3},
+        ),
+        (
+            {"foo": 1, "bar": {1: (1, 2, 3)}, "baz": {1, 2, 3}},
+            {
+                "foo": 1,
+                "bar": {1: (1, 2, 3)},
+                "baz": {1, 2, 3},
+            },
+        ),
+    ],
+)
+def test_pickle_codec(input_: Any, expected_result: Any):
+    codec = PickleKeyValueCodec()
+    encoded_value = codec.encode(input_)
+    assert expected_result == codec.decode(encoded_value)