You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/12/07 22:36:18 UTC
[superset] 03/03: switch out all places using connection
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch refactor-tc
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 70f0a52d54ac24cb076fd30d58f7fe5572ce7abd
Author: hughhhh <hu...@gmail.com>
AuthorDate: Wed Dec 7 17:35:57 2022 -0500
switch out all places using connection
---
superset/connectors/sqla/utils.py | 42 +++++++++++++++++-------------------
superset/db_engine_specs/base.py | 16 +++++++-------
superset/db_engine_specs/gsheets.py | 4 +---
superset/db_engine_specs/presto.py | 31 ++++++++++++--------------
superset/sql_lab.py | 17 ++++++---------
superset/sql_validators/presto_db.py | 19 +++++++---------
6 files changed, 57 insertions(+), 72 deletions(-)
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 05cf8cea13..f9bcfc0bd8 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -136,18 +136,17 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
try:
- with dataset.database.get_sqla_engine_with_context(
+ with dataset.database.get_raw_connection(
schema=dataset.schema
- ) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
- db_engine_spec.execute(cursor, query)
- result = db_engine_spec.fetch_data(cursor, limit=1)
- result_set = SupersetResultSet(
- result, cursor.description, db_engine_spec
- )
- cols = result_set.columns
+ ) as conn:
+ cursor = conn.cursor()
+ query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
+ db_engine_spec.execute(cursor, query)
+ result = db_engine_spec.fetch_data(cursor, limit=1)
+ result_set = SupersetResultSet(
+ result, cursor.description, db_engine_spec
+ )
+ cols = result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols
@@ -159,17 +158,16 @@ def get_columns_description(
) -> List[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
- with database.get_sqla_engine_with_context() as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- query = database.apply_limit_to_sql(query, limit=1)
- cursor.execute(query)
- db_engine_spec.execute(cursor, query)
- result = db_engine_spec.fetch_data(cursor, limit=1)
- result_set = SupersetResultSet(
- result, cursor.description, db_engine_spec
- )
- return result_set.columns
+ with database.get_raw_connection() as conn:
+ cursor = conn.cursor()
+ query = database.apply_limit_to_sql(query, limit=1)
+ cursor.execute(query)
+ db_engine_spec.execute(cursor, query)
+ result = db_engine_spec.fetch_data(cursor, limit=1)
+ result_set = SupersetResultSet(
+ result, cursor.description, db_engine_spec
+ )
+ return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3e2c0f56ba..c6978a3fd2 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1296,14 +1296,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
statements = parsed_query.get_statements()
costs = []
- with cls.get_engine(database, schema=schema, source=source) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- for statement in statements:
- processed_statement = cls.process_statement(statement, database)
- costs.append(
- cls.estimate_statement_cost(processed_statement, cursor)
- )
+ with database.get_raw_connection(schema=schema, source=source) as conn:
+ cursor = conn.cursor()
+ for statement in statements:
+ processed_statement = cls.process_statement(statement, database)
+ costs.append(
+ cls.estimate_statement_cost(processed_statement, cursor)
+ )
+
return costs
@classmethod
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 805a7ee400..9b0dd3fb8f 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -109,12 +109,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
- with cls.get_engine(database, schema=schema_name) as engine:
- with closing(engine.raw_connection()) as conn:
+ with database.get_raw_connection(schema=schema_name) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
results = cursor.fetchone()[0]
-
try:
metadata = json.loads(results)
except Exception: # pylint: disable=broad-except
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 6755039734..c1d67f13eb 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -470,13 +470,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
).strip()
params = {}
- with cls.get_engine(database, schema=schema) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- cursor.execute(sql, params)
- results = cursor.fetchall()
-
- return {row[0] for row in results}
+ with database.get_raw_connection(schema=schema) as conn:
+ cursor = conn.cursor()
+ cursor.execute(sql, params)
+ results = cursor.fetchall()
+ return {row[0] for row in results}
@classmethod
def _create_column_info(
@@ -996,16 +994,15 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError
- with cls.get_engine(database, schema=schema) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- sql = f"SHOW CREATE VIEW {schema}.{table}"
- try:
- cls.execute(cursor, sql)
-
- except DatabaseError: # not a VIEW
- return None
- rows = cls.fetch_data(cursor, 1)
+ with database.get_raw_connection(schema=schema) as conn:
+ cursor = conn.cursor()
+ sql = f"SHOW CREATE VIEW {schema}.{table}"
+ try:
+ cls.execute(cursor, sql)
+ except DatabaseError: # not a VIEW
+ return None
+ rows = cls.fetch_data(cursor, 1)
+
return rows[0][0]
@classmethod
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 6d9903c8f0..f40d4c1daf 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -463,13 +463,9 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
)
)
- with database.get_sqla_engine_with_context(
- query.schema, source=QuerySource.SQL_LAB
- ) as engine:
- # Sharing a single connection and cursor across the
- # execution of all statements (if many)
- with closing(engine.raw_connection()) as conn:
- # closing the connection closes the cursor as well
+ with database.get_raw_connection(
+ schema=query.schema,
+ source=QuerySource.SQL_LAB) as conn:
cursor = conn.cursor()
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
@@ -627,10 +623,9 @@ def cancel_query(query: Query) -> bool:
if cancel_query_id is None:
return False
- with query.database.get_sqla_engine_with_context(
- query.schema, source=QuerySource.SQL_LAB
- ) as engine:
- with closing(engine.raw_connection()) as conn:
+ with query.database.get_raw_connection(
+ schema=query.schema,
+ source=QuerySource.SQL_LAB) as conn:
with closing(conn.cursor()) as cursor:
return query.database.db_engine_spec.cancel_query(
cursor, query, cancel_query_id
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 37375e484d..cbefe154d8 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -162,18 +162,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
- with database.get_sqla_engine_with_context(
- schema, source=QuerySource.SQL_LAB
- ) as engine:
+ with database.get_raw_connection(schema=schema, source=QuerySource.SQL_LAB) as conn:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
- annotations: List[SQLValidationAnnotation] = []
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- for statement in parsed_query.get_statements():
- annotation = cls.validate_statement(statement, database, cursor)
- if annotation:
- annotations.append(annotation)
+ annotations: List[SQLValidationAnnotation] = []
+ cursor = conn.cursor()
+ for statement in parsed_query.get_statements():
+ annotation = cls.validate_statement(statement, database, cursor)
+ if annotation:
+ annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))
-
+
return annotations