You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2021/01/06 01:53:12 UTC

[superset] branch master updated: fix: load example data into correct DB (#12292)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b2b208  fix: load example data into correct DB (#12292)
6b2b208 is described below

commit 6b2b208b3b5d61573d2abce15326ed8b055da07b
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Tue Jan 5 17:52:42 2021 -0800

    fix: load example data into correct DB (#12292)
    
    * fix: load example data into correct DB
    
    * Fix force_data
    
    * Fix lint
---
 superset/cli.py                                  |  2 +-
 superset/commands/importers/v1/examples.py       | 26 +++++++--
 superset/datasets/commands/importers/v1/utils.py | 68 +++++++++++++++---------
 superset/examples/utils.py                       |  4 +-
 4 files changed, 70 insertions(+), 30 deletions(-)

diff --git a/superset/cli.py b/superset/cli.py
index 72b02d5..557cc9d 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -168,7 +168,7 @@ def load_examples_run(
     examples.load_tabbed_dashboard(only_metadata)
 
     # load examples that are stored as YAML config files
-    examples.load_from_configs()
+    examples.load_from_configs(force)
 
 
 @with_appcontext
diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py
index c5b2e6e..2b56ee0 100644
--- a/superset/commands/importers/v1/examples.py
+++ b/superset/commands/importers/v1/examples.py
@@ -55,10 +55,28 @@ class ImportExamplesCommand(ImportModelsCommand):
     }
     import_error = CommandException
 
-    # pylint: disable=too-many-locals
+    def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+        super().__init__(contents, *args, **kwargs)
+        self.force_data = kwargs.get("force_data", False)
+
+    def run(self) -> None:
+        self.validate()
+
+        # rollback to prevent partial imports
+        try:
+            self._import(db.session, self._configs, self.overwrite, self.force_data)
+            db.session.commit()
+        except Exception:
+            db.session.rollback()
+            raise self.import_error()
+
+    # pylint: disable=too-many-locals, arguments-differ
     @staticmethod
     def _import(
-        session: Session, configs: Dict[str, Any], overwrite: bool = False
+        session: Session,
+        configs: Dict[str, Any],
+        overwrite: bool = False,
+        force_data: bool = False,
     ) -> None:
         # import databases
         database_ids: Dict[str, int] = {}
@@ -78,7 +96,9 @@ class ImportExamplesCommand(ImportModelsCommand):
         for file_name, config in configs.items():
             if file_name.startswith("datasets/"):
                 config["database_id"] = examples_id
-                dataset = import_dataset(session, config, overwrite=overwrite)
+                dataset = import_dataset(
+                    session, config, overwrite=overwrite, force_data=force_data
+                )
                 dataset_info[str(dataset.uuid)] = {
                     "datasource_id": dataset.id,
                     "datasource_type": "view" if dataset.is_sqllab_view else "table",
diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py
index d1cd0bb..5e59545 100644
--- a/superset/datasets/commands/importers/v1/utils.py
+++ b/superset/datasets/commands/importers/v1/utils.py
@@ -29,6 +29,8 @@ from sqlalchemy.orm import Session
 from sqlalchemy.sql.visitors import VisitableType
 
 from superset.connectors.sqla.models import SqlaTable
+from superset.models.core import Database
+from superset.utils.core import get_example_database, get_main_database
 
 logger = logging.getLogger(__name__)
 
@@ -74,7 +76,10 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
 
 
 def import_dataset(
-    session: Session, config: Dict[str, Any], overwrite: bool = False
+    session: Session,
+    config: Dict[str, Any],
+    overwrite: bool = False,
+    force_data: bool = False,
 ) -> SqlaTable:
     existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
     if existing:
@@ -108,28 +113,43 @@ def import_dataset(
     if dataset.id is None:
         session.flush()
 
-    # load data
-    if data_uri:
-        data = request.urlopen(data_uri)
-        if data_uri.endswith(".gz"):
-            data = gzip.open(data)
-        df = pd.read_csv(data, encoding="utf-8")
-        dtype = get_dtype(df, dataset)
-
-        # convert temporal columns
-        for column_name, sqla_type in dtype.items():
-            if isinstance(sqla_type, (Date, DateTime)):
-                df[column_name] = pd.to_datetime(df[column_name])
-
-        df.to_sql(
-            dataset.table_name,
-            con=session.connection(),
-            schema=dataset.schema,
-            if_exists="replace",
-            chunksize=CHUNKSIZE,
-            dtype=dtype,
-            index=False,
-            method="multi",
-        )
+    example_database = get_example_database()
+    table_exists = example_database.has_table_by_name(dataset.table_name)
+    if data_uri and (not table_exists or force_data):
+        load_data(data_uri, dataset, example_database, session)
 
     return dataset
+
+
+def load_data(
+    data_uri: str, dataset: SqlaTable, example_database: Database, session: Session
+) -> None:
+    data = request.urlopen(data_uri)
+    if data_uri.endswith(".gz"):
+        data = gzip.open(data)
+    df = pd.read_csv(data, encoding="utf-8")
+    dtype = get_dtype(df, dataset)
+
+    # convert temporal columns
+    for column_name, sqla_type in dtype.items():
+        if isinstance(sqla_type, (Date, DateTime)):
+            df[column_name] = pd.to_datetime(df[column_name])
+
+    # reuse session when loading data if possible, to make import atomic
+    if example_database.sqlalchemy_uri == get_main_database().sqlalchemy_uri:
+        logger.info("Loading data inside the import transaction")
+        connection = session.connection()
+    else:
+        logger.warning("Loading data outside the import transaction")
+        connection = example_database.get_sqla_engine()
+
+    df.to_sql(
+        dataset.table_name,
+        con=connection,
+        schema=dataset.schema,
+        if_exists="replace",
+        chunksize=CHUNKSIZE,
+        dtype=dtype,
+        index=False,
+        method="multi",
+    )
diff --git a/superset/examples/utils.py b/superset/examples/utils.py
index 723f2bc..951b741 100644
--- a/superset/examples/utils.py
+++ b/superset/examples/utils.py
@@ -24,9 +24,9 @@ from superset.commands.importers.v1.examples import ImportExamplesCommand
 YAML_EXTENSIONS = {".yaml", ".yml"}
 
 
-def load_from_configs() -> None:
+def load_from_configs(force_data: bool = False) -> None:
     contents = load_contents()
-    command = ImportExamplesCommand(contents, overwrite=True)
+    command = ImportExamplesCommand(contents, overwrite=True, force_data=force_data)
     command.run()