You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/07/19 18:34:03 UTC

[superset] branch fix_rds updated (870d4c0fce -> 538e150059)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a change to branch fix_rds
in repository https://gitbox.apache.org/repos/asf/superset.git


    omit 870d4c0fce fix: search_path in RDS
     new 538e150059 fix: search_path in RDS

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (870d4c0fce)
            \
             N -- N -- N   refs/heads/fix_rds (538e150059)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 superset/models/core.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)


[superset] 01/01: fix: search_path in RDS

Posted by be...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch fix_rds
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 538e15005926f8b09415072cd21231bd26eb880d
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Tue Jul 18 17:25:51 2023 -0700

    fix: search_path in RDS
---
 superset/db_engine_specs/base.py     | 33 +++++++++++---
 superset/db_engine_specs/postgres.py | 54 ++++++++++++-----------
 superset/models/core.py              | 83 ++++++++++++++++++++----------------
 3 files changed, 102 insertions(+), 68 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 0d778de439..a74319ccc2 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1082,22 +1082,45 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
         For some databases (like MySQL, Presto, Snowflake) this requires modifying the
         SQLAlchemy URI before creating the connection. For others (like Postgres), it
-        requires additional parameters in ``connect_args``.
+        requires additional parameters in ``connect_args`` or running pre-session
+        queries with ``set`` parameters.
 
-        When a DB engine spec implements this method it should also have the attribute
-        ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
-        given query is running in order to enforce permissions (see #23385 and #23401).
+        When a DB engine spec implements this method or ``get_prequeries`` (see below) it
+        should also have the attribute ``supports_dynamic_schema`` set to true, so that
+        Superset knows in which schema a given query is running in order to enforce
+        permissions (see #23385 and #23401).
 
         Currently, changing the catalog is not supported. The method accepts a catalog so
         that when catalog support is added to Superset the interface remains the same.
         This is important because DB engine specs can be installed from 3rd party
-        packages.
+        packages, so we want to keep these methods as stable as possible.
         """
         return uri, {
             **connect_args,
             **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
         }
 
+    @classmethod
+    def get_prequeries(
+        cls,
+        catalog: str | None = None,
+        schema: str | None = None,
+    ) -> list[str]:
+        """
+        Return pre-session queries.
+
+        These are currently used as an alternative to ``adjust_engine_params`` for
+        databases where the selected schema cannot be specified in the SQLAlchemy URI or
+        connection arguments.
+
+        For example, in order to specify a default schema in RDS we need to run a query
+        at the beggining of the session:
+
+            sql> set search_path = my_schema;
+
+        """
+        return []
+
     @classmethod
     def patch(cls) -> None:
         """
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index cdd71fdfcc..3210a6c9aa 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -14,12 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from __future__ import annotations
+
 import json
 import logging
 import re
 from datetime import datetime
 from re import Pattern
-from typing import Any, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
 
 from flask_babel import gettext as __
 from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
@@ -169,9 +172,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
     }
 
     @classmethod
-    def fetch_data(
-        cls, cursor: Any, limit: Optional[int] = None
-    ) -> list[tuple[Any, ...]]:
+    def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
         if not cursor.description:
             return []
         return super().fetch_data(cursor, limit)
@@ -224,7 +225,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
         cls,
         sqlalchemy_uri: URL,
         connect_args: dict[str, Any],
-    ) -> Optional[str]:
+    ) -> str | None:
         """
         Return the configured schema.
 
@@ -252,23 +253,24 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
         return None
 
     @classmethod
-    def adjust_engine_params(
+    def get_prequeries(
         cls,
-        uri: URL,
-        connect_args: dict[str, Any],
-        catalog: Optional[str] = None,
-        schema: Optional[str] = None,
-    ) -> tuple[URL, dict[str, Any]]:
-        if not schema:
-            return uri, connect_args
+        catalog: str | None = None,
+        schema: str | None = None,
+    ) -> list[str]:
+        """
+        Set the search path to the specified schema.
 
