You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2020/11/12 06:04:47 UTC
[incubator-superset] branch master updated: chore: consolidate
datasource import logic (#11533)
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 45738ff chore: consolidate datasource import logic (#11533)
45738ff is described below
commit 45738ffc1dfe94deb6578b6e5b0d6687558969e9
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Nov 11 22:04:16 2020 -0800
chore: consolidate datasource import logic (#11533)
* Consolidate dash import logic
* WIP
* Add license
* Fix lint
* Retrigger tests
* Fix lint
---
superset/cli.py | 19 +-
superset/connectors/druid/models.py | 58 +---
superset/connectors/sqla/models.py | 87 +-----
superset/dashboards/commands/importers/__init__.py | 16 ++
superset/dashboards/commands/importers/v0.py | 6 +-
superset/datasets/commands/importers/__init__.py | 16 ++
superset/datasets/commands/importers/v0.py | 303 +++++++++++++++++++++
superset/utils/dict_import_export.py | 23 +-
superset/utils/import_datasource.py | 105 -------
tests/import_export_tests.py | 29 +-
10 files changed, 365 insertions(+), 297 deletions(-)
diff --git a/superset/cli.py b/superset/cli.py
index 6d1d2fb..5130dbf 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -301,11 +301,11 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
)
def import_datasources(path: str, sync: str, recursive: bool) -> None:
"""Import datasources from YAML"""
- from superset.utils import dict_import_export
+ from superset.datasets.commands.importers.v0 import ImportDatasetsCommand
sync_array = sync.split(",")
path_object = Path(path)
- files = []
+ files: List[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:
@@ -314,16 +314,11 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
elif path_object.exists() and recursive:
files.extend(path_object.rglob("*.yaml"))
files.extend(path_object.rglob("*.yml"))
- for file_ in files:
- logger.info("Importing datasources from file %s", file_)
- try:
- with file_.open() as data_stream:
- dict_import_export.import_from_dict(
- db.session, yaml.safe_load(data_stream), sync=sync_array
- )
- except Exception as ex: # pylint: disable=broad-except
- logger.error("Error when importing datasources from file %s", file_)
- logger.error(ex)
+ contents = {path.name: open(path).read() for path in files}
+ try:
+ ImportDatasetsCommand(contents, sync_array).run()
+ except Exception: # pylint: disable=broad-except
+ logger.exception("Error when importing dataset")
@superset.command()
diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index a8d00fd..644d6a3 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -56,7 +56,7 @@ from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult
from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict
-from superset.utils import core as utils, import_datasource
+from superset.utils import core as utils
try:
import requests
@@ -378,20 +378,6 @@ class DruidColumn(Model, BaseColumn):
metric.datasource_id = self.datasource_id
db.session.add(metric)
- @classmethod
- def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
- def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]:
- return (
- db.session.query(DruidColumn)
- .filter(
- DruidColumn.datasource_id == lookup_column.datasource_id,
- DruidColumn.column_name == lookup_column.column_name,
- )
- .first()
- )
-
- return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
-
class DruidMetric(Model, BaseMetric):
@@ -447,20 +433,6 @@ class DruidMetric(Model, BaseMetric):
def get_perm(self) -> Optional[str]:
return self.perm
- @classmethod
- def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric":
- def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]:
- return (
- db.session.query(DruidMetric)
- .filter(
- DruidMetric.datasource_id == lookup_metric.datasource_id,
- DruidMetric.metric_name == lookup_metric.metric_name,
- )
- .first()
- )
-
- return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
-
druiddatasource_user = Table(
"druiddatasource_user",
@@ -610,34 +582,6 @@ class DruidDatasource(Model, BaseDatasource):
def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
- @classmethod
- def import_obj(
- cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None
- ) -> int:
- """Imports the datasource from the object to the database.
-
- Metrics and columns and datasource will be overridden if exists.
- This function can be used to import/export dashboards between multiple
- superset instances. Audit metadata isn't copies over.
- """
-
- def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]:
- return (
- db.session.query(DruidDatasource)
- .filter(
- DruidDatasource.datasource_name == d.datasource_name,
- DruidDatasource.cluster_id == d.cluster_id,
- )
- .first()
- )
-
- def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
- return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
-
- return import_datasource.import_datasource(
- db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
- )
-
def latest_metadata(self) -> Optional[Dict[str, Any]]:
"""Returns segment metadata from the latest segment"""
logger.info("Syncing datasource [{}]".format(self.datasource_name))
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index a7c078d..34961f0 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -47,7 +47,6 @@ from sqlalchemy import (
)
from sqlalchemy.exc import CompileError
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
-from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
@@ -58,11 +57,7 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr
from superset.constants import NULL_STRING
from superset.db_engine_specs.base import TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import (
- DatabaseNotFound,
- QueryObjectValidationError,
- SupersetSecurityException,
-)
+from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
@@ -74,7 +69,7 @@ from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict
-from superset.utils import core as utils, import_datasource
+from superset.utils import core as utils
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -290,20 +285,6 @@ class TableColumn(Model, BaseColumn):
)
return self.table.make_sqla_column_compatible(time_expr, label)
- @classmethod
- def import_obj(cls, i_column: "TableColumn") -> "TableColumn":
- def lookup_obj(lookup_column: TableColumn) -> TableColumn:
- return (
- db.session.query(TableColumn)
- .filter(
- TableColumn.table_id == lookup_column.table_id,
- TableColumn.column_name == lookup_column.column_name,
- )
- .first()
- )
-
- return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
-
def dttm_sql_literal(
self,
dttm: DateTime,
@@ -412,20 +393,6 @@ class SqlMetric(Model, BaseMetric):
def get_perm(self) -> Optional[str]:
return self.perm
- @classmethod
- def import_obj(cls, i_metric: "SqlMetric") -> "SqlMetric":
- def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric:
- return (
- db.session.query(SqlMetric)
- .filter(
- SqlMetric.table_id == lookup_metric.table_id,
- SqlMetric.metric_name == lookup_metric.metric_name,
- )
- .first()
- )
-
- return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
-
def get_extra_dict(self) -> Dict[str, Any]:
try:
return json.loads(self.extra)
@@ -1417,56 +1384,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
return results
@classmethod
- def import_obj(
- cls,
- i_datasource: "SqlaTable",
- database_id: Optional[int] = None,
- import_time: Optional[int] = None,
- ) -> int:
- """Imports the datasource from the object to the database.
-
- Metrics and columns and datasource will be overrided if exists.
- This function can be used to import/export dashboards between multiple
- superset instances. Audit metadata isn't copies over.
- """
-
- def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable":
- return (
- db.session.query(SqlaTable)
- .join(Database)
- .filter(
- SqlaTable.table_name == table_.table_name,
- SqlaTable.schema == table_.schema,
- Database.id == table_.database_id,
- )
- .first()
- )
-
- def lookup_database(table_: SqlaTable) -> Database:
- try:
- return (
- db.session.query(Database)
- .filter_by(database_name=table_.params_dict["database_name"])
- .one()
- )
- except NoResultFound:
- raise DatabaseNotFound(
- _(
- "Database '%(name)s' is not found",
- name=table_.params_dict["database_name"],
- )
- )
-
- return import_datasource.import_datasource(
- db.session,
- i_datasource,
- lookup_database,
- lookup_sqlatable,
- import_time,
- database_id,
- )
-
- @classmethod
def query_datasources_by_name(
cls,
session: Session,
diff --git a/superset/dashboards/commands/importers/__init__.py b/superset/dashboards/commands/importers/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/superset/dashboards/commands/importers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py
index 8040c24..851ecab 100644
--- a/superset/dashboards/commands/importers/v0.py
+++ b/superset/dashboards/commands/importers/v0.py
@@ -19,7 +19,7 @@ import logging
import time
from copy import copy
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
@@ -27,6 +27,7 @@ from sqlalchemy.orm import make_transient, Session
from superset import ConnectorRegistry, db
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+from superset.datasets.commands.importers.v0 import import_dataset
from superset.exceptions import DashboardImportException
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -301,7 +302,7 @@ def import_dashboards(
if not data:
raise DashboardImportException(_("No data in file"))
for table in data["datasources"]:
- type(table).import_obj(table, database_id, import_time=import_time)
+ import_dataset(table, database_id, import_time=import_time)
session.commit()
for dashboard in data["dashboards"]:
import_dashboard(dashboard, import_time=import_time)
@@ -333,4 +334,5 @@ class ImportDashboardsCommand(BaseCommand):
try:
json.loads(content)
except ValueError:
+ logger.exception("Invalid JSON file")
raise
diff --git a/superset/datasets/commands/importers/__init__.py b/superset/datasets/commands/importers/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/superset/datasets/commands/importers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py
new file mode 100644
index 0000000..5b3ed25
--- /dev/null
+++ b/superset/datasets/commands/importers/v0.py
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import logging
+from typing import Any, Callable, Dict, List, Optional
+
+import yaml
+from flask_appbuilder import Model
+from sqlalchemy.orm import Session
+from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm.session import make_transient
+
+from superset import db
+from superset.commands.base import BaseCommand
+from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
+from superset.connectors.druid.models import (
+ DruidCluster,
+ DruidColumn,
+ DruidDatasource,
+ DruidMetric,
+)
+from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+from superset.databases.commands.exceptions import DatabaseNotFoundError
+from superset.models.core import Database
+from superset.utils.dict_import_export import DATABASES_KEY, DRUID_CLUSTERS_KEY
+
+logger = logging.getLogger(__name__)
+
+
+def lookup_sqla_table(table: SqlaTable) -> Optional[SqlaTable]:
+ return (
+ db.session.query(SqlaTable)
+ .join(Database)
+ .filter(
+ SqlaTable.table_name == table.table_name,
+ SqlaTable.schema == table.schema,
+ Database.id == table.database_id,
+ )
+ .first()
+ )
+
+
+def lookup_sqla_database(table: SqlaTable) -> Optional[Database]:
+ try:
+ return (
+ db.session.query(Database)
+ .filter_by(database_name=table.params_dict["database_name"])
+ .one()
+ )
+ except NoResultFound:
+ raise DatabaseNotFoundError
+
+
+def lookup_druid_cluster(datasource: DruidDatasource) -> Optional[DruidCluster]:
+ return db.session.query(DruidCluster).filter_by(id=datasource.cluster_id).first()
+
+
+def lookup_druid_datasource(datasource: DruidDatasource) -> Optional[DruidDatasource]:
+ return (
+ db.session.query(DruidDatasource)
+ .filter(
+ DruidDatasource.datasource_name == datasource.datasource_name,
+ DruidDatasource.cluster_id == datasource.cluster_id,
+ )
+ .first()
+ )
+
+
+def import_dataset(
+ i_datasource: BaseDatasource,
+ database_id: Optional[int] = None,
+ import_time: Optional[int] = None,
+) -> int:
+ """Imports the datasource from the object to the database.
+
+ Metrics and columns and datasource will be overridden if exists.
+ This function can be used to import/export dashboards between multiple
+ superset instances. Audit metadata isn't copied over.
+ """
+
+ lookup_database: Callable[[BaseDatasource], Optional[Database]]
+ lookup_datasource: Callable[[BaseDatasource], Optional[BaseDatasource]]
+ if isinstance(i_datasource, SqlaTable):
+ lookup_database = lookup_sqla_database
+ lookup_datasource = lookup_sqla_table
+ else:
+ lookup_database = lookup_druid_cluster
+ lookup_datasource = lookup_druid_datasource
+
+ return import_datasource(
+ db.session,
+ i_datasource,
+ lookup_database,
+ lookup_datasource,
+ import_time,
+ database_id,
+ )
+
+
+def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric:
+ return (
+ session.query(SqlMetric)
+ .filter(
+ SqlMetric.table_id == metric.table_id,
+ SqlMetric.metric_name == metric.metric_name,
+ )
+ .first()
+ )
+
+
+def lookup_druid_metric(session: Session, metric: DruidMetric) -> DruidMetric:
+ return (
+ session.query(DruidMetric)
+ .filter(
+ DruidMetric.datasource_id == metric.datasource_id,
+ DruidMetric.metric_name == metric.metric_name,
+ )
+ .first()
+ )
+
+
+def import_metric(session: Session, metric: BaseMetric) -> BaseMetric:
+ if isinstance(metric, SqlMetric):
+ lookup_metric = lookup_sqla_metric
+ else:
+ lookup_metric = lookup_druid_metric
+ return import_simple_obj(session, metric, lookup_metric)
+
+
+def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn:
+ return (
+ session.query(TableColumn)
+ .filter(
+ TableColumn.table_id == column.table_id,
+ TableColumn.column_name == column.column_name,
+ )
+ .first()
+ )
+
+
+def lookup_druid_column(session: Session, column: DruidColumn) -> DruidColumn:
+ return (
+ session.query(DruidColumn)
+ .filter(
+ DruidColumn.datasource_id == column.datasource_id,
+ DruidColumn.column_name == column.column_name,
+ )
+ .first()
+ )
+
+
+def import_column(session: Session, column: BaseColumn) -> BaseColumn:
+ if isinstance(column, TableColumn):
+ lookup_column = lookup_sqla_column
+ else:
+ lookup_column = lookup_druid_column
+ return import_simple_obj(session, column, lookup_column)
+
+
+def import_datasource( # pylint: disable=too-many-arguments
+ session: Session,
+ i_datasource: Model,
+ lookup_database: Callable[[Model], Optional[Model]],
+ lookup_datasource: Callable[[Model], Optional[Model]],
+ import_time: Optional[int] = None,
+ database_id: Optional[int] = None,
+) -> int:
+ """Imports the datasource from the object to the database.
+
+ Metrics and columns and datasource will be overrided if exists.
+ This function can be used to import/export datasources between multiple
+ superset instances. Audit metadata isn't copies over.
+ """
+ make_transient(i_datasource)
+ logger.info("Started import of the datasource: %s", i_datasource.to_json())
+
+ i_datasource.id = None
+ i_datasource.database_id = (
+ database_id
+ if database_id
+ else getattr(lookup_database(i_datasource), "id", None)
+ )
+ i_datasource.alter_params(import_time=import_time)
+
+ # override the datasource
+ datasource = lookup_datasource(i_datasource)
+
+ if datasource:
+ datasource.override(i_datasource)
+ session.flush()
+ else:
+ datasource = i_datasource.copy()
+ session.add(datasource)
+ session.flush()
+
+ for metric in i_datasource.metrics:
+ new_m = metric.copy()
+ new_m.table_id = datasource.id
+ logger.info(
+ "Importing metric %s from the datasource: %s",
+ new_m.to_json(),
+ i_datasource.full_name,
+ )
+ imported_m = import_metric(session, new_m)
+ if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]:
+ datasource.metrics.append(imported_m)
+
+ for column in i_datasource.columns:
+ new_c = column.copy()
+ new_c.table_id = datasource.id
+ logger.info(
+ "Importing column %s from the datasource: %s",
+ new_c.to_json(),
+ i_datasource.full_name,
+ )
+ imported_c = import_column(session, new_c)
+ if imported_c.column_name not in [c.column_name for c in datasource.columns]:
+ datasource.columns.append(imported_c)
+ session.flush()
+ return datasource.id
+
+
+def import_simple_obj(
+ session: Session, i_obj: Model, lookup_obj: Callable[[Session, Model], Model]
+) -> Model:
+ make_transient(i_obj)
+ i_obj.id = None
+ i_obj.table = None
+
+ # find if the column was already imported
+ existing_column = lookup_obj(session, i_obj)
+ i_obj.table = None
+ if existing_column:
+ existing_column.override(i_obj)
+ session.flush()
+ return existing_column
+
+ session.add(i_obj)
+ session.flush()
+ return i_obj
+
+
+def import_from_dict(
+ session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
+) -> None:
+ """Imports databases and druid clusters from dictionary"""
+ if not sync:
+ sync = []
+ if isinstance(data, dict):
+ logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
+ for database in data.get(DATABASES_KEY, []):
+ Database.import_from_dict(session, database, sync=sync)
+
+ logger.info(
+ "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
+ )
+ for datasource in data.get(DRUID_CLUSTERS_KEY, []):
+ DruidCluster.import_from_dict(session, datasource, sync=sync)
+ session.commit()
+ else:
+ logger.info("Supplied object is not a dictionary.")
+
+
+class ImportDatasetsCommand(BaseCommand):
+ """
+ Import datasources in YAML format.
+
+ This is the original unversioned format used to export and import datasources
+ in Superset.
+ """
+
+ def __init__(self, contents: Dict[str, str], sync: Optional[List[str]] = None):
+ self.contents = contents
+ self.sync = sync
+
+ def run(self) -> None:
+ self.validate()
+
+ for file_name, content in self.contents.items():
+ logger.info("Importing dataset from file %s", file_name)
+ import_from_dict(db.session, yaml.safe_load(content), sync=self.sync)
+
+ def validate(self) -> None:
+ # ensure all files are YAML
+ for content in self.contents.values():
+ try:
+ yaml.safe_load(content)
+ except yaml.parser.ParserError:
+ logger.exception("Invalid YAML file")
+ raise
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index 3a37e91..256f1b5 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict
from sqlalchemy.orm import Session
@@ -75,24 +75,3 @@ def export_to_dict(
if clusters:
data[DRUID_CLUSTERS_KEY] = clusters
return data
-
-
-def import_from_dict(
- session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
-) -> None:
- """Imports databases and druid clusters from dictionary"""
- if not sync:
- sync = []
- if isinstance(data, dict):
- logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
- for database in data.get(DATABASES_KEY, []):
- Database.import_from_dict(session, database, sync=sync)
-
- logger.info(
- "Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
- )
- for datasource in data.get(DRUID_CLUSTERS_KEY, []):
- DruidCluster.import_from_dict(session, datasource, sync=sync)
- session.commit()
- else:
- logger.info("Supplied object is not a dictionary.")
diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py
deleted file mode 100644
index 25da876..0000000
--- a/superset/utils/import_datasource.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from typing import Callable, Optional
-
-from flask_appbuilder import Model
-from sqlalchemy.orm import Session
-from sqlalchemy.orm.session import make_transient
-
-logger = logging.getLogger(__name__)
-
-
-def import_datasource( # pylint: disable=too-many-arguments
- session: Session,
- i_datasource: Model,
- lookup_database: Callable[[Model], Model],
- lookup_datasource: Callable[[Model], Model],
- import_time: Optional[int] = None,
- database_id: Optional[int] = None,
-) -> int:
- """Imports the datasource from the object to the database.
-
- Metrics and columns and datasource will be overrided if exists.
- This function can be used to import/export datasources between multiple
- superset instances. Audit metadata isn't copies over.
- """
- make_transient(i_datasource)
- logger.info("Started import of the datasource: %s", i_datasource.to_json())
-
- i_datasource.id = None
- i_datasource.database_id = (
- database_id if database_id else lookup_database(i_datasource).id
- )
- i_datasource.alter_params(import_time=import_time)
-
- # override the datasource
- datasource = lookup_datasource(i_datasource)
-
- if datasource:
- datasource.override(i_datasource)
- session.flush()
- else:
- datasource = i_datasource.copy()
- session.add(datasource)
- session.flush()
-
- for metric in i_datasource.metrics:
- new_m = metric.copy()
- new_m.table_id = datasource.id
- logger.info(
- "Importing metric %s from the datasource: %s",
- new_m.to_json(),
- i_datasource.full_name,
- )
- imported_m = i_datasource.metric_class.import_obj(new_m)
- if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]:
- datasource.metrics.append(imported_m)
-
- for column in i_datasource.columns:
- new_c = column.copy()
- new_c.table_id = datasource.id
- logger.info(
- "Importing column %s from the datasource: %s",
- new_c.to_json(),
- i_datasource.full_name,
- )
- imported_c = i_datasource.column_class.import_obj(new_c)
- if imported_c.column_name not in [c.column_name for c in datasource.columns]:
- datasource.columns.append(imported_c)
- session.flush()
- return datasource.id
-
-
-def import_simple_obj(
- session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
-) -> Model:
- make_transient(i_obj)
- i_obj.id = None
- i_obj.table = None
-
- # find if the column was already imported
- existing_column = lookup_obj(i_obj)
- i_obj.table = None
- if existing_column:
- existing_column.override(i_obj)
- session.flush()
- return existing_column
-
- session.add(i_obj)
- session.flush()
- return i_obj
diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py
index c19b957..aac3a51 100644
--- a/tests/import_export_tests.py
+++ b/tests/import_export_tests.py
@@ -33,6 +33,7 @@ from superset.connectors.druid.models import (
)
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dashboards.commands.importers.v0 import import_chart, import_dashboard
+from superset.datasets.commands.importers.v0 import import_dataset
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_database
@@ -567,7 +568,7 @@ class TestImportExport(SupersetTestCase):
def test_import_table_no_metadata(self):
db_id = get_example_database().id
table = self.create_table("pure_table", id=10001)
- imported_id = SqlaTable.import_obj(table, db_id, import_time=1989)
+ imported_id = import_dataset(table, db_id, import_time=1989)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
@@ -576,7 +577,7 @@ class TestImportExport(SupersetTestCase):
"table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"]
)
db_id = get_example_database().id
- imported_id = SqlaTable.import_obj(table, db_id, import_time=1990)
+ imported_id = import_dataset(table, db_id, import_time=1990)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.assertEqual(
@@ -592,7 +593,7 @@ class TestImportExport(SupersetTestCase):
metric_names=["m1", "m2"],
)
db_id = get_example_database().id
- imported_id = SqlaTable.import_obj(table, db_id, import_time=1991)
+ imported_id = import_dataset(table, db_id, import_time=1991)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
@@ -602,7 +603,7 @@ class TestImportExport(SupersetTestCase):
"table_override", id=10003, cols_names=["col1"], metric_names=["m1"]
)
db_id = get_example_database().id
- imported_id = SqlaTable.import_obj(table, db_id, import_time=1991)
+ imported_id = import_dataset(table, db_id, import_time=1991)
table_over = self.create_table(
"table_override",
@@ -610,7 +611,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_over_id = SqlaTable.import_obj(table_over, db_id, import_time=1992)
+ imported_over_id = import_dataset(table_over, db_id, import_time=1992)
imported_over = self.get_table_by_id(imported_over_id)
self.assertEqual(imported_id, imported_over.id)
@@ -630,7 +631,7 @@ class TestImportExport(SupersetTestCase):
metric_names=["new_metric1"],
)
db_id = get_example_database().id
- imported_id = SqlaTable.import_obj(table, db_id, import_time=1993)
+ imported_id = import_dataset(table, db_id, import_time=1993)
copy_table = self.create_table(
"copy_cat",
@@ -638,14 +639,14 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_id_copy = SqlaTable.import_obj(copy_table, db_id, import_time=1994)
+ imported_id_copy = import_dataset(copy_table, db_id, import_time=1994)
self.assertEqual(imported_id, imported_id_copy)
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
def test_import_druid_no_metadata(self):
datasource = self.create_druid_datasource("pure_druid", id=10001)
- imported_id = DruidDatasource.import_obj(datasource, import_time=1989)
+ imported_id = import_dataset(datasource, import_time=1989)
imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported)
@@ -653,7 +654,7 @@ class TestImportExport(SupersetTestCase):
datasource = self.create_druid_datasource(
"druid_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"]
)
- imported_id = DruidDatasource.import_obj(datasource, import_time=1990)
+ imported_id = import_dataset(datasource, import_time=1990)
imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported)
self.assertEqual(
@@ -668,7 +669,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["c1", "c2"],
metric_names=["m1", "m2"],
)
- imported_id = DruidDatasource.import_obj(datasource, import_time=1991)
+ imported_id = import_dataset(datasource, import_time=1991)
imported = self.get_datasource(imported_id)
self.assert_datasource_equals(datasource, imported)
@@ -676,14 +677,14 @@ class TestImportExport(SupersetTestCase):
datasource = self.create_druid_datasource(
"druid_override", id=10004, cols_names=["col1"], metric_names=["m1"]
)
- imported_id = DruidDatasource.import_obj(datasource, import_time=1991)
+ imported_id = import_dataset(datasource, import_time=1991)
table_over = self.create_druid_datasource(
"druid_override",
id=10004,
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_over_id = DruidDatasource.import_obj(table_over, import_time=1992)
+ imported_over_id = import_dataset(table_over, import_time=1992)
imported_over = self.get_datasource(imported_over_id)
self.assertEqual(imported_id, imported_over.id)
@@ -702,7 +703,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_id = DruidDatasource.import_obj(datasource, import_time=1993)
+ imported_id = import_dataset(datasource, import_time=1993)
copy_datasource = self.create_druid_datasource(
"copy_cat",
@@ -710,7 +711,7 @@ class TestImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
- imported_id_copy = DruidDatasource.import_obj(copy_datasource, import_time=1994)
+ imported_id_copy = import_dataset(copy_datasource, import_time=1994)
self.assertEqual(imported_id, imported_id_copy)
self.assert_datasource_equals(copy_datasource, self.get_datasource(imported_id))