You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2020/11/20 22:20:35 UTC
[incubator-superset] branch master updated: feat: add a command to
import charts (#11743)
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 2f4f877 feat: add a command to import charts (#11743)
2f4f877 is described below
commit 2f4f87795d10c3052b66a6aaface2235d59355ed
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Nov 20 14:20:13 2020 -0800
feat: add a command to import charts (#11743)
* ImportChartsCommand
* Fix type
---
superset/charts/commands/importers/__init__.py | 16 +++
superset/charts/commands/importers/v1/__init__.py | 146 ++++++++++++++++++++
superset/charts/commands/importers/v1/utils.py | 42 ++++++
superset/charts/schemas.py | 8 ++
superset/models/helpers.py | 10 +-
superset/models/slice.py | 1 +
tests/charts/commands_tests.py | 157 +++++++++++++++++++++-
tests/datasets/commands_tests.py | 33 ++++-
tests/fixtures/importexport.py | 47 +++++++
9 files changed, 449 insertions(+), 11 deletions(-)
diff --git a/superset/charts/commands/importers/__init__.py b/superset/charts/commands/importers/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/superset/charts/commands/importers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py
new file mode 100644
index 0000000..086a370
--- /dev/null
+++ b/superset/charts/commands/importers/v1/__init__.py
@@ -0,0 +1,146 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any, Dict, List, Optional, Set
+
+from marshmallow import Schema, validate
+from marshmallow.exceptions import ValidationError
+from sqlalchemy.orm import Session
+
+from superset import db
+from superset.charts.commands.importers.v1.utils import import_chart
+from superset.charts.schemas import ImportV1ChartSchema
+from superset.commands.base import BaseCommand
+from superset.commands.exceptions import CommandInvalidError
+from superset.commands.importers.v1.utils import (
+ load_metadata,
+ load_yaml,
+ METADATA_FILE_NAME,
+)
+from superset.databases.commands.importers.v1.utils import import_database
+from superset.databases.schemas import ImportV1DatabaseSchema
+from superset.datasets.commands.importers.v1.utils import import_dataset
+from superset.datasets.schemas import ImportV1DatasetSchema
+from superset.models.slice import Slice
+
+schemas: Dict[str, Schema] = {
+ "charts/": ImportV1ChartSchema(),
+ "datasets/": ImportV1DatasetSchema(),
+ "databases/": ImportV1DatabaseSchema(),
+}
+
+
+class ImportChartsCommand(BaseCommand):
+
+ """Import charts"""
+
+ # pylint: disable=unused-argument
+ def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ self.contents = contents
+ self._configs: Dict[str, Any] = {}
+
+ def _import_bundle(self, session: Session) -> None:
+ # discover datasets associated with charts
+ dataset_uuids: Set[str] = set()
+ for file_name, config in self._configs.items():
+ if file_name.startswith("charts/"):
+ dataset_uuids.add(config["dataset_uuid"])
+
+ # discover databases associated with datasets
+ database_uuids: Set[str] = set()
+ for file_name, config in self._configs.items():
+ if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
+ database_uuids.add(config["database_uuid"])
+
+ # import related databases
+ database_ids: Dict[str, int] = {}
+ for file_name, config in self._configs.items():
+ if file_name.startswith("databases/") and config["uuid"] in database_uuids:
+ database = import_database(session, config, overwrite=False)
+ database_ids[str(database.uuid)] = database.id
+
+ # import datasets with the correct parent ref
+ dataset_info: Dict[str, Dict[str, Any]] = {}
+ for file_name, config in self._configs.items():
+ if (
+ file_name.startswith("datasets/")
+ and config["database_uuid"] in database_ids
+ ):
+ config["database_id"] = database_ids[config["database_uuid"]]
+ dataset = import_dataset(session, config, overwrite=False)
+ dataset_info[str(dataset.uuid)] = {
+ "datasource_id": dataset.id,
+ "datasource_type": "view" if dataset.is_sqllab_view else "table",
+ "datasource_name": dataset.table_name,
+ }
+
+ # import charts with the correct parent ref
+ for file_name, config in self._configs.items():
+ if (
+ file_name.startswith("charts/")
+ and config["dataset_uuid"] in dataset_info
+ ):
+ # update datasource id, type, and name
+ config.update(dataset_info[config["dataset_uuid"]])
+ import_chart(session, config, overwrite=True)
+
+ def run(self) -> None:
+ self.validate()
+
+ # rollback to prevent partial imports
+ try:
+ self._import_bundle(db.session)
+ db.session.commit()
+ except Exception as exc:
+ db.session.rollback()
+ raise exc
+
+ def validate(self) -> None:
+ exceptions: List[ValidationError] = []
+
+ # verify that the metadata file is present and valid
+ try:
+ metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
+ except ValidationError as exc:
+ exceptions.append(exc)
+ metadata = None
+
+ for file_name, content in self.contents.items():
+ prefix = file_name.split("/")[0]
+ schema = schemas.get(f"{prefix}/")
+ if schema:
+ try:
+ config = load_yaml(file_name, content)
+ schema.load(config)
+ self._configs[file_name] = config
+ except ValidationError as exc:
+ exc.messages = {file_name: exc.messages}
+ exceptions.append(exc)
+
+ # validate that the type declared in METADATA_FILE_NAME is correct
+ if metadata:
+ type_validator = validate.Equal(Slice.__name__)
+ try:
+ type_validator(metadata["type"])
+ except ValidationError as exc:
+ exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
+ exceptions.append(exc)
+
+ if exceptions:
+ exception = CommandInvalidError("Error importing chart")
+ exception.add_list(exceptions)
+ raise exception
diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py
new file mode 100644
index 0000000..b3d4237
--- /dev/null
+++ b/superset/charts/commands/importers/v1/utils.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+from typing import Any, Dict
+
+from sqlalchemy.orm import Session
+
+from superset.models.slice import Slice
+
+
+def import_chart(
+ session: Session, config: Dict[str, Any], overwrite: bool = False
+) -> Slice:
+ existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
+ if existing:
+ if not overwrite:
+ return existing
+ config["id"] = existing.id
+
+ # TODO (betodealmeida): move this logic to import_from_dict
+ config["params"] = json.dumps(config["params"])
+
+ chart = Slice.import_from_dict(session, config, recursive=False)
+ if chart.id is None:
+ session.flush()
+
+ return chart
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 0997038..ca1497a 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -1118,6 +1118,14 @@ class GetFavStarIdsSchema(Schema):
)
+class ImportV1ChartSchema(Schema):
+ params = fields.Dict()
+ cache_timeout = fields.Integer(allow_none=True)
+ uuid = fields.UUID(required=True)
+ version = fields.String(required=True)
+ dataset_uuid = fields.UUID(required=True)
+
+
CHART_SCHEMAS = (
ChartDataQueryContextSchema,
ChartDataResponseSchema,
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index f0e5c0b..623bb07 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -88,14 +88,6 @@ class ImportExportMixin:
__mapper__: Mapper
@classmethod
- def _parent_foreign_key_mappings(cls) -> Dict[str, str]:
- """Get a mapping of foreign name to the local name of foreign keys"""
- parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
- if parent_rel:
- return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
- return {}
-
- @classmethod
def _unique_constrains(cls) -> List[Set[str]]:
"""Get all (single column and multi column) unique constraints"""
unique = [
@@ -171,7 +163,7 @@ class ImportExportMixin:
# Remove fields that should not get imported
for k in list(dict_rep):
- if k not in export_fields:
+ if k not in export_fields and k not in parent_refs:
del dict_rep[k]
if not parent:
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 7254652..2fd55a7 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -93,6 +93,7 @@ class Slice(
"params",
"cache_timeout",
]
+ export_parent = "table"
def __repr__(self) -> str:
return self.slice_name or str(self.id)
diff --git a/tests/charts/commands_tests.py b/tests/charts/commands_tests.py
index 9189d4c..8523b83 100644
--- a/tests/charts/commands_tests.py
+++ b/tests/charts/commands_tests.py
@@ -14,16 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=no-self-use, invalid-name
+import json
from unittest.mock import patch
+import pytest
import yaml
from superset import db, security_manager
from superset.charts.commands.exceptions import ChartNotFoundError
from superset.charts.commands.export import ExportChartsCommand
+from superset.charts.commands.importers.v1 import ImportChartsCommand
+from superset.commands.exceptions import CommandInvalidError
+from superset.commands.importers.exceptions import IncorrectVersionError
+from superset.connectors.sqla.models import SqlaTable
+from superset.models.core import Database
from superset.models.slice import Slice
from tests.base_tests import SupersetTestCase
+from tests.fixtures.importexport import (
+ chart_config,
+ chart_metadata_config,
+ database_config,
+ database_metadata_config,
+ dataset_config,
+)
class TestExportChartsCommand(SupersetTestCase):
@@ -49,7 +64,7 @@ class TestExportChartsCommand(SupersetTestCase):
"viz_type": "sankey",
"params": {
"collapsed_fieldsets": "",
- "groupby": ["source", "target",],
+ "groupby": ["source", "target"],
"metric": "sum__value",
"row_limit": "5000",
"slice_name": "Energy Sankey",
@@ -100,3 +115,143 @@ class TestExportChartsCommand(SupersetTestCase):
"version",
"dataset_uuid",
]
+
+ def test_import_v1_chart(self):
+ """Test that we can import a chart"""
+ contents = {
+ "metadata.yaml": yaml.safe_dump(chart_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(database_config),
+ "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+ "charts/imported_chart.yaml": yaml.safe_dump(chart_config),
+ }
+ command = ImportChartsCommand(contents)
+ command.run()
+
+ chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()
+ assert json.loads(chart.params) == {
+ "color_picker": {"a": 1, "b": 135, "g": 122, "r": 0},
+ "datasource": "12__table",
+ "js_columns": ["color"],
+ "js_data_mutator": "data => data.map(d => ({\\n ...d,\\n color: colors.hexToRGB(d.extraProps.color)\\n}));",
+ "js_onclick_href": "",
+ "js_tooltip": "",
+ "line_column": "path_json",
+ "line_type": "json",
+ "line_width": 150,
+ "mapbox_style": "mapbox://styles/mapbox/light-v9",
+ "reverse_long_lat": False,
+ "row_limit": 5000,
+ "slice_id": 43,
+ "time_grain_sqla": None,
+ "time_range": " : ",
+ "viewport": {
+ "altitude": 1.5,
+ "bearing": 0,
+ "height": 1094,
+ "latitude": 37.73671752604488,
+ "longitude": -122.18885402582598,
+ "maxLatitude": 85.05113,
+ "maxPitch": 60,
+ "maxZoom": 20,
+ "minLatitude": -85.05113,
+ "minPitch": 0,
+ "minZoom": 0,
+ "pitch": 0,
+ "width": 669,
+ "zoom": 9.51847667620428,
+ },
+ "viz_type": "deck_path",
+ }
+
+ dataset = (
+ db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
+ )
+ assert dataset.table_name == "imported_dataset"
+ assert chart.table == dataset
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert database.database_name == "imported_database"
+ assert chart.table.database == database
+
+ db.session.delete(chart)
+ db.session.delete(dataset)
+ db.session.delete(database)
+ db.session.commit()
+
+ def test_import_v1_chart_multiple(self):
+ """Test that a dataset can be imported multiple times"""
+ contents = {
+ "metadata.yaml": yaml.safe_dump(chart_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(database_config),
+ "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+ "charts/imported_chart.yaml": yaml.safe_dump(chart_config),
+ }
+ command = ImportChartsCommand(contents)
+ command.run()
+ command.run()
+
+ dataset = (
+ db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
+ )
+ charts = db.session.query(Slice).filter_by(datasource_id=dataset.id).all()
+ assert len(charts) == 1
+
+ database = dataset.database
+
+ db.session.delete(charts[0])
+ db.session.delete(dataset)
+ db.session.delete(database)
+ db.session.commit()
+
+ def test_import_v1_chart_validation(self):
+ """Test different validations applied when importing a chart"""
+ # metadata.yaml must be present
+ contents = {
+ "databases/imported_database.yaml": yaml.safe_dump(database_config),
+ "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+ "charts/imported_chart.yaml": yaml.safe_dump(chart_config),
+ }
+ command = ImportChartsCommand(contents)
+ with pytest.raises(IncorrectVersionError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Missing metadata.yaml"
+
+ # version should be 1.0.0
+ contents["metadata.yaml"] = yaml.safe_dump(
+ {
+ "version": "2.0.0",
+ "type": "SqlaTable",
+ "timestamp": "2020-11-04T21:27:44.423819+00:00",
+ }
+ )
+ command = ImportChartsCommand(contents)
+ with pytest.raises(IncorrectVersionError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Must be equal to 1.0.0."
+
+ # type should be Slice
+ contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
+ command = ImportChartsCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Error importing chart"
+ assert excinfo.value.normalized_messages() == {
+ "metadata.yaml": {"type": ["Must be equal to Slice."]}
+ }
+
+ # must also validate datasets and databases
+ broken_config = database_config.copy()
+ del broken_config["database_name"]
+ contents["metadata.yaml"] = yaml.safe_dump(chart_metadata_config)
+ contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config)
+ command = ImportChartsCommand(contents)
+ with pytest.raises(CommandInvalidError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == "Error importing chart"
+ assert excinfo.value.normalized_messages() == {
+ "databases/imported_database.yaml": {
+ "database_name": ["Missing data for required field."],
+ }
+ }
diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py
index 94b5e79..a957ffc 100644
--- a/tests/datasets/commands_tests.py
+++ b/tests/datasets/commands_tests.py
@@ -26,9 +26,11 @@ from superset import db, security_manager
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.connectors.sqla.models import SqlaTable
+from superset.databases.commands.importers.v1 import ImportDatabasesCommand
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.commands.export import ExportDatasetsCommand
from superset.datasets.commands.importers.v1 import ImportDatasetsCommand
+from superset.models.core import Database
from superset.utils.core import get_example_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
@@ -326,7 +328,7 @@ class TestExportDatasetsCommand(SupersetTestCase):
command.run()
assert str(excinfo.value) == "Error importing dataset"
assert excinfo.value.normalized_messages() == {
- "metadata.yaml": {"type": ["Must be equal to SqlaTable."],}
+ "metadata.yaml": {"type": ["Must be equal to SqlaTable."]}
}
# must also validate databases
@@ -343,3 +345,32 @@ class TestExportDatasetsCommand(SupersetTestCase):
"database_name": ["Missing data for required field."],
}
}
+
+ def test_import_v1_dataset_existing_database(self):
+ """Test that a dataset can be imported when the database already exists"""
+ # first import database...
+ contents = {
+ "metadata.yaml": yaml.safe_dump(database_metadata_config),
+ "databases/imported_database.yaml": yaml.safe_dump(database_config),
+ }
+ command = ImportDatabasesCommand(contents)
+ command.run()
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert len(database.tables) == 0
+
+ # ...then dataset
+ contents = {
+ "metadata.yaml": yaml.safe_dump(dataset_metadata_config),
+ "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
+ "databases/imported_database.yaml": yaml.safe_dump(database_config),
+ }
+ command = ImportDatasetsCommand(contents)
+ command.run()
+
+ database = (
+ db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
+ )
+ assert len(database.tables) == 1
diff --git a/tests/fixtures/importexport.py b/tests/fixtures/importexport.py
index 8b64004..8312a81 100644
--- a/tests/fixtures/importexport.py
+++ b/tests/fixtures/importexport.py
@@ -30,6 +30,12 @@ dataset_metadata_config: Dict[str, Any] = {
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
+chart_metadata_config: Dict[str, Any] = {
+ "version": "1.0.0",
+ "type": "Slice",
+ "timestamp": "2020-11-04T21:27:44.423819+00:00",
+}
+
database_config: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
@@ -88,3 +94,44 @@ dataset_config: Dict[str, Any] = {
"uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
}
+
+chart_config: Dict[str, Any] = {
+ "params": {
+ "color_picker": {"a": 1, "b": 135, "g": 122, "r": 0,},
+ "datasource": "12__table",
+ "js_columns": ["color"],
+ "js_data_mutator": r"data => data.map(d => ({\n ...d,\n color: colors.hexToRGB(d.extraProps.color)\n}));",
+ "js_onclick_href": "",
+ "js_tooltip": "",
+ "line_column": "path_json",
+ "line_type": "json",
+ "line_width": 150,
+ "mapbox_style": "mapbox://styles/mapbox/light-v9",
+ "reverse_long_lat": False,
+ "row_limit": 5000,
+ "slice_id": 43,
+ "time_grain_sqla": None,
+ "time_range": " : ",
+ "viewport": {
+ "altitude": 1.5,
+ "bearing": 0,
+ "height": 1094,
+ "latitude": 37.73671752604488,
+ "longitude": -122.18885402582598,
+ "maxLatitude": 85.05113,
+ "maxPitch": 60,
+ "maxZoom": 20,
+ "minLatitude": -85.05113,
+ "minPitch": 0,
+ "minZoom": 0,
+ "pitch": 0,
+ "width": 669,
+ "zoom": 9.51847667620428,
+ },
+ "viz_type": "deck_path",
+ },
+ "cache_timeout": None,
+ "uuid": "0c23747a-6528-4629-97bf-e4b78d3b9df1",
+ "version": "1.0.0",
+ "dataset_uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
+}