-        options = parse_options(connect_args)
-        options["search_path"] = schema
-        connect_args["options"] = " ".join(
-            f"-c{key}={value}" for key, value in options.items()
-        )
+        This is important for two reasons: in SQL Lab it will allow queries to run in
+        the schema selected in the dropdown, resolving unqualified table names to the
+        expected schema.
 
-        return uri, connect_args
+        But more importantly, in SQL Lab this is used to check if the user has access to
+        any tables with unqualified names. If the schema is not set by SQL Lab it could
+        be anything, and we would have to block users from running any queries
+        referencing tables without an explicit schema.
+        """
+        return [f'set search_path = "{schema}"']
 
     @classmethod
     def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
@@ -298,7 +300,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
     @classmethod
     def get_catalog_names(
         cls,
-        database: "Database",
+        database: Database,
         inspector: Inspector,
     ) -> list[str]:
         """
@@ -318,7 +320,7 @@ WHERE datistemplate = false;
 
     @classmethod
     def get_table_names(
-        cls, database: "Database", inspector: PGInspector, schema: Optional[str]
+        cls, database: Database, inspector: PGInspector, schema: str | None
     ) -> set[str]:
         """Need to consider foreign tables for PostgreSQL"""
         return set(inspector.get_table_names(schema)) | set(
@@ -327,8 +329,8 @@ WHERE datistemplate = false;
 
     @classmethod
     def convert_dttm(
-        cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
-    ) -> Optional[str]:
+        cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+    ) -> str | None:
         sqla_type = cls.get_sqla_column_type(target_type)
 
         if isinstance(sqla_type, Date):
@@ -339,7 +341,7 @@ WHERE datistemplate = false;
         return None
 
     @staticmethod
-    def get_extra_params(database: "Database") -> dict[str, Any]:
+    def get_extra_params(database: Database) -> dict[str, Any]:
         """
         For Postgres, the path to a SSL certificate is placed in `connect_args`.
 
@@ -363,7 +365,7 @@ WHERE datistemplate = false;
         return extra
 
     @classmethod
-    def get_datatype(cls, type_code: Any) -> Optional[str]:
+    def get_datatype(cls, type_code: Any) -> str | None:
         # pylint: disable=import-outside-toplevel
         from psycopg2.extensions import binary_types, string_types
 
@@ -374,7 +376,7 @@ WHERE datistemplate = false;
         return None
 
     @classmethod
-    def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
+    def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None:
         """
         Get Postgres PID that will be used to cancel all other running
         queries in the same session.
diff --git a/superset/models/core.py b/superset/models/core.py
index 4ff56145e1..a8fe8b5411 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -16,6 +16,9 @@
 # under the License.
 # pylint: disable=line-too-long,too-many-lines
 """A collection of ORM sqlalchemy models for Superset"""
+
+from __future__ import annotations
+
 import builtins
 import enum
 import json
@@ -26,7 +29,7 @@ from contextlib import closing, contextmanager, nullcontext
 from copy import deepcopy
 from datetime import datetime
 from functools import lru_cache
-from typing import Any, Callable, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
 
 import numpy
 import pandas as pd
@@ -270,7 +273,7 @@ class Database(
         return self.url_object.get_driver_name()
 
     @property
-    def masked_encrypted_extra(self) -> Optional[str]:
+    def masked_encrypted_extra(self) -> str | None:
         return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra)
 
     @property
