You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/10/26 18:04:02 UTC

[superset] branch ref-get-sqla-engine-2 created (now 158da8d200)

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a change to branch ref-get-sqla-engine-2
in repository https://gitbox.apache.org/repos/asf/superset.git


      at 158da8d200 init

This branch includes the following new commits:

     new 158da8d200 init

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[superset] 01/01: init

Posted by hu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch ref-get-sqla-engine-2
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 158da8d2008fc26eb191a600d538d6796caffc3a
Author: hughhhh <hu...@gmail.com>
AuthorDate: Wed Oct 26 14:03:32 2022 -0400

    init
---
 superset/connectors/sqla/models.py               |  12 +--
 superset/connectors/sqla/utils.py                |  39 ++++---
 superset/databases/commands/test_connection.py   |  56 +++++-----
 superset/databases/commands/validate.py          |  46 ++++-----
 superset/datasets/commands/importers/v1/utils.py |   3 +-
 superset/db_engine_specs/base.py                 |   5 +-
 superset/sql_lab.py                              | 126 ++++++++++++-----------
 7 files changed, 154 insertions(+), 133 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 4855dd1af3..98aac4906f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -958,13 +958,13 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         if self.fetch_values_predicate:
             qry = qry.where(self.get_fetch_values_predicate())
 
-        engine = self.database.get_sqla_engine()
-        sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
-        sql = self._apply_cte(sql, cte)
-        sql = self.mutate_query_from_config(sql)
+        with self.database.get_sqla_engine_with_context() as engine:
+            sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
+            sql = self._apply_cte(sql, cte)
+            sql = self.mutate_query_from_config(sql)
 
