You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/12/07 22:36:18 UTC

[superset] 03/03: switch out all places using connection

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch refactor-tc
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 70f0a52d54ac24cb076fd30d58f7fe5572ce7abd
Author: hughhhh <hu...@gmail.com>
AuthorDate: Wed Dec 7 17:35:57 2022 -0500

    switch out all places using connection
---
 superset/connectors/sqla/utils.py    | 42 +++++++++++++++++-------------------
 superset/db_engine_specs/base.py     | 16 +++++++-------
 superset/db_engine_specs/gsheets.py  |  4 +---
 superset/db_engine_specs/presto.py   | 31 ++++++++++++--------------
 superset/sql_lab.py                  | 17 ++++++---------
 superset/sql_validators/presto_db.py | 19 +++++++---------
 6 files changed, 57 insertions(+), 72 deletions(-)

diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 05cf8cea13..f9bcfc0bd8 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -136,18 +136,17 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
     # TODO(villebro): refactor to use same code that's used by
     #  sql_lab.py:execute_sql_statements
     try:
-        with dataset.database.get_sqla_engine_with_context(
+        with dataset.database.get_raw_connection(
             schema=dataset.schema
-        ) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
-                db_engine_spec.execute(cursor, query)
-                result = db_engine_spec.fetch_data(cursor, limit=1)
-                result_set = SupersetResultSet(
-                    result, cursor.description, db_engine_spec
-                )
-                cols = result_set.columns
+        ) as conn:
+            cursor = conn.cursor()
+            query = dataset.database.apply_limit_to_sql(statements[0], limit=1)
+            db_engine_spec.execute(cursor, query)
+            result = db_engine_spec.fetch_data(cursor, limit=1)
+            result_set = SupersetResultSet(
+                result, cursor.description, db_engine_spec
+            )
+            cols = result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
     return cols
@@ -159,17 +158,16 @@ def get_columns_description(
 ) -> List[ResultSetColumnType]:
     db_engine_spec = database.db_engine_spec
     try:
-        with database.get_sqla_engine_with_context() as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                query = database.apply_limit_to_sql(query, limit=1)
-                cursor.execute(query)
-                db_engine_spec.execute(cursor, query)
-                result = db_engine_spec.fetch_data(cursor, limit=1)
-                result_set = SupersetResultSet(
-                    result, cursor.description, db_engine_spec
-                )
-                return result_set.columns
+        with database.get_raw_connection() as conn:           
+            cursor = conn.cursor()
+            query = database.apply_limit_to_sql(query, limit=1)
+            cursor.execute(query)
+            db_engine_spec.execute(cursor, query)
+            result = db_engine_spec.fetch_data(cursor, limit=1)
+            result_set = SupersetResultSet(
+                result, cursor.description, db_engine_spec
+            )
+            return result_set.columns
     except Exception as ex:
         raise SupersetGenericDBErrorException(message=str(ex)) from ex
 
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3e2c0f56ba..c6978a3fd2 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1296,14 +1296,14 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         statements = parsed_query.get_statements()
 
         costs = []
-        with cls.get_engine(database, schema=schema, source=source) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                for statement in statements:
-                    processed_statement = cls.process_statement(statement, database)
-                    costs.append(
-                        cls.estimate_statement_cost(processed_statement, cursor)
-                    )
+        with database.get_raw_connection(schema=schema, source=source) as conn:
+            cursor = conn.cursor()
+            for statement in statements:
+                processed_statement = cls.process_statement(statement, database)
+                costs.append(
+                    cls.estimate_statement_cost(processed_statement, cursor)
+                )
+
         return costs
 
     @classmethod
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 805a7ee400..9b0dd3fb8f 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -109,12 +109,10 @@ class GSheetsEngineSpec(SqliteEngineSpec):
         table_name: str,
         schema_name: Optional[str],
     ) -> Dict[str, Any]:
