You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2023/08/11 10:54:10 UTC
[superset] branch master updated: fix: calls to `_get_sqla_engine` (#24953)
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 6f24a4e7a8 fix: calls to `_get_sqla_engine` (#24953)
6f24a4e7a8 is described below
commit 6f24a4e7a84cd25185b911c079aa622fb085fc29
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Aug 11 03:54:01 2023 -0700
fix: calls to `_get_sqla_engine` (#24953)
---
superset/db_engine_specs/trino.py | 7 ++--
superset/models/core.py | 51 +++++++++++-------------
tests/integration_tests/celery_tests.py | 5 +--
tests/integration_tests/charts/data/api_tests.py | 7 ++--
tests/integration_tests/model_tests.py | 3 +-
5 files changed, 36 insertions(+), 37 deletions(-)
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index f05bd67ec3..c3bdccc775 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -86,9 +86,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
}
if database.has_view_by_name(table_name, schema_name):
- metadata["view"] = database.inspector.get_view_definition(
- table_name, schema_name
- )
+ with database.get_inspector_with_context() as inspector:
+ metadata["view"] = inspector.get_view_definition(
+ table_name, schema_name
+ )
return metadata
diff --git a/superset/models/core.py b/superset/models/core.py
index e76da0dcd5..f59cd1159b 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -563,7 +563,8 @@ class Database(
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
- engine = self._get_sqla_engine(schema)
+ with self.get_sqla_engine_with_context(schema) as engine:
+ engine_url = engine.url
mutate_after_split = config["MUTATE_AFTER_SPLIT"]
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
@@ -577,7 +578,7 @@ class Database(
def _log_query(sql: str) -> None:
if log_query:
log_query(
- engine.url,
+ engine_url,
sql,
schema,
__name__,
@@ -624,13 +625,12 @@ class Database(
return df
def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
- engine = self._get_sqla_engine(schema=schema)
+ with self.get_sqla_engine_with_context(schema) as engine:
+ sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
- sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
-
- # pylint: disable=protected-access
- if engine.dialect.identifier_preparer._double_percents: # noqa
- sql = sql.replace("%%", "%")
+ # pylint: disable=protected-access
+ if engine.dialect.identifier_preparer._double_percents: # noqa
+ sql = sql.replace("%%", "%")
return sql
@@ -645,18 +645,18 @@ class Database(
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
- eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
- return self.db_engine_spec.select_star(
- self,
- table_name,
- schema=schema,
- engine=eng,
- limit=limit,
- show_cols=show_cols,
- indent=indent,
- latest_partition=latest_partition,
- cols=cols,
- )
+ with self.get_sqla_engine_with_context(schema) as engine:
+ return self.db_engine_spec.select_star(
+ self,
+ table_name,
+ schema=schema,
+ engine=engine,
+ limit=limit,
+ show_cols=show_cols,
+ indent=indent,
+ latest_partition=latest_partition,
+ cols=cols,
+ )
def apply_limit_to_sql(
self, sql: str, limit: int = 1000, force: bool = False
@@ -668,11 +668,6 @@ class Database(
def safe_sqlalchemy_uri(self) -> str:
return self.sqlalchemy_uri
- @property
- def inspector(self) -> Inspector:
- engine = self._get_sqla_engine()
- return sqla.inspect(engine)
-
@cache_util.memoized_func(
key="db:{self.id}:schema:{schema}:table_list",
cache=cache_manager.cache,
@@ -955,8 +950,10 @@ class Database(
return view_name in view_names
def has_view(self, view_name: str, schema: str | None = None) -> bool:
- engine = self._get_sqla_engine()
- return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
+ with self.get_sqla_engine_with_context(schema) as engine:
+ return engine.run_callable(
+ self._has_view, engine.dialect, view_name, schema
+ )
def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
return self.has_view(view_name=view_name, schema=schema)
diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py
index 8693a88887..29a1f7a66a 100644
--- a/tests/integration_tests/celery_tests.py
+++ b/tests/integration_tests/celery_tests.py
@@ -120,9 +120,8 @@ def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
def quote_f(value: Optional[str]):
if not value:
return value
- return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
- value
- )
+ with get_example_database().get_inspector_with_context() as inspector:
+ return inspector.engine.dialect.identifier_preparer.quote_identifier(value)
def cta_result(ctas_method: CtasMethod):
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index da3a28f1ba..dc82026986 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -113,9 +113,10 @@ class BaseTestChartDataApi(SupersetTestCase):
def quote_name(self, name: str):
if get_main_database().backend in {"presto", "hive"}:
- return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier(
- name
- )
+ with get_example_database().get_inspector_with_context() as inspector: # E: Ne
+ return inspector.engine.dialect.identifier_preparer.quote_identifier(
+ name
+ )
return name
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 3a5f7c0a77..5222c1cb34 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -296,7 +296,8 @@ class TestDatabaseModel(SupersetTestCase):
db = get_example_database()
table_name = "energy_usage"
sql = db.select_star(table_name, show_cols=False, latest_partition=False)
- quote = db.inspector.engine.dialect.identifier_preparer.quote_identifier
+ with db.get_sqla_engine_with_context() as engine:
+ quote = engine.dialect.identifier_preparer.quote_identifier
expected = (
textwrap.dedent(
f"""\