You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/07/20 19:57:55 UTC

[superset] branch master updated: fix: `search_path` in RDS (#24739)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 7675e0db10 fix: `search_path` in RDS (#24739)
7675e0db10 is described below

commit 7675e0db10f42dbb76f908e9bc70906da204c98d
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Jul 20 12:57:48 2023 -0700

    fix: `search_path` in RDS (#24739)
---
 superset/db_engine_specs/base.py                  | 33 +++++++--
 superset/db_engine_specs/postgres.py              | 88 +++++++++++++++--------
 superset/models/core.py                           | 83 +++++++++++----------
 tests/unit_tests/db_engine_specs/test_postgres.py | 48 +++++++++----
 tests/unit_tests/models/core_test.py              | 18 +++++
 5 files changed, 185 insertions(+), 85 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 0d778de439..af24e54790 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1082,22 +1082,45 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
         For some databases (like MySQL, Presto, Snowflake) this requires modifying the
         SQLAlchemy URI before creating the connection. For others (like Postgres), it
-        requires additional parameters in ``connect_args``.
+        requires additional parameters in ``connect_args`` or running pre-session
+        queries with ``set`` parameters.
 
-        When a DB engine spec implements this method it should also have the attribute
-        ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
-        given query is running in order to enforce permissions (see #23385 and #23401).
+        When a DB engine spec implements this method or ``get_prequeries`` (see below) it
+        should also have the attribute ``supports_dynamic_schema`` set to true, so that
+        Superset knows in which schema a given query is running in order to enforce
+        permissions (see #23385 and #23401).
 
         Currently, changing the catalog is not supported. The method accepts a catalog so
         that when catalog support is added to Superset the interface remains the same.
         This is important because DB engine specs can be installed from 3rd party
-        packages.
+        packages, so we want to keep these methods as stable as possible.
         """
         return uri, {
             **connect_args,
             **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
         }
 
+    @classmethod
+    def get_prequeries(
+        cls,
+        catalog: str | None = None,  # pylint: disable=unused-argument
+        schema: str | None = None,  # pylint: disable=unused-argument
+    ) -> list[str]:
+        """
+        Return pre-session queries.
+
+        These are currently used as an alternative to ``adjust_engine_params`` for
+        databases where the selected schema cannot be specified in the SQLAlchemy URI or
+        connection arguments.
+
+        For example, in order to specify a default schema in RDS we need to run a query
+        at the beggining of the session:
+
+            sql> set search_path = my_schema;
+
+        """
+        return []
+
     @classmethod
     def patch(cls) -> None:
         """
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index cdd71fdfcc..642f84f58c 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -14,13 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from __future__ import annotations
+
 import json
 import logging
 import re
 from datetime import datetime
 from re import Pattern
-from typing import Any, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
 
+import sqlparse
 from flask_babel import gettext as __
 from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
 from sqlalchemy.dialects.postgresql.base import PGInspector
@@ -30,8 +34,8 @@ from sqlalchemy.types import Date, DateTime, String
 
 from superset.constants import TimeGrain
 from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
-from superset.errors import SupersetErrorType
-from superset.exceptions import SupersetException
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetException, SupersetSecurityException
 from superset.models.sql_lab import Query
 from superset.utils import core as utils
 from superset.utils.core import GenericDataType
@@ -169,9 +173,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
     }
 
     @classmethod
-    def fetch_data(
-        cls, cursor: Any, limit: Optional[int] = None
-    ) -> list[tuple[Any, ...]]:
+    def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
         if not cursor.description:
             return []
         return super().fetch_data(cursor, limit)
@@ -224,7 +226,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
         cls,
         sqlalchemy_uri: URL,
         connect_args: dict[str, Any],
-    ) -> Optional[str]:
+    ) -> str | None:
         """
         Return the configured schema.
 
@@ -237,6 +239,9 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
         in Superset because it breaks schema-level permissions, since it's impossible
         to determine the schema for a non-qualified table in a query. In cases like
         that we raise an exception.
+
+        Note that because the DB engine supports dynamic schema this method is never
+        called. It's left here as an implementation reference.
         """
         options = parse_options(connect_args)
         if search_path := options.get("search_path"):
@@ -252,23 +257,50 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
         return None
 
     @classmethod
-    def adjust_engine_params(
+    def get_default_schema_for_query(
         cls,
-        uri: URL,
-        connect_args: dict[str, Any],
-        catalog: Optional[str] = None,
-        schema: Optional[str] = None,
-    ) -> tuple[URL, dict[str, Any]]:
-        if not schema:
-            return uri, connect_args
+        database: Database,
+        query: Query,
+    ) -> str | None:
+        """
+        Return the default schema for a given query.
 
-        options = parse_options(connect_args)
-        options["search_path"] = schema
-        connect_args["options"] = " ".join(
-            f"-c{key}={value}" for key, value in options.items()
-        )
+        This method simply uses the parent method after checking that there are no
+        malicious path setting in the query.
+        """
+        sql = sqlparse.format(query.sql, strip_comments=True)
+        if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
+            raise SupersetSecurityException(
+                SupersetError(
+                    error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
+                    message=__(
+                        "Users are not allowed to set a search path for security reasons."
+                    ),
+                    level=ErrorLevel.ERROR,
+                )
+            )
+
+        return super().get_default_schema_for_query(database, query)
 
-        return uri, connect_args
+    @classmethod
+    def get_prequeries(
+        cls,
+        catalog: str | None = None,
+        schema: str | None = None,
+    ) -> list[str]:
+        """
+        Set the search path to the specified schema.
+
+        This is important for two reasons: in SQL Lab it will allow queries to run in
+        the schema selected in the dropdown, resolving unqualified table names to the
+        expected schema.
+
+        But more importantly, in SQL Lab this is used to check if the user has access to
+        any tables with unqualified names. If the schema is not set by SQL Lab it could
+        be anything, and we would have to block users from running any queries
+        referencing tables without an explicit schema.
+        """
+        return [f'set search_path = "{schema}"'] if schema else []
 
     @classmethod
     def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
@@ -298,7 +330,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
     @classmethod
     def get_catalog_names(
         cls,
-        database: "Database",
+        database: Database,
         inspector: Inspector,
     ) -> list[str]:
         """
@@ -318,7 +350,7 @@ WHERE datistemplate = false;
 
     @classmethod
     def get_table_names(
-        cls, database: "Database", inspector: PGInspector, schema: Optional[str]
+        cls, database: Database, inspector: PGInspector, schema: str | None
     ) -> set[str]:
         """Need to consider foreign tables for PostgreSQL"""
         return set(inspector.get_table_names(schema)) | set(
@@ -327,8 +359,8 @@ WHERE datistemplate = false;
 
     @classmethod
     def convert_dttm(
-        cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
-    ) -> Optional[str]:
+        cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+    ) -> str | None:
         sqla_type = cls.get_sqla_column_type(target_type)
 
         if isinstance(sqla_type, Date):
@@ -339,7 +371,7 @@ WHERE datistemplate = false;
         return None
 
     @staticmethod
-    def get_extra_params(database: "Database") -> dict[str, Any]:
+    def get_extra_params(database: Database) -> dict[str, Any]:
         """
         For Postgres, the path to a SSL certificate is placed in `connect_args`.
 
@@ -363,7 +395,7 @@ WHERE datistemplate = false;
         return extra
 
     @classmethod
-    def get_datatype(cls, type_code: Any) -> Optional[str]:
+    def get_datatype(cls, type_code: Any) -> str | None:
         # pylint: disable=import-outside-toplevel
         from psycopg2.extensions import binary_types, string_types
 
@@ -374,7 +406,7 @@ WHERE datistemplate = false;
         return None
 
     @classmethod
-    def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
+    def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None:
         """
         Get Postgres PID that will be used to cancel all other running
         queries in the same session.
diff --git a/superset/models/core.py b/superset/models/core.py
index 4ff56145e1..a8fe8b5411 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -16,6 +16,9 @@
 # under the License.
 # pylint: disable=line-too-long,too-many-lines
 """A collection of ORM sqlalchemy models for Superset"""
+
+from __future__ import annotations
+
 import builtins
 import enum
 import json
@@ -26,7 +29,7 @@ from contextlib import closing, contextmanager, nullcontext
 from copy import deepcopy
 from datetime import datetime
 from functools import lru_cache
-from typing import Any, Callable, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
 
 import numpy
 import pandas as pd
@@ -270,7 +273,7 @@ class Database(
         return self.url_object.get_driver_name()
 
     @property
-    def masked_encrypted_extra(self) -> Optional[str]:
+    def masked_encrypted_extra(self) -> str | None:
         return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra)
 
     @property
@@ -315,7 +318,7 @@ class Database(
         return "schema_cache_timeout" in self.metadata_cache_timeout
 
     @property
-    def schema_cache_timeout(self) -> Optional[int]:
+    def schema_cache_timeout(self) -> int | None:
         return self.metadata_cache_timeout.get("schema_cache_timeout")
 
     @property
@@ -323,7 +326,7 @@ class Database(
         return "table_cache_timeout" in self.metadata_cache_timeout
 
     @property
-    def table_cache_timeout(self) -> Optional[int]:
+    def table_cache_timeout(self) -> int | None:
         return self.metadata_cache_timeout.get("table_cache_timeout")
 
     @property
@@ -364,7 +367,7 @@ class Database(
         conn = conn.set(password=PASSWORD_MASK if conn.password else None)
         self.sqlalchemy_uri = str(conn)  # hides the password
 
-    def get_effective_user(self, object_url: URL) -> Optional[str]:
+    def get_effective_user(self, object_url: URL) -> str | None:
         """
         Get the effective user, especially during impersonation.
 
@@ -383,10 +386,10 @@ class Database(
     @contextmanager
     def get_sqla_engine_with_context(
         self,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         nullpool: bool = True,
-        source: Optional[utils.QuerySource] = None,
-        override_ssh_tunnel: Optional["SSHTunnel"] = None,
+        source: utils.QuerySource | None = None,
+        override_ssh_tunnel: SSHTunnel | None = None,
     ) -> Engine:
         from superset.daos.database import (  # pylint: disable=import-outside-toplevel
             DatabaseDAO,
@@ -425,10 +428,10 @@ class Database(
 
     def _get_sqla_engine(
         self,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         nullpool: bool = True,
-        source: Optional[utils.QuerySource] = None,
-        sqlalchemy_uri: Optional[str] = None,
+        source: utils.QuerySource | None = None,
+        sqlalchemy_uri: str | None = None,
     ) -> Engine:
         sqlalchemy_url = make_url_safe(
             sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
@@ -513,17 +516,23 @@ class Database(
     @contextmanager
     def get_raw_connection(
         self,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         nullpool: bool = True,
-        source: Optional[utils.QuerySource] = None,
+        source: utils.QuerySource | None = None,
     ) -> Connection:
         with self.get_sqla_engine_with_context(
             schema=schema, nullpool=nullpool, source=source
         ) as engine:
             with closing(engine.raw_connection()) as conn:
+                # pre-session queries are used to set the selected schema and, in the
+                # future, the selected catalog
+                for prequery in self.db_engine_spec.get_prequeries(schema=schema):
+                    cursor = conn.cursor()
+                    cursor.execute(prequery)
+
                 yield conn
 
-    def get_default_schema_for_query(self, query: "Query") -> Optional[str]:
+    def get_default_schema_for_query(self, query: Query) -> str | None:
         """
         Return the default schema for a given query.
 
@@ -550,8 +559,8 @@ class Database(
     def get_df(  # pylint: disable=too-many-locals
         self,
         sql: str,
-        schema: Optional[str] = None,
-        mutator: Optional[Callable[[pd.DataFrame], None]] = None,
+        schema: str | None = None,
+        mutator: Callable[[pd.DataFrame], None] | None = None,
     ) -> pd.DataFrame:
         sqls = self.db_engine_spec.parse_sql(sql)
         engine = self._get_sqla_engine(schema)
@@ -614,7 +623,7 @@ class Database(
 
             return df
 
-    def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
+    def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
         engine = self._get_sqla_engine(schema=schema)
 
         sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
@@ -628,12 +637,12 @@ class Database(
     def select_star(  # pylint: disable=too-many-arguments
         self,
         table_name: str,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         limit: int = 100,
         show_cols: bool = False,
         indent: bool = True,
         latest_partition: bool = False,
-        cols: Optional[list[ResultSetColumnType]] = None,
+        cols: list[ResultSetColumnType] | None = None,
     ) -> str:
         """Generates a ``select *`` statement in the proper dialect"""
         eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
@@ -672,7 +681,7 @@ class Database(
         self,
         schema: str,
         cache: bool = False,
-        cache_timeout: Optional[int] = None,
+        cache_timeout: int | None = None,
         force: bool = False,
     ) -> set[tuple[str, str]]:
         """Parameters need to be passed as keyword arguments.
@@ -708,7 +717,7 @@ class Database(
         self,
         schema: str,
         cache: bool = False,
-        cache_timeout: Optional[int] = None,
+        cache_timeout: int | None = None,
         force: bool = False,
     ) -> set[tuple[str, str]]:
         """Parameters need to be passed as keyword arguments.
@@ -737,7 +746,7 @@ class Database(
 
     @contextmanager
     def get_inspector_with_context(
-        self, ssh_tunnel: Optional["SSHTunnel"] = None
+        self, ssh_tunnel: SSHTunnel | None = None
     ) -> Inspector:
         with self.get_sqla_engine_with_context(
             override_ssh_tunnel=ssh_tunnel
@@ -751,9 +760,9 @@ class Database(
     def get_all_schema_names(  # pylint: disable=unused-argument
         self,
         cache: bool = False,
-        cache_timeout: Optional[int] = None,
+        cache_timeout: int | None = None,
         force: bool = False,
-        ssh_tunnel: Optional["SSHTunnel"] = None,
+        ssh_tunnel: SSHTunnel | None = None,
     ) -> list[str]:
         """Parameters need to be passed as keyword arguments.
 
@@ -818,7 +827,7 @@ class Database(
     def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
         self.db_engine_spec.update_params_from_encrypted_extra(self, params)
 
-    def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
+    def get_table(self, table_name: str, schema: str | None = None) -> Table:
         extra = self.get_extra()
         meta = MetaData(**extra.get("metadata_params", {}))
         with self.get_sqla_engine_with_context() as engine:
@@ -831,13 +840,13 @@ class Database(
             )
 
     def get_table_comment(
-        self, table_name: str, schema: Optional[str] = None
-    ) -> Optional[str]:
+        self, table_name: str, schema: str | None = None
+    ) -> str | None:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
 
     def get_columns(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> list[ResultSetColumnType]:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_columns(inspector, table_name, schema)
@@ -845,19 +854,19 @@ class Database(
     def get_metrics(
         self,
         table_name: str,
-        schema: Optional[str] = None,
+        schema: str | None = None,
     ) -> list[MetricType]:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
 
     def get_indexes(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> list[dict[str, Any]]:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
 
     def get_pk_constraint(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> dict[str, Any]:
         with self.get_inspector_with_context() as inspector:
             pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
@@ -871,7 +880,7 @@ class Database(
             return {key: _convert(value) for key, value in pk_constraint.items()}
 
     def get_foreign_keys(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> list[dict[str, Any]]:
         with self.get_inspector_with_context() as inspector:
             return inspector.get_foreign_keys(table_name, schema)
@@ -926,7 +935,7 @@ class Database(
         with self.get_sqla_engine_with_context() as engine:
             return engine.has_table(table.table_name, table.schema or None)
 
-    def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
+    def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
         with self.get_sqla_engine_with_context() as engine:
             return engine.has_table(table_name, schema)
 
@@ -936,7 +945,7 @@ class Database(
         conn: Connection,
         dialect: Dialect,
         view_name: str,
-        schema: Optional[str] = None,
+        schema: str | None = None,
     ) -> bool:
         view_names: list[str] = []
         try:
@@ -945,11 +954,11 @@ class Database(
             logger.warning("Has view failed", exc_info=True)
         return view_name in view_names
 
-    def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
+    def has_view(self, view_name: str, schema: str | None = None) -> bool:
         engine = self._get_sqla_engine()
         return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
 
-    def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
+    def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
         return self.has_view(view_name=view_name, schema=schema)
 
     def get_dialect(self) -> Dialect:
@@ -957,7 +966,7 @@ class Database(
         return sqla_url.get_dialect()()
 
     def make_sqla_column_compatible(
-        self, sqla_col: ColumnElement, label: Optional[str] = None
+        self, sqla_col: ColumnElement, label: str | None = None
     ) -> ColumnElement:
         """Takes a sqlalchemy column object and adds label info if supported by engine.
         :param sqla_col: sqlalchemy column instance
diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py
index 145d398898..59d1829f14 100644
--- a/tests/unit_tests/db_engine_specs/test_postgres.py
+++ b/tests/unit_tests/db_engine_specs/test_postgres.py
@@ -19,10 +19,12 @@ from datetime import datetime
 from typing import Any, Optional
 
 import pytest
+from pytest_mock import MockFixture
 from sqlalchemy import types
 from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
 from sqlalchemy.engine.url import make_url
 
+from superset.exceptions import SupersetSecurityException
 from superset.utils.core import GenericDataType
 from tests.unit_tests.db_engine_specs.utils import (
     assert_column_spec,
@@ -133,25 +135,41 @@ def test_get_schema_from_engine_params() -> None:
     )
 
 
-def test_adjust_engine_params() -> None:
+def test_get_prequeries() -> None:
     """
-    Test the ``adjust_engine_params`` method.
+    Test the ``get_prequeries`` method.
     """
     from superset.db_engine_specs.postgres import PostgresEngineSpec
 
-    uri = make_url("postgres://user:password@host/catalog")
+    assert PostgresEngineSpec.get_prequeries() == []
+    assert PostgresEngineSpec.get_prequeries(schema="test") == [
+        'set search_path = "test"'
+    ]
 
-    assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == (
-        uri,
-        {"options": "-csearch_path=secret"},
-    )
 
-    assert PostgresEngineSpec.adjust_engine_params(
-        uri,
-        {"foo": "bar", "options": "-csearch_path=default -c debug=1"},
-        None,
-        "secret",
-    ) == (
-        uri,
-        {"foo": "bar", "options": "-csearch_path=secret -cdebug=1"},
+def test_get_default_schema_for_query(mocker: MockFixture) -> None:
+    """
+    Test the ``get_default_schema_for_query`` method.
+    """
+    from superset.db_engine_specs.postgres import PostgresEngineSpec
+
+    database = mocker.MagicMock()
+    query = mocker.MagicMock()
+
+    query.sql = "SELECT * FROM some_table"
+    query.schema = "foo"
+    assert PostgresEngineSpec.get_default_schema_for_query(database, query) == "foo"
+
+    query.sql = """
+set
+-- this is a tricky comment
+search_path -- another one
+= bar;
+SELECT * FROM some_table;
+    """
+    with pytest.raises(SupersetSecurityException) as excinfo:
+        PostgresEngineSpec.get_default_schema_for_query(database, query)
+    assert (
+        str(excinfo.value)
+        == "Users are not allowed to set a search path for security reasons."
     )
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index 267b7c024a..5d6c1fcbcc 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -212,3 +212,21 @@ def test_dttm_sql_literal(
 def test_table_column_database() -> None:
     database = Database(database_name="db")
     assert TableColumn(database=database).database is database
+
+
+def test_get_prequeries(mocker: MockFixture) -> None:
+    """
+    Tests for ``get_prequeries``.
+    """
+    mocker.patch.object(
+        Database,
+        "get_sqla_engine_with_context",
+    )
+    db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
+    db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
+
+    database = Database(database_name="db")
+    with database.get_raw_connection() as conn:
+        conn.cursor().execute.assert_has_calls(
+            [mocker.call("set a=1"), mocker.call("set b=2")]
+        )