-        with cls.get_engine(database, schema=schema_name) as engine:
-            with closing(engine.raw_connection()) as conn:
+        with database.get_raw_connection(schema=schema_name) as conn:
                 cursor = conn.cursor()
                 cursor.execute(f'SELECT GET_METADATA("{table_name}")')
                 results = cursor.fetchone()[0]
-
         try:
             metadata = json.loads(results)
         except Exception:  # pylint: disable=broad-except
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 6755039734..c1d67f13eb 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -470,13 +470,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
             ).strip()
             params = {}
 
-        with cls.get_engine(database, schema=schema) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                cursor.execute(sql, params)
-                results = cursor.fetchall()
-
-        return {row[0] for row in results}
+        with database.get_raw_connection(schema=schema) as conn:
+            cursor = conn.cursor()
+            cursor.execute(sql, params)
+            results = cursor.fetchall()
+            return {row[0] for row in results}
 
     @classmethod
     def _create_column_info(
@@ -996,16 +994,15 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         # pylint: disable=import-outside-toplevel
         from pyhive.exc import DatabaseError
 
-        with cls.get_engine(database, schema=schema) as engine:
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                sql = f"SHOW CREATE VIEW {schema}.{table}"
-                try:
-                    cls.execute(cursor, sql)
-
-                except DatabaseError:  # not a VIEW
-                    return None
-                rows = cls.fetch_data(cursor, 1)
+        with database.get_raw_connection(schema=schema) as conn:
+            cursor = conn.cursor()
+            sql = f"SHOW CREATE VIEW {schema}.{table}"
+            try:
+                cls.execute(cursor, sql)
+            except DatabaseError:  # not a VIEW
+                return None
+            rows = cls.fetch_data(cursor, 1)
+            
             return rows[0][0]
 
     @classmethod
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 6d9903c8f0..f40d4c1daf 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -463,13 +463,9 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
             )
         )
 
-    with database.get_sqla_engine_with_context(
-        query.schema, source=QuerySource.SQL_LAB
-    ) as engine:
-        # Sharing a single connection and cursor across the
-        # execution of all statements (if many)
-        with closing(engine.raw_connection()) as conn:
-            # closing the connection closes the cursor as well
+    with database.get_raw_connection(
+        schema=query.schema, 
+        source=QuerySource.SQL_LAB) as conn:
             cursor = conn.cursor()
             cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
             if cancel_query_id is not None:
@@ -627,10 +623,9 @@ def cancel_query(query: Query) -> bool:
     if cancel_query_id is None:
         return False
 
-    with query.database.get_sqla_engine_with_context(
-        query.schema, source=QuerySource.SQL_LAB
-    ) as engine:
-        with closing(engine.raw_connection()) as conn:
+    with query.database.get_raw_connection(
+        schema=query.schema, 
+        source=QuerySource.SQL_LAB) as conn:
             with closing(conn.cursor()) as cursor:
                 return query.database.db_engine_spec.cancel_query(
                     cursor, query, cancel_query_id
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 37375e484d..cbefe154d8 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -162,18 +162,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
         statements = parsed_query.get_statements()
 
         logger.info("Validating %i statement(s)", len(statements))
-        with database.get_sqla_engine_with_context(
-            schema, source=QuerySource.SQL_LAB
-        ) as engine:
+        with database.get_raw_connection(schema=schema, source=QuerySource.SQL_LAB) as conn:
             # Sharing a single connection and cursor across the
             # execution of all statements (if many)
-            annotations: List[SQLValidationAnnotation] = []
-            with closing(engine.raw_connection()) as conn:
-                cursor = conn.cursor()
-                for statement in parsed_query.get_statements():
-                    annotation = cls.validate_statement(statement, database, cursor)
-                    if annotation:
-                        annotations.append(annotation)
+            annotations: List[SQLValidationAnnotation] = []         
+            cursor = conn.cursor()
+            for statement in parsed_query.get_statements():
+                annotation = cls.validate_statement(statement, database, cursor)
+                if annotation:
+                    annotations.append(annotation)
             logger.debug("Validation found %i error(s)", len(annotations))
-
+        
         return annotations