You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/12/29 19:20:52 UTC
[superset] branch master updated: chore(ssh-tunnel): Refactor establishing raw connection with contextmanger (#22366)
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 9c0d6c51f1 chore(ssh-tunnel): Refactor establishing raw connection with contextmanger (#22366)
9c0d6c51f1 is described below
commit 9c0d6c51f154cf97f4323d634282239fead40aa6
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Thu Dec 29 14:20:45 2022 -0500
chore(ssh-tunnel): Refactor establishing raw connection with contextmanger (#22366)
---
superset/connectors/sqla/utils.py | 39 +++-----
superset/db_engine_specs/base.py | 15 ++-
superset/db_engine_specs/gsheets.py | 11 +--
superset/db_engine_specs/presto.py | 30 +++---
superset/models/core.py | 13 +++
superset/sql_lab.py | 104 ++++++++++-----------
superset/sql_validators/presto_db.py | 2 +
.../db_engine_specs/presto_tests.py | 32 ++-----
tests/integration_tests/sqllab_tests.py | 6 +-
9 files changed, 112 insertions(+), 140 deletions(-)
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 05cf8cea13..e3745dac2a 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import logging
-from contextlib import closing
from typing import (
Any,
Callable,
@@ -136,18 +135,13 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
# TODO(villebro): refactor to use same code that's used by
# sql_lab.py:execute_sql_statements
try:
- with dataset.database.get_sqla_engine_with_context(
- schema=dataset.schema
- ) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
- db_engine_spec.execute(cursor, query)
- result = db_engine_spec.fetch_data(cursor, limit=1)
- result_set = SupersetResultSet(
- result, cursor.description, db_engine_spec
- )
- cols = result_set.columns
+ with dataset.database.get_raw_connection(schema=dataset.schema) as conn:
+ cursor = conn.cursor()
+ query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
+ db_engine_spec.execute(cursor, query)
+ result = db_engine_spec.fetch_data(cursor, limit=1)
+ result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
+ cols = result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
return cols
@@ -159,17 +153,14 @@ def get_columns_description(
) -> List[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
- with database.get_sqla_engine_with_context() as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- query = database.apply_limit_to_sql(query, limit=1)
- cursor.execute(query)
- db_engine_spec.execute(cursor, query)
- result = db_engine_spec.fetch_data(cursor, limit=1)
- result_set = SupersetResultSet(
- result, cursor.description, db_engine_spec
- )
- return result_set.columns
+ with database.get_raw_connection() as conn:
+ cursor = conn.cursor()
+ query = database.apply_limit_to_sql(query, limit=1)
+ cursor.execute(query)
+ db_engine_spec.execute(cursor, query)
+ result = db_engine_spec.fetch_data(cursor, limit=1)
+ result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
+ return result_set.columns
except Exception as ex:
raise SupersetGenericDBErrorException(message=str(ex)) from ex
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 43dd607876..0f124de34a 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -21,7 +21,6 @@ from __future__ import annotations
import json
import logging
import re
-from contextlib import closing
from datetime import datetime
from typing import (
Any,
@@ -1299,14 +1298,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
statements = parsed_query.get_statements()
costs = []
- with cls.get_engine(database, schema=schema, source=source) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- for statement in statements:
- processed_statement = cls.process_statement(statement, database)
- costs.append(
- cls.estimate_statement_cost(processed_statement, cursor)
- )
+ with database.get_raw_connection(schema=schema, source=source) as conn:
+ cursor = conn.cursor()
+ for statement in statements:
+ processed_statement = cls.process_statement(statement, database)
+ costs.append(cls.estimate_statement_cost(processed_statement, cursor))
+
return costs
@classmethod
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 805a7ee400..9438f4d566 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -16,7 +16,6 @@
# under the License.
import json
import re
-from contextlib import closing
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from apispec import APISpec
@@ -109,12 +108,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
- with cls.get_engine(database, schema=schema_name) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- cursor.execute(f'SELECT GET_METADATA("{table_name}")')
- results = cursor.fetchone()[0]
-
+ with database.get_raw_connection(schema=schema_name) as conn:
+ cursor = conn.cursor()
+ cursor.execute(f'SELECT GET_METADATA("{table_name}")')
+ results = cursor.fetchone()[0]
try:
metadata = json.loads(results)
except Exception: # pylint: disable=broad-except
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index ba5df3e28d..e1aa918879 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -22,7 +22,6 @@ import re
import time
from abc import ABCMeta
from collections import defaultdict, deque
-from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
@@ -667,13 +666,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
).strip()
params = {}
- with cls.get_engine(database, schema=schema) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- cursor.execute(sql, params)
- results = cursor.fetchall()
-
- return {row[0] for row in results}
+ with database.get_raw_connection(schema=schema) as conn:
+ cursor = conn.cursor()
+ cursor.execute(sql, params)
+ results = cursor.fetchall()
+ return {row[0] for row in results}
@classmethod
def _create_column_info(
@@ -1196,16 +1193,15 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# pylint: disable=import-outside-toplevel
from pyhive.exc import DatabaseError
- with cls.get_engine(database, schema=schema) as engine:
- with closing(engine.raw_connection()) as conn:
- cursor = conn.cursor()
- sql = f"SHOW CREATE VIEW {schema}.{table}"
- try:
- cls.execute(cursor, sql)
+ with database.get_raw_connection(schema=schema) as conn:
+ cursor = conn.cursor()
+ sql = f"SHOW CREATE VIEW {schema}.{table}"
+ try:
+ cls.execute(cursor, sql)
+ except DatabaseError: # not a VIEW
+ return None
+ rows = cls.fetch_data(cursor, 1)
- except DatabaseError: # not a VIEW
- return None
- rows = cls.fetch_data(cursor, 1)
return rows[0][0]
@classmethod
diff --git a/superset/models/core.py b/superset/models/core.py
index 1f2fa7f71c..12ce9ef95e 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -429,6 +429,19 @@ class Database(
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
+ @contextmanager
+ def get_raw_connection(
+ self,
+ schema: Optional[str] = None,
+ nullpool: bool = True,
+ source: Optional[utils.QuerySource] = None,
+ ) -> Connection:
+ with self.get_sqla_engine_with_context(
+ schema=schema, nullpool=nullpool, source=source
+ ) as engine:
+ with closing(engine.raw_connection()) as conn:
+ yield conn
+
@property
def quote_identifier(self) -> Callable[[str], str]:
"""Add quotes to potential identifiter expressions if needed"""
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 143806c7f5..c8b3bca2b1 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -464,66 +464,56 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
)
)
- with database.get_sqla_engine_with_context(
- query.schema, source=QuerySource.SQL_LAB
- ) as engine:
+ with database.get_raw_connection(query.schema, source=QuerySource.SQL_LAB) as conn:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
- with closing(engine.raw_connection()) as conn:
- # closing the connection closes the cursor as well
- cursor = conn.cursor()
- cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
- if cancel_query_id is not None:
- query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
- session.commit()
- statement_count = len(statements)
- for i, statement in enumerate(statements):
- # Check if stopped
- session.refresh(query)
- if query.status == QueryStatus.STOPPED:
- payload.update({"status": query.status})
- return payload
-
- # For CTAS we create the table only on the last statement
- apply_ctas = query.select_as_cta and (
- query.ctas_method == CtasMethod.VIEW
- or (
- query.ctas_method == CtasMethod.TABLE
- and i == len(statements) - 1
- )
+ cursor = conn.cursor()
+ cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
+ if cancel_query_id is not None:
+ query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
+ session.commit()
+ statement_count = len(statements)
+ for i, statement in enumerate(statements):
+ # Check if stopped
+ session.refresh(query)
+ if query.status == QueryStatus.STOPPED:
+ payload.update({"status": query.status})
+ return payload
+ # For CTAS we create the table only on the last statement
+ apply_ctas = query.select_as_cta and (
+ query.ctas_method == CtasMethod.VIEW
+ or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
+ )
+ # Run statement
+ msg = f"Running statement {i+1} out of {statement_count}"
+ logger.info("Query %s: %s", str(query_id), msg)
+ query.set_extra_json_key("progress", msg)
+ session.commit()
+ try:
+ result_set = execute_sql_statement(
+ statement,
+ query,
+ session,
+ cursor,
+ log_params,
+ apply_ctas,
)
-
- # Run statement
- msg = f"Running statement {i+1} out of {statement_count}"
- logger.info("Query %s: %s", str(query_id), msg)
- query.set_extra_json_key("progress", msg)
- session.commit()
- try:
- result_set = execute_sql_statement(
- statement,
- query,
- session,
- cursor,
- log_params,
- apply_ctas,
- )
- except SqlLabQueryStoppedException:
- payload.update({"status": QueryStatus.STOPPED})
- return payload
- except Exception as ex: # pylint: disable=broad-except
- msg = str(ex)
- prefix_message = (
- f"[Statement {i+1} out of {statement_count}]"
- if statement_count > 1
- else ""
- )
- payload = handle_query_error(
- ex, query, session, payload, prefix_message
- )
- return payload
-
- # Commit the connection so CTA queries will create the table.
- conn.commit()
+ except SqlLabQueryStoppedException:
+ payload.update({"status": QueryStatus.STOPPED})
+ return payload
+ except Exception as ex: # pylint: disable=broad-except
+ msg = str(ex)
+ prefix_message = (
+ f"[Statement {i+1} out of {statement_count}]"
+ if statement_count > 1
+ else ""
+ )
+ payload = handle_query_error(
+ ex, query, session, payload, prefix_message
+ )
+ return payload
+ # Commit the connection so CTA queries will create the table.
+ conn.commit()
# Success, updating the query entry in database
query.rows = result_set.size
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 37375e484d..5bc844751b 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -162,6 +162,8 @@ class PrestoDBSQLValidator(BaseSQLValidator):
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
+ # todo(hughhh): update this to use new database.get_raw_connection()
+ # this function keeps stalling CI
with database.get_sqla_engine_with_context(
schema, source=QuerySource.SQL_LAB
) as engine:
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index eef3bb8d36..9099dbb7d7 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -37,10 +37,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_view_names_with_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
- mock_execute
- )
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
+ database.get_raw_connection().__enter__().cursor().execute = mock_execute
+ database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
@@ -61,10 +59,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_view_names_without_schema(self):
database = mock.MagicMock()
mock_execute = mock.MagicMock()
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
- mock_execute
- )
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
+ database.get_raw_connection().__enter__().cursor().execute = mock_execute
+ database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
return_value=[["a", "b,", "c"], ["d", "e"]]
)
result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
@@ -823,15 +819,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_execute = mock.MagicMock()
mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
database = mock.MagicMock()
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
- mock_execute
- )
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
- mock_fetchall
- )
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
- False
- )
+ database.get_raw_connection().__enter__().cursor().execute = mock_execute
+ database.get_raw_connection().__enter__().cursor().fetchall = mock_fetchall
+ database.get_raw_connection().__enter__().cursor().return_value = False
schema = "schema"
table = "table"
result = PrestoEngineSpec.get_create_view(database, schema=schema, table=table)
@@ -841,9 +831,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
def test_get_create_view_exception(self):
mock_execute = mock.MagicMock(side_effect=Exception())
database = mock.MagicMock()
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
- mock_execute
- )
+ database.get_raw_connection().__enter__().cursor().execute = mock_execute
schema = "schema"
table = "table"
with self.assertRaises(Exception):
@@ -854,9 +842,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
mock_execute = mock.MagicMock(side_effect=DatabaseError())
database = mock.MagicMock()
- database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
- mock_execute
- )
+ database.get_raw_connection().__enter__().cursor().execute = mock_execute
schema = "schema"
table = "table"
result = PrestoEngineSpec.get_create_view(database, schema=schema, table=table)
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index ed37eece96..b1b0480d56 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -733,7 +733,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
- mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
+ mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@@ -786,7 +786,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = True
mock_cursor = mock.MagicMock()
- mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
+ mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@@ -836,7 +836,7 @@ class TestSqlLab(SupersetTestCase):
mock_query = mock.MagicMock()
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
- mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
+ mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
mock_cursor
)
mock_query.database.db_engine_spec.run_multiple_statements_as_one = False