@@ -315,7 +318,7 @@ class Database(
         return "schema_cache_timeout" in self.metadata_cache_timeout
 
     @property
-    def schema_cache_timeout(self) -> Optional[int]:
+    def schema_cache_timeout(self) -> int | None:
         return self.metadata_cache_timeout.get("schema_cache_timeout")
 
     @property
@@ -323,7 +326,7 @@ class Database(
         return "table_cache_timeout" in self.metadata_cache_timeout
 
     @property
-    def table_cache_timeout(self) -> Optional[int]:
+    def table_cache_timeout(self) -> int | None:
         return self.metadata_cache_timeout.get("table_cache_timeout")
 
     @property
@@ -364,7 +367,7 @@ class Database(
         conn = conn.set(password=PASSWORD_MASK if conn.password else None)
         self.sqlalchemy_uri = str(conn)  # hides the password
 
-    def get_effective_user(self, object_url: URL) -> Optional[str]:
+    def get_effective_user(self, object_url: URL) -> str | None:
         """
         Get the effective user, especially during impersonation.
 
@@ -383,10 +386,10 @@ class Database(
     @contextmanager
     def get_sqla_engine_with_context(
         self,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         nullpool: bool = True,
-        source: Optional[utils.QuerySource] = None,
-        override_ssh_tunnel: Optional["SSHTunnel"] = None,
+        source: utils.QuerySource | None = None,
+        override_ssh_tunnel: SSHTunnel | None = None,
     ) -> Engine:
         from superset.daos.database import (  # pylint: disable=import-outside-toplevel
             DatabaseDAO,
@@ -425,10 +428,10 @@ class Database(
 
     def _get_sqla_engine(
         self,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         nullpool: bool = True,
-        source: Optional[utils.QuerySource] = None,
-        sqlalchemy_uri: Optional[str] = None,
+        source: utils.QuerySource | None = None,
+        sqlalchemy_uri: str | None = None,
     ) -> Engine:
         sqlalchemy_url = make_url_safe(
             sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
@@ -513,17 +516,23 @@ class Database(
     @contextmanager
     def get_raw_connection(
         self,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         nullpool: bool = True,
-        source: Optional[utils.QuerySource] = None,
+        source: utils.QuerySource | None = None,
     ) -> Connection:
         with self.get_sqla_engine_with_context(
             schema=schema, nullpool=nullpool, source=source
         ) as engine:
             with closing(engine.raw_connection()) as conn:
+                # pre-session queries are used to set the selected schema and, in the
+                # future, the selected catalog
+                for prequery in self.db_engine_spec.get_prequeries(schema=schema):
+                    cursor = conn.cursor()
+                    cursor.execute(prequery)
+
                 yield conn
 
-    def get_default_schema_for_query(self, query: "Query") -> Optional[str]:
+    def get_default_schema_for_query(self, query: Query) -> str | None:
         """
         Return the default schema for a given query.
 
@@ -550,8 +559,8 @@ class Database(
     def get_df(  # pylint: disable=too-many-locals
         self,
         sql: str,
-        schema: Optional[str] = None,
-        mutator: Optional[Callable[[pd.DataFrame], None]] = None,
+        schema: str | None = None,
+        mutator: Callable[[pd.DataFrame], None] | None = None,
     ) -> pd.DataFrame:
         sqls = self.db_engine_spec.parse_sql(sql)
         engine = self._get_sqla_engine(schema)
@@ -614,7 +623,7 @@ class Database(
 
             return df
 
-    def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
+    def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
         engine = self._get_sqla_engine(schema=schema)
 
         sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
@@ -628,12 +637,12 @@ class Database(
     def select_star(  # pylint: disable=too-many-arguments
         self,
         table_name: str,
-        schema: Optional[str] = None,
+        schema: str | None = None,
         limit: int = 100,
         show_cols: bool = False,
         indent: bool = True,
         latest_partition: bool = False,
-        cols: Optional[list[ResultSetColumnType]] = None,
+        cols: list[ResultSetColumnType] | None = None,
     ) -> str:
         """Generates a ``select *`` statement in the proper dialect"""
         eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
@@ -672,7 +681,7 @@ class Database(
         self,
         schema: str,
         cache: bool = False,
-        cache_timeout: Optional[int] = None,
+        cache_timeout: int | None = None,
         force: bool = False,
     ) -> set[tuple[str, str]]:
         """Parameters need to be passed as keyword arguments.
@@ -708,7 +717,7 @@ class Database(
         self,
         schema: str,
         cache: bool = False,
-        cache_timeout: Optional[int] = None,
+        cache_timeout: int | None = None,
         force: bool = False,
     ) -> set[tuple[str, str]]:
         """Parameters need to be passed as keyword arguments.
@@ -737,7 +746,7 @@ class Database(
 
     @contextmanager
     def get_inspector_with_context(
-        self, ssh_tunnel: Optional["SSHTunnel"] = None
+        self, ssh_tunnel: SSHTunnel | None = None
     ) -> Inspector:
         with self.get_sqla_engine_with_context(
             override_ssh_tunnel=ssh_tunnel
@@ -751,9 +760,9 @@ class Database(
     def get_all_schema_names(  # pylint: disable=unused-argument
         self,
         cache: bool = False,
-        cache_timeout: Optional[int] = None,
+        cache_timeout: int | None = None,
         force: bool = False,
-        ssh_tunnel: Optional["SSHTunnel"] = None,
+        ssh_tunnel: SSHTunnel | None = None,
     ) -> list[str]:
         """Parameters need to be passed as keyword arguments.
 
@@ -818,7 +827,7 @@ class Database(
     def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
         self.db_engine_spec.update_params_from_encrypted_extra(self, params)
 
-    def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
+    def get_table(self, table_name: str, schema: str | None = None) -> Table:
         extra = self.get_extra()
         meta = MetaData(**extra.get("metadata_params", {}))
         with self.get_sqla_engine_with_context() as engine:
@@ -831,13 +840,13 @@ class Database(
             )
 
     def get_table_comment(
-        self, table_name: str, schema: Optional[str] = None
-    ) -> Optional[str]:
+        self, table_name: str, schema: str | None = None
+    ) -> str | None:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
 
     def get_columns(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> list[ResultSetColumnType]:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_columns(inspector, table_name, schema)
@@ -845,19 +854,19 @@ class Database(
     def get_metrics(
         self,
         table_name: str,
-        schema: Optional[str] = None,
+        schema: str | None = None,
     ) -> list[MetricType]:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
 
     def get_indexes(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> list[dict[str, Any]]:
         with self.get_inspector_with_context() as inspector:
             return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
 
     def get_pk_constraint(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> dict[str, Any]:
         with self.get_inspector_with_context() as inspector:
             pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
@@ -871,7 +880,7 @@ class Database(
             return {key: _convert(value) for key, value in pk_constraint.items()}
 
     def get_foreign_keys(
-        self, table_name: str, schema: Optional[str] = None
+        self, table_name: str, schema: str | None = None
     ) -> list[dict[str, Any]]:
         with self.get_inspector_with_context() as inspector:
             return inspector.get_foreign_keys(table_name, schema)
@@ -926,7 +935,7 @@ class Database(
         with self.get_sqla_engine_with_context() as engine:
             return engine.has_table(table.table_name, table.schema or None)
 
-    def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
+    def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
         with self.get_sqla_engine_with_context() as engine:
             return engine.has_table(table_name, schema)
 
@@ -936,7 +945,7 @@ class Database(
         conn: Connection,
         dialect: Dialect,
         view_name: str,
-        schema: Optional[str] = None,
+        schema: str | None = None,
     ) -> bool:
         view_names: list[str] = []
         try:
@@ -945,11 +954,11 @@ class Database(
             logger.warning("Has view failed", exc_info=True)
         return view_name in view_names
 
-    def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
+    def has_view(self, view_name: str, schema: str | None = None) -> bool:
         engine = self._get_sqla_engine()
         return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
 
-    def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
+    def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
         return self.has_view(view_name=view_name, schema=schema)
 
     def get_dialect(self) -> Dialect:
@@ -957,7 +966,7 @@ class Database(
         return sqla_url.get_dialect()()
 
     def make_sqla_column_compatible(
-        self, sqla_col: ColumnElement, label: Optional[str] = None
+        self, sqla_col: ColumnElement, label: str | None = None
     ) -> ColumnElement:
         """Takes a sqlalchemy column object and adds label info if supported by engine.
         :param sqla_col: sqlalchemy column instance