You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2020/11/20 22:20:35 UTC

[incubator-superset] branch master updated: feat: add a command to import charts (#11743)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f4f877  feat: add a command to import charts (#11743)
2f4f877 is described below

commit 2f4f87795d10c3052b66a6aaface2235d59355ed
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Nov 20 14:20:13 2020 -0800

    feat: add a command to import charts (#11743)
    
    * ImportChartsCommand
    
    * Fix type
---
 superset/charts/commands/importers/__init__.py    |  16 +++
 superset/charts/commands/importers/v1/__init__.py | 146 ++++++++++++++++++++
 superset/charts/commands/importers/v1/utils.py    |  42 ++++++
 superset/charts/schemas.py                        |   8 ++
 superset/models/helpers.py                        |  10 +-
 superset/models/slice.py                          |   1 +
 tests/charts/commands_tests.py                    | 157 +++++++++++++++++++++-
 tests/datasets/commands_tests.py                  |  33 ++++-
 tests/fixtures/importexport.py                    |  47 +++++++
 9 files changed, 449 insertions(+), 11 deletions(-)

diff --git a/superset/charts/commands/importers/__init__.py b/superset/charts/commands/importers/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/superset/charts/commands/importers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py
new file mode 100644
index 0000000..086a370
--- /dev/null
+++ b/superset/charts/commands/importers/v1/__init__.py
@@ -0,0 +1,146 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any, Dict, List, Optional, Set
+
+from marshmallow import Schema, validate
+from marshmallow.exceptions import ValidationError
+from sqlalchemy.orm import Session
+
+from superset import db
+from superset.charts.commands.importers.v1.utils import import_chart
+from superset.charts.schemas import ImportV1ChartSchema
+from superset.commands.base import BaseCommand
+from superset.commands.exceptions import CommandInvalidError
+from superset.commands.importers.v1.utils import (
+    load_metadata,
+    load_yaml,
+    METADATA_FILE_NAME,
+)
+from superset.databases.commands.importers.v1.utils import import_database
+from superset.databases.schemas import ImportV1DatabaseSchema
+from superset.datasets.commands.importers.v1.utils import import_dataset
+from superset.datasets.schemas import ImportV1DatasetSchema
+from superset.models.slice import Slice
+
+schemas: Dict[str, Schema] = {
+    "charts/": ImportV1ChartSchema(),
+    "datasets/": ImportV1DatasetSchema(),
+    "databases/": ImportV1DatabaseSchema(),
+}
+
+
+class ImportChartsCommand(BaseCommand):
+
+    """Import charts"""
+
+    # pylint: disable=unused-argument
+    def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+        self.contents = contents
+        self._configs: Dict[str, Any] = {}
+
+    def _import_bundle(self, session: Session) -> None:
+        # discover datasets associated with charts
+        dataset_uuids: Set[str] = set()
+        for file_name, config in self._configs.items():
+            if file_name.startswith("charts/"):
+                dataset_uuids.add(config["dataset_uuid"])
+
+        # discover databases associated with datasets
+        database_uuids: Set[str] = set()
+        for file_name, config in self._configs.items():
+            if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
+                database_uuids.add(config["database_uuid"])
+
+        # import related databases
+        database_ids: Dict[str, int] = {}
+        for file_name, config in self._configs.items():
+            if file_name.startswith("databases/") and config["uuid"] in database_uuids:
+                database = import_database(session, config, overwrite=False)
+                database_ids[str(database.uuid)] = database.id
+
+        # import datasets with the correct parent ref
+        dataset_info: Dict[str, Dict[str, Any]] = {}
+        for file_name, config in self._configs.items():
+            if (
+                file_name.startswith("datasets/")
+                and config["database_uuid"] in database_ids
+            ):
+                config["database_id"] = database_ids[config["database_uuid"]]
+                dataset = import_dataset(session, config, overwrite=False)
+                dataset_info[str(dataset.uuid)] = {
+                    "datasource_id": dataset.id,
+                    "datasource_type": "view" if dataset.is_sqllab_view else "table",
+                    "datasource_name": dataset.table_name,
+                }
+
+        # import charts with the correct parent ref
+        for file_name, config in self._configs.items():
+            if (
+                file_name.startswith("charts/")
+                and config["dataset_uuid"] in dataset_info
+            ):
+                # update datasource id, type, and name
+                config.update(dataset_info[config["dataset_uuid"]])
+                import_chart(session, config, overwrite=True)
+
+    def run(self) -> None:
+        self.validate()
+
+        # rollback to prevent partial imports
+        try:
+            self._import_bundle(db.session)
+            db.session.commit()
+        except Exception as exc:
+            db.session.rollback()
+            raise exc
+
+    def validate(self) -> None:
+        exceptions: List[ValidationError] = []
+
+        # verify that the metadata file is present and valid
+        try:
+            metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
+        except ValidationError as exc:
+            exceptions.append(exc)
+            metadata = None
+
+        for file_name, content in self.contents.items():
+            prefix = file_name.split("/")[0]
+            schema = schemas.get(f"{prefix}/")
+            if schema:
+                try:
+                    config = load_yaml(file_name, content)
+                    schema.load(config)
+                    self._configs[file_name] = config
+                except ValidationError as exc:
+                    exc.messages = {file_name: exc.messages}
+                    exceptions.append(exc)
+
+        # validate that the type declared in METADATA_FILE_NAME is correct
+        if metadata:
+            type_validator = validate.Equal(Slice.__name__)
+            try:
+                type_validator(metadata["type"])
+            except ValidationError as exc:
+                exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
+                exceptions.append(exc)
+
+        if exceptions:
+            exception = CommandInvalidError("Error importing chart")
+            exception.add_list(exceptions)
+            raise exception
diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py
new file mode 100644
index 0000000..b3d4237
--- /dev/null
+++ b/superset/charts/commands/importers/v1/utils.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+from typing import Any, Dict
+
+from sqlalchemy.orm import Session
+
+from superset.models.slice import Slice
+
+
+def import_chart(
+    session: Session, config: Dict[str, Any], overwrite: bool = False
+) -> Slice:
+    existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
+    if existing:
+        if not overwrite:
+            return existing
+        config["id"] = existing.id
+
+    # TODO (betodealmeida): move this logic to import_from_dict
+    config["params"] = json.dumps(config["params"])
+
+    chart = Slice.import_from_dict(session, config, recursive=False)
+    if chart.id is None:
+        session.flush()
+
+    return chart
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 0997038..ca1497a 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -1118,6 +1118,14 @@ class GetFavStarIdsSchema(Schema):
     )
 
 
+class ImportV1ChartSchema(Schema):
+    params = fields.Dict()
+    cache_timeout = fields.Integer(allow_none=True)
+    uuid = fields.UUID(required=True)
+    version = fields.String(required=True)
+    dataset_uuid = fields.UUID(required=True)
+
+
 CHART_SCHEMAS = (
     ChartDataQueryContextSchema,
     ChartDataResponseSchema,
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index f0e5c0b..623bb07 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -88,14 +88,6 @@ class ImportExportMixin:
     __mapper__: Mapper
 
     @classmethod
-    def _parent_foreign_key_mappings(cls) -> Dict[str, str]:
-        """Get a mapping of foreign name to the local name of foreign keys"""
-        parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
-        if parent_rel:
-            return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
-        return {}
-
-    @classmethod
     def _unique_constrains(cls) -> List[Set[str]]:
         """Get all (single column and multi column) unique constraints"""
         unique = [
@@ -171,7 +163,7 @@ class ImportExportMixin:
 
         # Remove fields that should not get imported
         for k in list(dict_rep):
-            if k not in export_fields:
+            if k not in export_fields and k not in parent_refs:
                 del dict_rep[k]
 
         if not parent:
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 7254652..2fd55a7 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -93,6 +93,7 @@ class Slice(
         "params",
         "cache_timeout",
     ]
+    export_parent = "table"
 
     def __repr__(self) -> str:
         return self.slice_name or str(self.id)
diff --git a/tests/charts/commands_tests.py b/tests/charts/commands_tests.py
index 9189d4c..8523b83 100644
--- a/tests/charts/commands_tests.py
+++ b/tests/charts/commands_tests.py
@@ -14,16 +14,31 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=no-self-use, invalid-name
 
+import json
 from unittest.mock import patch
 
+import pytest
 import yaml
 
 from superset import db, security_manager
 from superset.charts.commands.exceptions import ChartNotFoundError
 from superset.charts.commands.export import ExportChartsCommand
+from superset.charts.commands.importers.v1 import ImportChartsCommand
+from superset.commands.exceptions import CommandInvalidError
+from superset.commands.importers.exceptions import IncorrectVersionError
+from superset.connectors.sqla.models import SqlaTable
+from superset.models.core import Database
 from superset.models.slice import Slice
 from tests.base_tests import SupersetTestCase
+from tests.fixtures.importexport import (
+    chart_config,
+    chart_metadata_config,
+    database_config,
+    database_metadata_config,
+    dataset_config,
+)
 
 
 class TestExportChartsCommand(SupersetTestCase):
@@ -49,7 +64,7 @@ class TestExportChartsCommand(SupersetTestCase):
             "viz_type": "sankey",
             "params": {
                 "collapsed_fieldsets": "",
-                "groupby": ["source", "target",],
+                "groupby": ["source", "target"],
                 "metric": "sum__value",
                 "row_limit": "5000",
                 "slice_name": "Energy Sankey",
@@ -100,3 +115,143 @@ class TestExportChartsCommand(SupersetTestCase):
             "version",
             "dataset_uuid",
         ]
+
+    def test_import_v1_chart(self):
+        """Test that we can import a chart"""
+        contents = {
+            "metadata.yaml": yaml.safe_dump(chart_metadata_config),
+            "databases/imported_database.yaml": yaml.safe_dump(database_config),
+            "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+            "charts/imported_chart.yaml": yaml.safe_dump(chart_config),
+        }
+        command = ImportChartsCommand(contents)
+        command.run()
+
+        chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()
+        assert json.loads(chart.params) == {
+            "color_picker": {"a": 1, "b": 135, "g": 122, "r": 0},
+            "datasource": "12__table",
+            "js_columns": ["color"],
+            "js_data_mutator": "data => data.map(d => ({\\n    ...d,\\n    color: colors.hexToRGB(d.extraProps.color)\\n}));",
+            "js_onclick_href": "",
+            "js_tooltip": "",
+            "line_column": "path_json",
+            "line_type": "json",
+            "line_width": 150,
+            "mapbox_style": "mapbox://styles/mapbox/light-v9",
+            "reverse_long_lat": False,
+            "row_limit": 5000,
+            "slice_id": 43,
+            "time_grain_sqla": None,
+            "time_range": " : ",
+            "viewport": {
+                "altitude": 1.5,
+                "bearing": 0,
+                "height": 1094,
+                "latitude": 37.73671752604488,
+                "longitude": -122.18885402582598,
+                "maxLatitude": 85.05113,
+                "maxPitch": 60,
+                "maxZoom": 20,
+                "minLatitude": -85.05113,
+                "minPitch": 0,
+                "minZoom": 0,
+                "pitch": 0,
+                "width": 669,
+                "zoom": 9.51847667620428,
+            },
+            "viz_type": "deck_path",
+        }
+
+        dataset = (
+            db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
+        )
+        assert dataset.table_name == "imported_dataset"
+        assert chart.table == dataset
+
+        database = (
+            db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+        )
+        assert database.database_name == "imported_database"
+        assert chart.table.database == database
+
+        db.session.delete(chart)
+        db.session.delete(dataset)
+        db.session.delete(database)
+        db.session.commit()
+
+    def test_import_v1_chart_multiple(self):
+        """Test that a dataset can be imported multiple times"""
+        contents = {
+            "metadata.yaml": yaml.safe_dump(chart_metadata_config),
+            "databases/imported_database.yaml": yaml.safe_dump(database_config),
+            "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+            "charts/imported_chart.yaml": yaml.safe_dump(chart_config),
+        }
+        command = ImportChartsCommand(contents)
+        command.run()
+        command.run()
+
+        dataset = (
+            db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
+        )
+        charts = db.session.query(Slice).filter_by(datasource_id=dataset.id).all()
+        assert len(charts) == 1
+
+        database = dataset.database
+
+        db.session.delete(charts[0])
+        db.session.delete(dataset)
+        db.session.delete(database)
+        db.session.commit()
+
+    def test_import_v1_chart_validation(self):
+        """Test different validations applied when importing a chart"""
+        # metadata.yaml must be present
+        contents = {
+            "databases/imported_database.yaml": yaml.safe_dump(database_config),
+            "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+            "charts/imported_chart.yaml": yaml.safe_dump(chart_config),
+        }
+        command = ImportChartsCommand(contents)
+        with pytest.raises(IncorrectVersionError) as excinfo:
+            command.run()
+        assert str(excinfo.value) == "Missing metadata.yaml"
+
+        # version should be 1.0.0
+        contents["metadata.yaml"] = yaml.safe_dump(
+            {
+                "version": "2.0.0",
+                "type": "SqlaTable",
+                "timestamp": "2020-11-04T21:27:44.423819+00:00",
+            }
+        )
+        command = ImportChartsCommand(contents)
+        with pytest.raises(IncorrectVersionError) as excinfo:
+            command.run()
+        assert str(excinfo.value) == "Must be equal to 1.0.0."
+
+        # type should be Slice
+        contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
+        command = ImportChartsCommand(contents)
+        with pytest.raises(CommandInvalidError) as excinfo:
+            command.run()
+        assert str(excinfo.value) == "Error importing chart"
+        assert excinfo.value.normalized_messages() == {
+            "metadata.yaml": {"type": ["Must be equal to Slice."]}
+        }
+
+        # must also validate datasets and databases
+        broken_config = database_config.copy()
+        del broken_config["database_name"]
+        contents["metadata.yaml"] = yaml.safe_dump(chart_metadata_config)
+        contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config)
+        command = ImportChartsCommand(contents)
+        with pytest.raises(CommandInvalidError) as excinfo:
+            command.run()
+        assert str(excinfo.value) == "Error importing chart"
+        assert excinfo.value.normalized_messages() == {
+            "databases/imported_database.yaml": {
+                "database_name": ["Missing data for required field."],
+            }
+        }
diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py
index 94b5e79..a957ffc 100644
--- a/tests/datasets/commands_tests.py
+++ b/tests/datasets/commands_tests.py
@@ -26,9 +26,11 @@ from superset import db, security_manager
 from superset.commands.exceptions import CommandInvalidError
 from superset.commands.importers.exceptions import IncorrectVersionError
 from superset.connectors.sqla.models import SqlaTable
+from superset.databases.commands.importers.v1 import ImportDatabasesCommand
 from superset.datasets.commands.exceptions import DatasetNotFoundError
 from superset.datasets.commands.export import ExportDatasetsCommand
 from superset.datasets.commands.importers.v1 import ImportDatasetsCommand
+from superset.models.core import Database
 from superset.utils.core import get_example_database
 from tests.base_tests import SupersetTestCase
 from tests.fixtures.importexport import (
@@ -326,7 +328,7 @@ class TestExportDatasetsCommand(SupersetTestCase):
             command.run()
         assert str(excinfo.value) == "Error importing dataset"
         assert excinfo.value.normalized_messages() == {
-            "metadata.yaml": {"type": ["Must be equal to SqlaTable."],}
+            "metadata.yaml": {"type": ["Must be equal to SqlaTable."]}
         }
 
         # must also validate databases
@@ -343,3 +345,32 @@ class TestExportDatasetsCommand(SupersetTestCase):
                 "database_name": ["Missing data for required field."],
             }
         }
+
+    def test_import_v1_dataset_existing_database(self):
+        """Test that a dataset can be imported when the database already exists"""
+        # first import database...
+        contents = {
+            "metadata.yaml": yaml.safe_dump(database_metadata_config),
+            "databases/imported_database.yaml": yaml.safe_dump(database_config),
+        }
+        command = ImportDatabasesCommand(contents)
+        command.run()
+
+        database = (
+            db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+        )
+        assert len(database.tables) == 0
+
+        # ...then dataset
+        contents = {
+            "metadata.yaml": yaml.safe_dump(dataset_metadata_config),
+            "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+            "databases/imported_database.yaml": yaml.safe_dump(database_config),
+        }
+        command = ImportDatasetsCommand(contents)
+        command.run()
+
+        database = (
+            db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+        )
+        assert len(database.tables) == 1
diff --git a/tests/fixtures/importexport.py b/tests/fixtures/importexport.py
index 8b64004..8312a81 100644
--- a/tests/fixtures/importexport.py
+++ b/tests/fixtures/importexport.py
@@ -30,6 +30,12 @@ dataset_metadata_config: Dict[str, Any] = {
     "timestamp": "2020-11-04T21:27:44.423819+00:00",
 }
 
+chart_metadata_config: Dict[str, Any] = {
+    "version": "1.0.0",
+    "type": "Slice",
+    "timestamp": "2020-11-04T21:27:44.423819+00:00",
+}
+
 database_config: Dict[str, Any] = {
     "allow_csv_upload": True,
     "allow_ctas": True,
@@ -88,3 +94,44 @@ dataset_config: Dict[str, Any] = {
     "uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
     "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
 }
+
+chart_config: Dict[str, Any] = {
+    "params": {
+        "color_picker": {"a": 1, "b": 135, "g": 122, "r": 0,},
+        "datasource": "12__table",
+        "js_columns": ["color"],
+        "js_data_mutator": r"data => data.map(d => ({\n    ...d,\n    color: colors.hexToRGB(d.extraProps.color)\n}));",
+        "js_onclick_href": "",
+        "js_tooltip": "",
+        "line_column": "path_json",
+        "line_type": "json",
+        "line_width": 150,
+        "mapbox_style": "mapbox://styles/mapbox/light-v9",
+        "reverse_long_lat": False,
+        "row_limit": 5000,
+        "slice_id": 43,
+        "time_grain_sqla": None,
+        "time_range": " : ",
+        "viewport": {
+            "altitude": 1.5,
+            "bearing": 0,
+            "height": 1094,
+            "latitude": 37.73671752604488,
+            "longitude": -122.18885402582598,
+            "maxLatitude": 85.05113,
+            "maxPitch": 60,
+            "maxZoom": 20,
+            "minLatitude": -85.05113,
+            "minPitch": 0,
+            "minZoom": 0,
+            "pitch": 0,
+            "width": 669,
+            "zoom": 9.51847667620428,
+        },
+        "viz_type": "deck_path",
+    },
+    "cache_timeout": None,
+    "uuid": "0c23747a-6528-4629-97bf-e4b78d3b9df1",
+    "version": "1.0.0",
+    "dataset_uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
+}