You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/08/14 12:27:04 UTC

[superset] 02/06: fix: calls to `_get_sqla_engine` (#24953)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit ff2ec231029c508196b4e40538c78b24e54493ba
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Aug 11 03:54:01 2023 -0700

    fix: calls to `_get_sqla_engine` (#24953)
    
    (cherry picked from commit 6f24a4e7a84cd25185b911c079aa622fb085fc29)
---
 superset/db_engine_specs/trino.py                |  7 ++--
 superset/models/core.py                          | 51 +++++++++++-------------
 tests/integration_tests/celery_tests.py          |  5 +--
 tests/integration_tests/charts/data/api_tests.py |  7 ++--
 tests/integration_tests/model_tests.py           |  3 +-
 5 files changed, 36 insertions(+), 37 deletions(-)

diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index f05bd67ec3..c3bdccc775 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -86,9 +86,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
             }
 
         if database.has_view_by_name(table_name, schema_name):
-            metadata["view"] = database.inspector.get_view_definition(
-                table_name, schema_name
-            )
+            with database.get_inspector_with_context() as inspector:
+                metadata["view"] = inspector.get_view_definition(
+                    table_name, schema_name
+                )
 
         return metadata
 
diff --git a/superset/models/core.py b/superset/models/core.py
index 21f8eddba4..e3f91e1379 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -563,7 +563,8 @@ class Database(
         mutator: Callable[[pd.DataFrame], None] | None = None,
     ) -> pd.DataFrame:
         sqls = self.db_engine_spec.parse_sql(sql)
-        engine = self._get_sqla_engine(schema)
+        with self.get_sqla_engine_with_context(schema) as engine:
+            engine_url = engine.url
         mutate_after_split = config["MUTATE_AFTER_SPLIT"]
         sql_query_mutator = config["SQL_QUERY_MUTATOR"]
 
@@ -577,7 +578,7 @@ class Database(
         def _log_query(sql: str) -> None:
             if log_query:
                 log_query(
-                    engine.url,
+                    engine_url,
                     sql,
                     schema,
                     __name__,
@@ -624,13 +625,12 @@ class Database(
             return df
 
     def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
-        engine = self._get_sqla_engine(schema=schema)
+        with self.get_sqla_engine_with_context(schema) as engine:
+            sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
 
-        sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
-
-        # pylint: disable=protected-access
-        if engine.dialect.identifier_preparer._double_percents:  # noqa
-            sql = sql.replace("%%", "%")
+            # pylint: disable=protected-access
+            if engine.dialect.identifier_preparer._double_percents:  # noqa
+                sql = sql.replace("%%", "%")
 
         return sql
 
@@ -645,18 +645,18 @@ class Database(
         cols: list[ResultSetColumnType] | None = None,
     ) -> str:
         """Generates a ``select *`` statement in the proper dialect"""
-        eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
-        return self.db_engine_spec.select_star(
-            self,
-            table_name,
-            schema=schema,
-            engine=eng,
-            limit=limit,
-            show_cols=show_cols,
-            indent=indent,
-            latest_partition=latest_partition,
-            cols=cols,
-        )
+        with self.get_sqla_engine_with_context(schema) as engine:
+            return self.db_engine_spec.select_star(
+                self,
+                table_name,
+                schema=schema,
+                engine=engine,
+                limit=limit,
+                show_cols=show_cols,
+                indent=indent,
+                latest_partition=latest_partition,
+                cols=cols,
+            )
 
     def apply_limit_to_sql(
         self, sql: str, limit: int = 1000, force: bool = False
@@ -668,11 +668,6 @@ class Database(
     def safe_sqlalchemy_uri(self) -> str:
         return self.sqlalchemy_uri
 
-    @property
-    def inspector(self) -> Inspector:
-        engine = self._get_sqla_engine()
-        return sqla.inspect(engine)
-
     @cache_util.memoized_func(
         key="db:{self.id}:schema:{schema}:table_list",
         cache=cache_manager.cache,
@@ -955,8 +950,10 @@ class Database(
         return view_name in view_names
 
     def has_view(self, view_name: str, schema: str | None = None) -> bool:
-        engine = self._get_sqla_engine()
-        return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
+        with self.get_sqla_engine_with_context(schema) as engine:
+            return engine.run_callable(
+                self._has_view, engine.dialect, view_name, schema
+            )
 
     def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
         return self.has_view(view_name=view_name, schema=schema)
diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py
index 8693a88887..29a1f7a66a 100644
--- a/tests/integration_tests/celery_tests.py
+++ b/tests/integration_tests/celery_tests.py
@@ -120,9 +120,8 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
 def quote_f(value: Optional[str]):
     if not value:
         return value
-    return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
-        value
-    )
+    with get_example_database().get_inspector_with_context() as inspector:
+        return inspector.engine.dialect.identifier_preparer.quote_identifier(value)
 
 
 def cta_result(ctas_method: CtasMethod):
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index da3a28f1ba..dc82026986 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -113,9 +113,10 @@ class BaseTestChartDataApi(SupersetTestCase):
 
     def quote_name(self, name: str):
         if get_main_database().backend in {"presto", "hive"}:
-            return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
-                name
-            )
+            with get_example_database().get_inspector_with_context() as inspector:  # E: Ne
+                return inspector.engine.dialect.identifier_preparer.quote_identifier(
+                    name
+                )
         return name
 
 
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 3a5f7c0a77..5222c1cb34 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -296,7 +296,8 @@ class TestDatabaseModel(SupersetTestCase):
         db = get_example_database()
         table_name = "energy_usage"
         sql = db.select_star(table_name, show_cols=False, latest_partition=False)
-        quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier
+        with db.get_sqla_engine_with_context() as engine:
+            quote = engine.dialect.identifier_preparer.quote_identifier
         expected = (
             textwrap.dedent(
                 f"""\