You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/10/25 18:12:59 UTC

[superset] branch master updated: feat: create function for get_sqla_engine with context (#21790)

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7600da8041 feat: create function for get_sqla_engine with context (#21790)
7600da8041 is described below

commit 7600da80412186d0f5d0c85e6cd831fbae2e9d9e
Author: Hugh A. Miles II <hu...@gmail.com>
AuthorDate: Tue Oct 25 14:12:48 2022 -0400

    feat: create function for get_sqla_engine with context (#21790)
---
 superset/models/core.py                            | 14 +++-
 tests/conftest.py                                  |  3 +-
 tests/integration_tests/access_tests.py            |  8 +--
 tests/integration_tests/celery_tests.py            |  4 +-
 tests/integration_tests/conftest.py                | 18 +++---
 tests/integration_tests/csv_upload_tests.py        | 16 ++---
 tests/integration_tests/datasets/api_tests.py      | 27 ++++----
 tests/integration_tests/datasource_tests.py        | 19 +++---
 .../integration_tests/fixtures/energy_dashboard.py | 27 ++++----
 .../fixtures/unicode_dashboard.py                  | 21 +++---
 .../fixtures/world_bank_dashboard.py               | 21 +++---
 tests/integration_tests/model_tests.py             | 75 +++++++++++++++-------
 tests/integration_tests/reports/commands_tests.py  | 16 ++---
 tests/integration_tests/sqllab_tests.py            | 56 ++++++++--------
 14 files changed, 182 insertions(+), 143 deletions(-)

diff --git a/superset/models/core.py b/superset/models/core.py
index 008230ef48..d0a32a1864 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -21,7 +21,7 @@ import json
 import logging
 import textwrap
 from ast import literal_eval
-from contextlib import closing
+from contextlib import closing, contextmanager
 from copy import deepcopy
 from datetime import datetime
 from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
@@ -362,6 +362,18 @@ class Database(
             else None
         )
 
+    @contextmanager
+    def get_sqla_engine_with_context(
+        self,
+        schema: Optional[str] = None,
+        nullpool: bool = True,
+        source: Optional[utils.QuerySource] = None,
+    ) -> Engine:
+        try:
+            yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
+        except Exception as ex:
+            raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
+
     def get_sqla_engine(
         self,
         schema: Optional[str] = None,
diff --git a/tests/conftest.py b/tests/conftest.py
index 2c129965f1..a5945f2f5c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -70,7 +70,8 @@ def example_db_provider() -> Callable[[], Database]:
 
 @fixture(scope="session")
 def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
-    return example_db_provider().get_sqla_engine()
+    with example_db_provider().get_sqla_engine_with_context() as engine:
+        return engine
 
 
 @fixture(scope="session")
diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py
index 5ab03055d9..ae8b39a8d2 100644
--- a/tests/integration_tests/access_tests.py
+++ b/tests/integration_tests/access_tests.py
@@ -158,8 +158,8 @@ class TestRequestAccess(SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_override_role_permissions_1_table(self):
         database = get_example_database()
-        engine = database.get_sqla_engine()
-        schema = inspect(engine).default_schema_name
+        with database.get_sqla_engine_with_context() as engine:
+            schema = inspect(engine).default_schema_name
 
         perm_data = ROLE_TABLES_PERM_DATA.copy()
         perm_data["database"][0]["schema"][0]["name"] = schema
@@ -186,8 +186,8 @@ class TestRequestAccess(SupersetTestCase):
     )
     def test_override_role_permissions_drops_absent_perms(self):
         database = get_example_database()
-        engine = database.get_sqla_engine()
-        schema = inspect(engine).default_schema_name
+        with database.get_sqla_engine_with_context() as engine:
+            schema = inspect(engine).default_schema_name
 
         override_me = security_manager.find_role("override_me")
         override_me.permissions.append(
diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py
index f057d3128e..da6db727e7 100644
--- a/tests/integration_tests/celery_tests.py
+++ b/tests/integration_tests/celery_tests.py
@@ -112,7 +112,9 @@ def run_sql(
 def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
     """Drop table if it exists, works on any DB"""
     sql = f"DROP {table_type} IF EXISTS  {table_name}"
-    get_example_database().get_sqla_engine().execute(sql)
+    database = get_example_database()
+    with database.get_sqla_engine_with_context() as engine:
+        engine.execute(sql)
 
 
 def quote_f(value: Optional[str]):
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 463f93b833..efbc6bf7f0 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -134,8 +134,6 @@ def setup_sample_data() -> Any:
     yield
 
     with app.app_context():
-        engine = get_example_database().get_sqla_engine()
-
         # drop sqlachemy tables
 
         db.session.commit()
@@ -210,14 +208,14 @@ def setup_presto_if_needed():
 
     if backend in {"presto", "hive"}:
         database = get_example_database()
-        engine = database.get_sqla_engine()
-        drop_from_schema(engine, CTAS_SCHEMA_NAME)
-        engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
-        engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")
-
-        drop_from_schema(engine, ADMIN_SCHEMA_NAME)
-        engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
-        engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")
+        with database.get_sqla_engine_with_context() as engine:
+            drop_from_schema(engine, CTAS_SCHEMA_NAME)
+            engine.execute(f"DROP SCHEMA IF EXISTS {CTAS_SCHEMA_NAME}")
+            engine.execute(f"CREATE SCHEMA {CTAS_SCHEMA_NAME}")
+
+            drop_from_schema(engine, ADMIN_SCHEMA_NAME)
+            engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
+            engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")
 
 
 def with_feature_flags(**mock_feature_flags):
diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py
index 6cd228b9cb..3941606aba 100644
--- a/tests/integration_tests/csv_upload_tests.py
+++ b/tests/integration_tests/csv_upload_tests.py
@@ -71,14 +71,14 @@ def setup_csv_upload(login_as_admin):
     yield
 
     upload_db = get_upload_db()
-    engine = upload_db.get_sqla_engine()
-    engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
-    engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
-    engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
-    engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
-    engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}")
-    db.session.delete(upload_db)
-    db.session.commit()
+    with upload_db.get_sqla_engine_with_context() as engine:
+        engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
+        engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
+        engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}")
+        engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
+        engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}")
+        db.session.delete(upload_db)
+        db.session.commit()
 
 
 @pytest.fixture(scope="module")
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index ef003d05dc..0175a2c334 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -670,9 +670,10 @@ class TestDatasetApi(SupersetTestCase):
             return
 
         example_db = get_example_database()
-        example_db.get_sqla_engine().execute(
-            f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
-        )
+        with example_db.get_sqla_engine_with_context() as engine:
+            engine.execute(
+                f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
+            )
 
         self.login(username="admin")
         table_data = {
@@ -690,9 +691,8 @@ class TestDatasetApi(SupersetTestCase):
         uri = f'api/v1/dataset/{data.get("id")}'
         rv = self.client.delete(uri)
         assert rv.status_code == 200
-        example_db.get_sqla_engine().execute(
-            f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names"
-        )
+        with example_db.get_sqla_engine_with_context() as engine:
+            engine.execute(f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names")
 
     def test_create_dataset_validate_database(self):
         """
@@ -758,13 +758,14 @@ class TestDatasetApi(SupersetTestCase):
         mock_get_table.return_value = None
 
         example_db = get_example_database()
-        engine = example_db.get_sqla_engine()
-        dialect = engine.dialect
-
-        with patch.object(
-            dialect, "get_view_names", wraps=dialect.get_view_names
-        ) as patch_get_view_names:
-            patch_get_view_names.return_value = ["test_case_view"]
+        with example_db.get_sqla_engine_with_context() as engine:
+            engine = engine
+            dialect = engine.dialect
+
+            with patch.object(
+                dialect, "get_view_names", wraps=dialect.get_view_names
+            ) as patch_get_view_names:
+                patch_get_view_names.return_value = ["test_case_view"]
 
             self.login(username="admin")
             table_data = {
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index ef3ba0c69d..0896971743 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -45,18 +45,17 @@ def create_test_table_context(database: Database):
     schema = get_example_default_schema()
     full_table_name = f"{schema}.test_table" if schema else "test_table"
 
-    database.get_sqla_engine().execute(
-        f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
-    )
-    database.get_sqla_engine().execute(
-        f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)"
-    )
-    database.get_sqla_engine().execute(
-        f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)"
-    )
+    with database.get_sqla_engine_with_context() as engine:
+        engine.execute(
+            f"CREATE TABLE IF NOT EXISTS {full_table_name} AS SELECT 1 as first, 2 as second"
+        )
+        engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (1, 2)")
+        engine.execute(f"INSERT INTO {full_table_name} (first, second) VALUES (3, 4)")
 
     yield db.session
-    database.get_sqla_engine().execute(f"DROP TABLE {full_table_name}")
+
+    with database.get_sqla_engine_with_context() as engine:
+        engine.execute(f"DROP TABLE {full_table_name}")
 
 
 class TestDatasource(SupersetTestCase):
diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py
index 436ba1ce55..202f494aa2 100644
--- a/tests/integration_tests/fixtures/energy_dashboard.py
+++ b/tests/integration_tests/fixtures/energy_dashboard.py
@@ -39,21 +39,22 @@ ENERGY_USAGE_TBL_NAME = "energy_usage"
 def load_energy_table_data():
     with app.app_context():
         database = get_example_database()
-        df = _get_dataframe()
-        df.to_sql(
-            ENERGY_USAGE_TBL_NAME,
-            database.get_sqla_engine(),
-            if_exists="replace",
-            chunksize=500,
-            index=False,
-            dtype={"source": String(255), "target": String(255), "value": Float()},
-            method="multi",
-            schema=get_example_default_schema(),
-        )
+        with database.get_sqla_engine_with_context() as engine:
+            df = _get_dataframe()
+            df.to_sql(
+                ENERGY_USAGE_TBL_NAME,
+                engine,
+                if_exists="replace",
+                chunksize=500,
+                index=False,
+                dtype={"source": String(255), "target": String(255), "value": Float()},
+                method="multi",
+                schema=get_example_default_schema(),
+            )
     yield
     with app.app_context():
-        engine = get_example_database().get_sqla_engine()
-        engine.execute("DROP TABLE IF EXISTS energy_usage")
+        with get_example_database().get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE IF EXISTS energy_usage")
 
 
 @pytest.fixture()
diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py
index c7b828176f..9368df7614 100644
--- a/tests/integration_tests/fixtures/unicode_dashboard.py
+++ b/tests/integration_tests/fixtures/unicode_dashboard.py
@@ -37,16 +37,17 @@ UNICODE_TBL_NAME = "unicode_test"
 @pytest.fixture(scope="session")
 def load_unicode_data():
     with app.app_context():
-        _get_dataframe().to_sql(
-            UNICODE_TBL_NAME,
-            get_example_database().get_sqla_engine(),
-            if_exists="replace",
-            chunksize=500,
-            dtype={"phrase": String(500)},
-            index=False,
-            method="multi",
-            schema=get_example_default_schema(),
-        )
+        with get_example_database().get_sqla_engine_with_context() as engine:
+            _get_dataframe().to_sql(
+                UNICODE_TBL_NAME,
+                engine,
+                if_exists="replace",
+                chunksize=500,
+                dtype={"phrase": String(500)},
+                index=False,
+                method="multi",
+                schema=get_example_default_schema(),
+            )
 
     yield
     with app.app_context():
diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py
index 2c6fb2c3e2..e29962a8c9 100644
--- a/tests/integration_tests/fixtures/world_bank_dashboard.py
+++ b/tests/integration_tests/fixtures/world_bank_dashboard.py
@@ -50,16 +50,17 @@ def load_world_bank_data():
             "country_name": String(255),
             "region": String(255),
         }
-        _get_dataframe(database).to_sql(
-            WB_HEALTH_POPULATION,
-            get_example_database().get_sqla_engine(),
-            if_exists="replace",
-            chunksize=500,
-            dtype=dtype,
-            index=False,
-            method="multi",
-            schema=get_example_default_schema(),
-        )
+        with database.get_sqla_engine_with_context() as engine:
+            _get_dataframe(database).to_sql(
+                WB_HEALTH_POPULATION,
+                engine,
+                if_exists="replace",
+                chunksize=500,
+                dtype=dtype,
+                index=False,
+                method="multi",
+                schema=get_example_default_schema(),
+            )
 
     yield
     with app.app_context():
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index c92de47f03..3e13664b63 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -57,30 +57,36 @@ class TestDatabaseModel(SupersetTestCase):
         sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
         model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
 
-        db = make_url(model.get_sqla_engine().url).database
-        self.assertEqual("hive/default", db)
+        with model.get_sqla_engine_with_context() as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("hive/default", db)
 
-        db = make_url(model.get_sqla_engine(schema="core_db").url).database
-        self.assertEqual("hive/core_db", db)
+        with model.get_sqla_engine_with_context(schema="core_db") as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("hive/core_db", db)
 
         sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
         model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
 
-        db = make_url(model.get_sqla_engine().url).database
-        self.assertEqual("hive", db)
+        with model.get_sqla_engine_with_context() as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("hive", db)
 
-        db = make_url(model.get_sqla_engine(schema="core_db").url).database
-        self.assertEqual("hive/core_db", db)
+        with model.get_sqla_engine_with_context(schema="core_db") as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("hive/core_db", db)
 
     def test_database_schema_postgres(self):
         sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
         model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
 
-        db = make_url(model.get_sqla_engine().url).database
-        self.assertEqual("prod", db)
+        with model.get_sqla_engine_with_context() as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("prod", db)
 
-        db = make_url(model.get_sqla_engine(schema="foo").url).database
-        self.assertEqual("prod", db)
+        with model.get_sqla_engine_with_context(schema="foo") as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("prod", db)
 
     @unittest.skipUnless(
         SupersetTestCase.is_module_installed("thrift"), "thrift not installed"
@@ -91,11 +97,14 @@ class TestDatabaseModel(SupersetTestCase):
     def test_database_schema_hive(self):
         sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL"
         model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
-        db = make_url(model.get_sqla_engine().url).database
-        self.assertEqual("default", db)
 
-        db = make_url(model.get_sqla_engine(schema="core_db").url).database
-        self.assertEqual("core_db", db)
+        with model.get_sqla_engine_with_context() as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("default", db)
+
+        with model.get_sqla_engine_with_context(schema="core_db") as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("core_db", db)
 
     @unittest.skipUnless(
         SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
@@ -104,11 +113,13 @@ class TestDatabaseModel(SupersetTestCase):
         sqlalchemy_uri = "mysql://root@localhost/superset"
         model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
 
-        db = make_url(model.get_sqla_engine().url).database
-        self.assertEqual("superset", db)
+        with model.get_sqla_engine_with_context() as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("superset", db)
 
-        db = make_url(model.get_sqla_engine(schema="staging").url).database
-        self.assertEqual("staging", db)
+        with model.get_sqla_engine_with_context(schema="staging") as engine:
+            db = make_url(engine.url).database
+            self.assertEqual("staging", db)
 
     @unittest.skipUnless(
         SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
@@ -120,12 +131,14 @@ class TestDatabaseModel(SupersetTestCase):
 
         with override_user(example_user):
             model.impersonate_user = True
-            username = make_url(model.get_sqla_engine().url).username
-            self.assertEqual(example_user.username, username)
+            with model.get_sqla_engine_with_context() as engine:
+                username = make_url(engine.url).username
+                self.assertEqual(example_user.username, username)
 
             model.impersonate_user = False
-            username = make_url(model.get_sqla_engine().url).username
-            self.assertNotEqual(example_user.username, username)
+            with model.get_sqla_engine_with_context() as engine:
+                username = make_url(engine.url).username
+                self.assertNotEqual(example_user.username, username)
 
     @mock.patch("superset.models.core.create_engine")
     def test_impersonate_user_presto(self, mocked_create_engine):
@@ -369,6 +382,20 @@ class TestDatabaseModel(SupersetTestCase):
         with self.assertRaises(SupersetException):
             model.get_sqla_engine()
 
+    # todo(hughhh): update this test
+    # @mock.patch("superset.models.core.create_engine")
+    # def test_get_sqla_engine_with_context(self, mocked_create_engine):
+    #     model = Database(
+    #         database_name="test_database",
+    #         sqlalchemy_uri="mysql://root@localhost",
+    #     )
+    #     model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock(
+    #         return_value={Exception: SupersetException}
+    #     )
+    #     mocked_create_engine.side_effect = Exception()
+    #     with self.assertRaises(SupersetException):
+    #         model.get_sqla_engine()
+
 
 class TestSqlaTableModel(SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py
index c82c1b5fdb..b3ef86c5e3 100644
--- a/tests/integration_tests/reports/commands_tests.py
+++ b/tests/integration_tests/reports/commands_tests.py
@@ -130,18 +130,14 @@ def assert_log(state: str, error_message: Optional[str] = None):
 
 @contextmanager
 def create_test_table_context(database: Database):
-    database.get_sqla_engine().execute(
-        "CREATE TABLE test_table AS SELECT 1 as first, 2 as second"
-    )
-    database.get_sqla_engine().execute(
-        "INSERT INTO test_table (first, second) VALUES (1, 2)"
-    )
-    database.get_sqla_engine().execute(
-        "INSERT INTO test_table (first, second) VALUES (3, 4)"
-    )
+    with database.get_sqla_engine_with_context() as engine:
+        engine.execute("CREATE TABLE test_table AS SELECT 1 as first, 2 as second")
+        engine.execute("INSERT INTO test_table (first, second) VALUES (1, 2)")
+        engine.execute("INSERT INTO test_table (first, second) VALUES (3, 4)")
 
     yield db.session
-    database.get_sqla_engine().execute("DROP TABLE test_table")
+    with database.get_sqla_engine_with_context() as engine:
+        engine.execute("DROP TABLE test_table")
 
 
 @pytest.fixture()
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index 0c4019e7d9..bee9b08114 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -207,19 +207,21 @@ class TestSqlLab(SupersetTestCase):
             # assertions
             db.session.commit()
             examples_db = get_example_database()
-            engine = examples_db.get_sqla_engine()
-            data = engine.execute(
-                f"SELECT * FROM admin_database.{tmp_table_name}"
-            ).fetchall()
-            names_count = engine.execute(f"SELECT COUNT(*) FROM birth_names").first()
-            self.assertEqual(
-                names_count[0], len(data)
-            )  # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
-
-            # cleanup
-            engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
-            examples_db.allow_ctas = old_allow_ctas
-            db.session.commit()
+            with examples_db.get_sqla_engine_with_context() as engine:
+                data = engine.execute(
+                    f"SELECT * FROM admin_database.{tmp_table_name}"
+                ).fetchall()
+                names_count = engine.execute(
+                    f"SELECT COUNT(*) FROM birth_names"
+                ).first()
+                self.assertEqual(
+                    names_count[0], len(data)
+                )  # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
+
+                # cleanup
+                engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
+                examples_db.allow_ctas = old_allow_ctas
+                db.session.commit()
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_multi_sql(self):
@@ -275,9 +277,10 @@ class TestSqlLab(SupersetTestCase):
             "SchemaUser", ["SchemaPermission", "Gamma", "sql_lab"]
         )
 
-        examples_db.get_sqla_engine().execute(
-            f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
-        )
+        with examples_db.get_sqla_engine_with_context() as engine:
+            engine.execute(
+                f"CREATE TABLE IF NOT EXISTS {CTAS_SCHEMA_NAME}.test_table AS SELECT 1 as c1, 2 as c2"
+            )
 
         data = self.run_sql(
             f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser"
@@ -303,9 +306,8 @@ class TestSqlLab(SupersetTestCase):
             self.assertEqual(1, len(data["data"]))
 
         db.session.query(Query).delete()
-        get_example_database().get_sqla_engine().execute(
-            f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table"
-        )
+        with get_example_database().get_sqla_engine_with_context() as engine:
+            engine.execute(f"DROP TABLE IF EXISTS {CTAS_SCHEMA_NAME}.test_table")
         db.session.commit()
 
     def test_queries_endpoint(self):
@@ -520,12 +522,10 @@ class TestSqlLab(SupersetTestCase):
     def test_sqllab_table_viz(self):
         self.login("admin")
         examples_db = get_example_database()
-        examples_db.get_sqla_engine().execute(
-            "DROP TABLE IF EXISTS test_sqllab_table_viz"
-        )
-        examples_db.get_sqla_engine().execute(
-            "CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col"
-        )
+        with examples_db.get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE IF EXISTS test_sqllab_table_viz")
+            engine.execute("CREATE TABLE test_sqllab_table_viz AS SELECT 2 as col")
+
         examples_dbid = examples_db.id
 
         payload = {
@@ -543,9 +543,9 @@ class TestSqlLab(SupersetTestCase):
         table = db.session.query(SqlaTable).filter_by(id=table_id).one()
         self.assertEqual([owner.username for owner in table.owners], ["admin"])
         db.session.delete(table)
-        get_example_database().get_sqla_engine().execute(
-            "DROP TABLE test_sqllab_table_viz"
-        )
+
+        with get_example_database().get_sqla_engine_with_context() as engine:
+            engine.execute("DROP TABLE test_sqllab_table_viz")
         db.session.commit()
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")