You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/12/29 19:20:52 UTC

[superset] branch master updated: chore(ssh-tunnel): Refactor establishing raw connection with contextmanger (#22366)

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c0d6c51f1 chore(ssh-tunnel): Refactor establishing raw connection with contextmanger (#22366)
9c0d6c51f1 is described below

commit 9c0d6c51f154cf97f4323d634282239fead40aa6
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Thu Dec 29 14:20:45 2022 -0500

    chore(ssh-tunnel): Refactor establishing raw connection with contextmanger (#22366)
---
 superset/connectors/sqla/utils.py                  |  39 +++-----
 superset/db_engine_specs/base.py                   |  15 ++-
 superset/db_engine_specs/gsheets.py                |  11 +--
 superset/db_engine_specs/presto.py                 |  30 +++---
 superset/models/core.py                            |  13 +++
 superset/sql_lab.py                                | 104 ++++++++++-----------
 superset/sql_validators/presto_db.py               |   2 +
 .../db_engine_specs/presto_tests.py                |  32 ++-----
 tests/integration_tests/sqllab_tests.py            |   6 +-
 9 files changed, 112 insertions(+), 140 deletions(-)

diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 05cf8cea13..e3745dac2a 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -17,7 +17,6 @@
 from __future__ import annotations
 
 import logging
-from contextlib import closing
 from typing import (
     Any,
     Callable,
@@ -136,18 +135,13 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
     # TODO(villebro): refactor to use same code that's used by
     #  sql_lab.py:execute_sql_statements
     try:
-        with dataset.database.get_sqla_engine_with_context(
-            schema=dataset.schema
-        ) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
-                db_engine_spec.execute(cursor, query)
-                result = db_engine_spec.fetch_data(cursor, limit=1)
-                result_set = SupersetResultSet(
-                    result, cursor.description, db_engine_spec
-                )
-                cols = result_set.columns
+        with dataset.database.get_raw_connection(schema=dataset.schema) as conn:
+            cursor = conn.cursor()
+            query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
+            db_engine_spec.execute(cursor, query)
+            result = db_engine_spec.fetch_data(cursor, limit=1)
+            result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
+            cols = result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
     return cols
@@ -159,17 +153,14 @@ def get_columns_description(
 ) -> List[ResultSetColumnType]:
     db_engine_spec = database.db_engine_spec
     try:
-        with database.get_sqla_engine_with_context() as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                query = database.apply_limit_to_sql(query, limit=1)
-                cursor.execute(query)
-                db_engine_spec.execute(cursor, query)
-                result = db_engine_spec.fetch_data(cursor, limit=1)
-                result_set = SupersetResultSet(
-                    result, cursor.description, db_engine_spec
-                )
-                return result_set.columns
+        with database.get_raw_connection() as conn:
+            cursor = conn.cursor()
+            query = database.apply_limit_to_sql(query, limit=1)
+            cursor.execute(query)
+            db_engine_spec.execute(cursor, query)
+            result = db_engine_spec.fetch_data(cursor, limit=1)
+            result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
+            return result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
 
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 43dd607876..0f124de34a 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -21,7 +21,6 @@ from __future__ import annotations
 import json
 import logging
 import re
-from contextlib import closing
 from datetime import datetime
 from typing import (
     Any,
@@ -1299,14 +1298,12 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         statements = parsed_query.get_statements()
 
         costs = []
-        with cls.get_engine(database, schema=schema, source=source) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                for statement in statements:
-                    processed_statement = cls.process_statement(statement, database)
-                    costs.append(
-                        cls.estimate_statement_cost(processed_statement, cursor)
-                    )
+        with database.get_raw_connection(schema=schema, source=source) as conn:
+            cursor = conn.cursor()
+            for statement in statements:
+                processed_statement = cls.process_statement(statement, database)
+                costs.append(cls.estimate_statement_cost(processed_statement, cursor))
+
         return costs
 
     @classmethod
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 805a7ee400..9438f4d566 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -16,7 +16,6 @@
 # under the License.
 import json
 import re
-from contextlib import closing
 from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
 
 from apispec import APISpec
@@ -109,12 +108,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
         table_name: str,
         schema_name: Optional[str],
     ) -> Dict[str, Any]:
-        with cls.get_engine(database, schema=schema_name) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                cursor.execute(f'SELECT GET_METADATA("{table_name}")')
-                results = cursor.fetchone()[0]
-
+        with database.get_raw_connection(schema=schema_name) as conn:
+            cursor = conn.cursor()
+            cursor.execute(f'SELECT GET_METADATA("{table_name}")')
+            results = cursor.fetchone()[0]
         try:
             metadata = json.loads(results)
         except Exception:  # pylint: disable=broad-except
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index ba5df3e28d..e1aa918879 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -22,7 +22,6 @@ import re
 import time
 from abc import ABCMeta
 from collections import defaultdict, deque
-from contextlib import closing
 from datetime import datetime
 from distutils.version import StrictVersion
 from textwrap import dedent
@@ -667,13 +666,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
             ).strip()
             params = {}
 
-        with cls.get_engine(database, schema=schema) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                cursor.execute(sql, params)
-                results = cursor.fetchall()
-
-        return {row[0] for row in results}
+        with database.get_raw_connection(schema=schema) as conn:
+            cursor = conn.cursor()
+            cursor.execute(sql, params)
+            results = cursor.fetchall()
+            return {row[0] for row in results}
 
     @classmethod
     def _create_column_info(
@@ -1196,16 +1193,15 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         # pylint: disable=import-outside-toplevel
         from pyhive.exc import DatabaseError
 
-        with cls.get_engine(database, schema=schema) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                sql = f"SHOW CREATE VIEW {schema}.{table}"
-                try:
-                    cls.execute(cursor, sql)
+        with database.get_raw_connection(schema=schema) as conn:
+            cursor = conn.cursor()
+            sql = f"SHOW CREATE VIEW {schema}.{table}"
+            try:
+                cls.execute(cursor, sql)
+            except DatabaseError:  # not a VIEW
+                return None
+            rows = cls.fetch_data(cursor, 1)
 
-                except DatabaseError:  # not a VIEW
-                    return None
-                rows = cls.fetch_data(cursor, 1)
             return rows[0][0]
 
     @classmethod
diff --git a/superset/models/core.py b/superset/models/core.py
index 1f2fa7f71c..12ce9ef95e 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -429,6 +429,19 @@ class Database(
         except Exception as ex:
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
 
+    @contextmanager
+    def get_raw_connection(
+        self,
+        schema: Optional[str] = None,
+        nullpool: bool = True,
+        source: Optional[utils.QuerySource] = None,
+    ) -> Connection:
+        with self.get_sqla_engine_with_context(
+            schema=schema, nullpool=nullpool, source=source
+        ) as engine:
+            with closing(engine.raw_connection()) as conn:
+                yield conn
+
     @property
     def quote_identifier(self) -> Callable[[str], str]:
         """Add quotes to potential identifiter expressions if needed"""
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 143806c7f5..c8b3bca2b1 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -464,66 +464,56 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
             )
         )
 
-    with database.get_sqla_engine_with_context(
-        query.schema, source=QuerySource.SQL_LAB
-    ) as engine:
+    with database.get_raw_connection(query.schema, source=QuerySource.SQL_LAB) as conn:
         # Sharing a single connection and cursor across the
         # execution of all statements (if many)
-        with closing(engine.raw_connection()) as conn:
-            # closing the connection closes the cursor as well
-            cursor = conn.cursor()
-            cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
-            if cancel_query_id is not None:
-                query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
-                session.commit()
-            statement_count = len(statements)
-            for i, statement in enumerate(statements):
-                # Check if stopped
-                session.refresh(query)
-                if query.status == QueryStatus.STOPPED:
-                    payload.update({"status": query.status})
-                    return payload
-
-                # For CTAS we create the table only on the last statement
-                apply_ctas = query.select_as_cta and (
-                    query.ctas_method == CtasMethod.VIEW
-                    or (
-                        query.ctas_method == CtasMethod.TABLE
-                        and i == len(statements) - 1
-                    )
+        cursor = conn.cursor()
+        cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
+        if cancel_query_id is not None:
+            query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
+            session.commit()
+        statement_count = len(statements)
+        for i, statement in enumerate(statements):
+            # Check if stopped
+            session.refresh(query)
+            if query.status == QueryStatus.STOPPED:
+                payload.update({"status": query.status})
+                return payload
+            # For CTAS we create the table only on the last statement
+            apply_ctas = query.select_as_cta and (
+                query.ctas_method == CtasMethod.VIEW
+                or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1)
+            )
+            # Run statement
+            msg = f"Running statement {i+1} out of {statement_count}"
+            logger.info("Query %s: %s", str(query_id), msg)
+            query.set_extra_json_key("progress", msg)
+            session.commit()
+            try:
+                result_set = execute_sql_statement(
+                    statement,
+                    query,
+                    session,
+                    cursor,
+                    log_params,
+                    apply_ctas,
                 )
-
-                # Run statement
-                msg = f"Running statement {i+1} out of {statement_count}"
-                logger.info("Query %s: %s", str(query_id), msg)
-                query.set_extra_json_key("progress", msg)
-                session.commit()
-                try:
-                    result_set = execute_sql_statement(
-                        statement,
-                        query,
-                        session,
-                        cursor,
-                        log_params,
-                        apply_ctas,
-                    )
-                except SqlLabQueryStoppedException:
-                    payload.update({"status": QueryStatus.STOPPED})
-                    return payload
-                except Exception as ex:  # pylint: disable=broad-except
-                    msg = str(ex)
-                    prefix_message = (
-                        f"[Statement {i+1} out of {statement_count}]"
-                        if statement_count > 1
-                        else ""
-                    )
-                    payload = handle_query_error(
-                        ex, query, session, payload, prefix_message
-                    )
-                    return payload
-
-            # Commit the connection so CTA queries will create the table.
-            conn.commit()
+            except SqlLabQueryStoppedException:
+                payload.update({"status": QueryStatus.STOPPED})
+                return payload
+            except Exception as ex:  # pylint: disable=broad-except
+                msg = str(ex)
+                prefix_message = (
+                    f"[Statement {i+1} out of {statement_count}]"
+                    if statement_count > 1
+                    else ""
+                )
+                payload = handle_query_error(
+                    ex, query, session, payload, prefix_message
+                )
+                return payload
+        # Commit the connection so CTA queries will create the table.
+        conn.commit()
 
     # Success, updating the query entry in database
     query.rows = result_set.size
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 37375e484d..5bc844751b 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -162,6 +162,8 @@ class PrestoDBSQLValidator(BaseSQLValidator):
         statements = parsed_query.get_statements()
 
         logger.info("Validating %i statement(s)", len(statements))
+        # todo(hughhh): update this to use new database.get_raw_connection()
+        # this function keeps stalling CI
         with database.get_sqla_engine_with_context(
             schema, source=QuerySource.SQL_LAB
         ) as engine:
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index eef3bb8d36..9099dbb7d7 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -37,10 +37,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
     def test_get_view_names_with_schema(self):
         database = mock.MagicMock()
         mock_execute = mock.MagicMock()
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
-            mock_execute
-        )
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
+        database.get_raw_connection().__enter__().cursor().execute = mock_execute
+        database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
             return_value=[["a", "b,", "c"], ["d", "e"]]
         )
 
@@ -61,10 +59,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
     def test_get_view_names_without_schema(self):
         database = mock.MagicMock()
         mock_execute = mock.MagicMock()
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
-            mock_execute
-        )
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock(
+        database.get_raw_connection().__enter__().cursor().execute = mock_execute
+        database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock(
             return_value=[["a", "b,", "c"], ["d", "e"]]
         )
         result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None)