-        df = pd.read_sql_query(sql=sql, con=engine)
-        return df[column_name].to_list()
+            df = pd.read_sql_query(sql=sql, con=engine)
+            return df[column_name].to_list()
 
     def mutate_query_from_config(self, sql: str) -> str:
         """Apply config's SQL_QUERY_MUTATOR
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 8151bfd44b..05cf8cea13 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
         )
 
     db_engine_spec = dataset.database.db_engine_spec
-    engine = dataset.database.get_sqla_engine(schema=dataset.schema)
     sql = dataset.get_template_processor().process_template(
         dataset.sql, **dataset.template_params_dict
     )
@@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
     # TODO(villebro): refactor to use same code that's used by
     #  sql_lab.py:execute_sql_statements
     try:
-        with closing(engine.raw_connection()) as conn:
-            cursor = conn.cursor()
-            query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
-            db_engine_spec.execute(cursor, query)
-            result = db_engine_spec.fetch_data(cursor, limit=1)
-            result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
-            cols = result_set.columns
+        with dataset.database.get_sqla_engine_with_context(
+            schema=dataset.schema
+        ) as engine:
+            with closing(engine.raw_connection()) as conn:
+                cursor = conn.cursor()
+                query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
+                db_engine_spec.execute(cursor, query)
+                result = db_engine_spec.fetch_data(cursor, limit=1)
+                result_set = SupersetResultSet(
+                    result, cursor.description, db_engine_spec
+                )
+                cols = result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
     return cols
@@ -155,14 +159,17 @@ def get_columns_description(
 ) -> List[ResultSetColumnType]:
     db_engine_spec = database.db_engine_spec
     try:
-        with closing(database.get_sqla_engine().raw_connection()) as conn:
-            cursor = conn.cursor()
-            query = database.apply_limit_to_sql(query, limit=1)
-            cursor.execute(query)
-            db_engine_spec.execute(cursor, query)
-            result = db_engine_spec.fetch_data(cursor, limit=1)
-            result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
-            return result_set.columns
+        with database.get_sqla_engine_with_context() as engine:
+            with closing(engine.raw_connection()) as conn:
+                cursor = conn.cursor()
+                query = database.apply_limit_to_sql(query, limit=1)
+                cursor.execute(query)
+                db_engine_spec.execute(cursor, query)
+                result = db_engine_spec.fetch_data(cursor, limit=1)
+                result_set = SupersetResultSet(
+                    result, cursor.description, db_engine_spec
+                )
+                return result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
 
diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py
index d7f7d90e49..2865174ff8 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -86,7 +86,6 @@ class TestConnectionDatabaseCommand(BaseCommand):
             database.set_sqlalchemy_uri(uri)
             database.db_engine_spec.mutate_db_for_connection_test(database)
 
-            engine = database.get_sqla_engine()
             event_logger.log_with_context(
                 action="test_connection_attempt",
                 engine=database.db_engine_spec.__name__,
@@ -96,31 +95,36 @@ class TestConnectionDatabaseCommand(BaseCommand):
                 with closing(engine.raw_connection()) as conn:
                     return engine.dialect.do_ping(conn)
 
-            try:
-                alive = func_timeout(
-                    int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
-                    ping,
-                    args=(engine,),
-                )
-            except (sqlite3.ProgrammingError, RuntimeError):
-                # SQLite can't run on a separate thread, so ``func_timeout`` fails
-                # RuntimeError catches the equivalent error from duckdb.
-                alive = engine.dialect.do_ping(engine)
-            except FunctionTimedOut as ex:
-                raise SupersetTimeoutException(
-                    error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
-                    message=(
-                        "Please check your connection details and database settings, "
-                        "and ensure that your database is accepting connections, "
-                        "then try connecting again."
-                    ),
-                    level=ErrorLevel.ERROR,
-                    extra={"sqlalchemy_uri": database.sqlalchemy_uri},
-                ) from ex
-            except Exception:  # pylint: disable=broad-except
-                alive = False
-            if not alive:
-                raise DBAPIError(None, None, None)
+            with database.get_sqla_engine_with_context() as engine:
+                try:
+                    alive = func_timeout(
+                        int(
+                            app.config[
+                                "TEST_DATABASE_CONNECTION_TIMEOUT"
+                            ].total_seconds()
+                        ),
+                        ping,
+                        args=(engine,),
+                    )
+                except (sqlite3.ProgrammingError, RuntimeError):
+                    # SQLite can't run on a separate thread, so ``func_timeout`` fails
+                    # RuntimeError catches the equivalent error from duckdb.
+                    alive = engine.dialect.do_ping(engine)
+                except FunctionTimedOut as ex:
+                    raise SupersetTimeoutException(
+                        error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
+                        message=(
+                            "Please check your connection details and database settings, "
+                            "and ensure that your database is accepting connections, "
+                            "then try connecting again."
+                        ),
+                        level=ErrorLevel.ERROR,
+                        extra={"sqlalchemy_uri": database.sqlalchemy_uri},
+                    ) from ex
+                except Exception:  # pylint: disable=broad-except
+                    alive = False
+                if not alive:
+                    raise DBAPIError(None, None, None)
 
             # Log succesful connection test with engine
             event_logger.log_with_context(
diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py
index a8956257fa..a92fb79f83 100644
--- a/superset/databases/commands/validate.py
+++ b/superset/databases/commands/validate.py
@@ -101,30 +101,30 @@ class ValidateDatabaseParametersCommand(BaseCommand):
         database.set_sqlalchemy_uri(sqlalchemy_uri)
         database.db_engine_spec.mutate_db_for_connection_test(database)
 
-        engine = database.get_sqla_engine()
-        try:
-            with closing(engine.raw_connection()) as conn:
-                alive = engine.dialect.do_ping(conn)
-        except Exception as ex:
-            url = make_url_safe(sqlalchemy_uri)
-            context = {
-                "hostname": url.host,
-                "password": url.password,
-                "port": url.port,
-                "username": url.username,
-                "database": url.database,
-            }
-            errors = database.db_engine_spec.extract_errors(ex, context)
-            raise DatabaseTestConnectionFailedError(errors) from ex
+        with database.get_sqla_engine_with_context() as engine:
+            try:
+                with closing(engine.raw_connection()) as conn:
+                    alive = engine.dialect.do_ping(conn)
+            except Exception as ex:
+                url = make_url_safe(sqlalchemy_uri)
+                context = {
+                    "hostname": url.host,
+                    "password": url.password,
+                    "port": url.port,
+                    "username": url.username,
+                    "database": url.database,
+                }
+                errors = database.db_engine_spec.extract_errors(ex, context)
+                raise DatabaseTestConnectionFailedError(errors) from ex
 
-        if not alive:
-            raise DatabaseOfflineError(
-                SupersetError(
-                    message=__("Database is offline."),
-                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
-                    level=ErrorLevel.ERROR,
-                ),
-            )
+            if not alive:
+                raise DatabaseOfflineError(
+                    SupersetError(
+                        message=__("Database is offline."),
+                        error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
+                        level=ErrorLevel.ERROR,
+                    ),
+                )
 
     def validate(self) -> None:
         database_id = self._properties.get("id")
diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py
index ba2b7df261..7d3998b3bb 100644
--- a/superset/datasets/commands/importers/v1/utils.py
+++ b/superset/datasets/commands/importers/v1/utils.py
@@ -168,7 +168,8 @@ def load_data(
         connection = session.connection()
     else:
         logger.warning("Loading data outside the import transaction")
-        connection = database.get_sqla_engine()
+        with database.get_sqla_engine_with_context() as engine:
+            connection = engine
 
     df.to_sql(
         dataset.table_name,
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index dabed0c7ae..9dd7594dc7 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -472,7 +472,10 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         schema: Optional[str] = None,
         source: Optional[utils.QuerySource] = None,
     ) -> Engine:
-        return database.get_sqla_engine(schema=schema, source=source)
+        with database.get_sqla_engine_with_context(
+            schema=schema, source=source
+        ) as engine:
+            return engine
 
     @classmethod
     def get_timestamp_expr(
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 96afc7f51e..6d9903c8f0 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -463,61 +463,66 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
             )
         )
 
-    engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
-    # Sharing a single connection and cursor across the
-    # execution of all statements (if many)
-    with closing(engine.raw_connection()) as conn:
-        # closing the connection closes the cursor as well
-        cursor = conn.cursor()
-        cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
-        if cancel_query_id is not None:
-            query.set_extra_json_key(cancel_query_key, cancel_query_id)
-            session.commit()
-        statement_count = len(statements)
-        for i, statement in enumerate(statements):
-            # Check if stopped
-            session.refresh(query)
-            if query.status == QueryStatus.STOPPED:
-                payload.update({"status": query.status})
-                return payload
-
-            # For CTAS we create the table only on the last statement
-            apply_ctas = query.select_as_cta and (
-                query.ctas_method == CtasMethod.VIEW
-                or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
-            )
-
-            # Run statement
-            msg = f"Running statement {i+1} out of {statement_count}"
-            logger.info("Query %s: %s", str(query_id), msg)
-            query.set_extra_json_key("progress", msg)
-            session.commit()
-            try:
-                result_set = execute_sql_statement(
-                    statement,
-                    query,
-                    session,
-                    cursor,
-                    log_params,
-                    apply_ctas,
-                )
-            except SqlLabQueryStoppedException:
-                payload.update({"status": QueryStatus.STOPPED})
-                return payload
-            except Exception as ex:  # pylint: disable=broad-except
-                msg = str(ex)
-                prefix_message = (
-                    f"[Statement {i+1} out of {statement_count}]"
-                    if statement_count > 1
-                    else ""
+    with database.get_sqla_engine_with_context(
+        query.schema, source=QuerySource.SQL_LAB
+    ) as engine:
+        # Sharing a single connection and cursor across the
+        # execution of all statements (if many)
+        with closing(engine.raw_connection()) as conn:
+            # closing the connection closes the cursor as well
+            cursor = conn.cursor()
+            cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
+            if cancel_query_id is not None:
+                query.set_extra_json_key(cancel_query_key, cancel_query_id)
+                session.commit()
+            statement_count = len(statements)
+            for i, statement in enumerate(statements):
+                # Check if stopped
+                session.refresh(query)
+                if query.status == QueryStatus.STOPPED:
+                    payload.update({"status": query.status})
+                    return payload
+
+                # For CTAS we create the table only on the last statement
+                apply_ctas = query.select_as_cta and (
+                    query.ctas_method == CtasMethod.VIEW
+                    or (
+                        query.ctas_method == CtasMethod.TABLE
+                        and i == len(statements) - 1
+                    )
                 )
-                payload = handle_query_error(
-                    ex, query, session, payload, prefix_message
-                )
-                return payload
 
-        # Commit the connection so CTA queries will create the table.
-        conn.commit()
+                # Run statement
+                msg = f"Running statement {i+1} out of {statement_count}"
+                logger.info("Query %s: %s", str(query_id), msg)
+                query.set_extra_json_key("progress", msg)
+                session.commit()
+                try:
+                    result_set = execute_sql_statement(
+                        statement,
+                        query,
+                        session,
+                        cursor,
+                        log_params,
+                        apply_ctas,
+                    )
+                except SqlLabQueryStoppedException:
+                    payload.update({"status": QueryStatus.STOPPED})
+                    return payload
+                except Exception as ex:  # pylint: disable=broad-except
+                    msg = str(ex)
+                    prefix_message = (
+                        f"[Statement {i+1} out of {statement_count}]"
+                        if statement_count > 1
+                        else ""
+                    )
+                    payload = handle_query_error(
+                        ex, query, session, payload, prefix_message
+                    )
+                    return payload
+
+            # Commit the connection so CTA queries will create the table.
+            conn.commit()
 
     # Success, updating the query entry in database
     query.rows = result_set.size
@@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool:
     if cancel_query_id is None:
         return False
 
-    engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB)
-
-    with closing(engine.raw_connection()) as conn:
-        with closing(conn.cursor()) as cursor:
-            return query.database.db_engine_spec.cancel_query(
-                cursor, query, cancel_query_id
-            )
+    with query.database.get_sqla_engine_with_context(
+        query.schema, source=QuerySource.SQL_LAB
+    ) as engine:
+        with closing(engine.raw_connection()) as conn:
+            with closing(conn.cursor()) as cursor:
+                return query.database.db_engine_spec.cancel_query(
+                    cursor, query, cancel_query_id
+                )