@@ -823,15 +819,9 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
         mock_execute = mock.MagicMock()
         mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]])
         database = mock.MagicMock()
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
-            mock_execute
-        )
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = (
-            mock_fetchall
-        )
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = (
-            False
-        )
+        database.get_raw_connection().__enter__().cursor().execute = mock_execute
+        database.get_raw_connection().__enter__().cursor().fetchall = mock_fetchall
+        database.get_raw_connection().__enter__().cursor().return_value = False
         schema = "schema"
         table = "table"
         result = PrestoEngineSpec.get_create_view(database, schema=schema, table=table)
@@ -841,9 +831,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
     def test_get_create_view_exception(self):
         mock_execute = mock.MagicMock(side_effect=Exception())
         database = mock.MagicMock()
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
-            mock_execute
-        )
+        database.get_raw_connection().__enter__().cursor().execute = mock_execute
         schema = "schema"
         table = "table"
         with self.assertRaises(Exception):
@@ -854,9 +842,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
 
         mock_execute = mock.MagicMock(side_effect=DatabaseError())
         database = mock.MagicMock()
-        database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = (
-            mock_execute
-        )
+        database.get_raw_connection().__enter__().cursor().execute = mock_execute
         schema = "schema"
         table = "table"
         result = PrestoEngineSpec.get_create_view(database, schema=schema, table=table)
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index ed37eece96..b1b0480d56 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -733,7 +733,7 @@ class TestSqlLab(SupersetTestCase):
         mock_query = mock.MagicMock()
         mock_query.database.allow_run_async = False
         mock_cursor = mock.MagicMock()
-        mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
+        mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
             mock_cursor
         )
         mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@@ -786,7 +786,7 @@ class TestSqlLab(SupersetTestCase):
         mock_query = mock.MagicMock()
         mock_query.database.allow_run_async = True
         mock_cursor = mock.MagicMock()
-        mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
+        mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
             mock_cursor
         )
         mock_query.database.db_engine_spec.run_multiple_statements_as_one = False
@@ -836,7 +836,7 @@ class TestSqlLab(SupersetTestCase):
         mock_query = mock.MagicMock()
         mock_query.database.allow_run_async = False
         mock_cursor = mock.MagicMock()
-        mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = (
+        mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
             mock_cursor
         )
         mock_query.database.db_engine_spec.run_multiple_statements_as_one = False