You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/03/18 00:53:51 UTC

[superset] branch master updated: feat(postgresql): dynamic schema (#23401)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c6f581fa6 feat(postgresql): dynamic schema (#23401)
2c6f581fa6 is described below

commit 2c6f581fa621033efc7d1c8699dd386539a03db8
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Mar 17 17:53:42 2023 -0700

    feat(postgresql): dynamic schema (#23401)
---
 superset/db_engine_specs/base.py                  | 41 +++++++++-------
 superset/db_engine_specs/drill.py                 | 18 ++++---
 superset/db_engine_specs/hive.py                  | 16 +++---
 superset/db_engine_specs/mysql.py                 | 14 +++---
 superset/db_engine_specs/postgres.py              | 60 ++++++++++++++++++-----
 superset/db_engine_specs/presto.py                | 20 +++++---
 superset/db_engine_specs/snowflake.py             | 22 +++++----
 superset/models/core.py                           | 50 +++++++++++++++----
 tests/unit_tests/db_engine_specs/test_postgres.py | 24 +++++++++
 9 files changed, 187 insertions(+), 78 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 26dd169dc0..b8b1662057 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -371,7 +371,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     supports_file_upload = True
 
     # Is the DB engine spec able to change the default schema? This requires implementing
-    # a custom `adjust_database_uri` method.
+    # a custom `adjust_engine_params` method.
     supports_dynamic_schema = False
 
     @classmethod
@@ -472,7 +472,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         Determining the correct schema is crucial for managing access to data, so please
         make sure you understand this logic when working on a new DB engine spec.
         """
-        # default schema varies on a per-query basis
+        # dynamic schema varies on a per-query basis
         if cls.supports_dynamic_schema:
             return query.schema
 
@@ -1057,30 +1057,33 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         ]
 
     @classmethod
-    def adjust_database_uri(  # pylint: disable=unused-argument
+    def adjust_engine_params(  # pylint: disable=unused-argument
         cls,
         uri: URL,
-        selected_schema: Optional[str],
-    ) -> URL:
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         """
-        Return a modified URL with a new database component.
+        Return a new URL and ``connect_args`` for a specific catalog/schema.
+
+        This is used in SQL Lab, allowing users to select a schema from the list of
+        schemas available in a given database, and have the query run with that schema as
+        the default one.
 
-        The URI here represents the URI as entered when saving the database,
-        ``selected_schema`` is the schema currently active presumably in
-        the SQL Lab dropdown. Based on that, for some database engine,
-        we can return a new altered URI that connects straight to the
-        active schema, meaning the users won't have to prefix the object
-        names by the schema name.
+        For some databases (like MySQL, Presto, Snowflake) this requires modifying the
+        SQLAlchemy URI before creating the connection. For others (like Postgres), it
+        requires additional parameters in ``connect_args``.
 
-        Some databases engines have 2 level of namespacing: database and
-        schema (postgres, oracle, mssql, ...)
-        For those it's probably better to not alter the database
-        component of the URI with the schema name, it won't work.
+        When a DB engine spec implements this method it should also have the attribute
+        ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
+        given query is running in order to enforce permissions (see #23385 and #23401).
 
-        Some database drivers like Presto accept '{catalog}/{schema}' in
-        the database component of the URL, that can be handled here.
+        Currently, changing the catalog is not supported. The method acceps a catalog so
+        that when catalog support is added to Superse the interface remains the same. This
+        is important because DB engine specs can be installed from 3rd party packages.
         """
-        return uri
+        return uri, connect_args
 
     @classmethod
     def patch(cls) -> None:
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 4ae5ae59b3..16ac89212a 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
 from urllib import parse
 
 from sqlalchemy import types
@@ -71,13 +71,17 @@ class DrillEngineSpec(BaseEngineSpec):
         return None
 
     @classmethod
-    def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
-        if selected_schema:
-            uri = uri.set(
-                database=parse.quote(selected_schema.replace(".", "/"), safe="")
-            )
+    def adjust_engine_params(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index f90d889f8c..792ef94735 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -260,13 +260,17 @@ class HiveEngineSpec(PrestoEngineSpec):
         return None
 
     @classmethod
-    def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
-        if selected_schema:
-            uri = uri.set(database=parse.quote(selected_schema, safe=""))
+    def adjust_engine_params(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 04b8c68dd7..e5ff964f86 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -191,15 +191,17 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
         return None
 
     @classmethod
-    def adjust_database_uri(
+    def adjust_engine_params(
         cls,
         uri: URL,
-        selected_schema: Optional[str] = None,
-    ) -> URL:
-        if selected_schema:
-            uri = uri.set(database=parse.quote(selected_schema, safe=""))
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index 84ddf56e10..fac0b1b1d0 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -72,12 +72,30 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
 SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')
 
 
+def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]:
+    """
+    Parse ``options`` from  ``connect_args`` into a dictionary.
+    """
+    if not isinstance(connect_args.get("options"), str):
+        return {}
+
+    tokens = (
+        tuple(token.strip() for token in option.strip().split("=", 1))
+        for option in re.split(r"-c\s?", connect_args["options"])
+        if "=" in option
+    )
+
+    return {token[0]: token[1] for token in tokens}
+
+
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """Abstract class for Postgres 'like' databases"""
 
     engine = ""
     engine_name = "PostgreSQL"
 
+    supports_dynamic_schema = True
+
     _time_grain_expressions = {
         None: "{col}",
         "PT1S": "DATE_TRUNC('second', {col})",
@@ -147,6 +165,25 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         ),
     }
 
+    @classmethod
+    def adjust_engine_params(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if not schema:
+            return uri, connect_args
+
+        options = parse_options(connect_args)
+        options["search_path"] = schema
+        connect_args["options"] = " ".join(
+            f"-c{key}={value}" for key, value in options.items()
+        )
+
+        return uri, connect_args
+
     @classmethod
     def get_schema_from_engine_params(
         cls,
@@ -166,19 +203,16 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         to determine the schema for a non-qualified table in a query. In cases like
         that we raise an exception.
         """
-        options = re.split(r"-c\s?", connect_args.get("options", ""))
-        for option in options:
-            if "=" not in option:
-                continue
-            key, value = option.strip().split("=", 1)
-            if key.strip() == "search_path":
-                if "," in value:
-                    raise Exception(
-                        "Multiple schemas are configured in the search path, which means "
-                        "Superset is unable to determine the schema of unqualified table "
-                        "names and enforce permissions."
-                    )
-                return value.strip()
+        options = parse_options(connect_args)
+        if search_path := options.get("search_path"):
+            schemas = search_path.split(",")
+            if len(schemas) > 1:
+                raise Exception(
+                    "Multiple schemas are configured in the search path, which means "
+                    "Superset is unable to determine the schema of unqualified table "
+                    "names and enforce permissions."
+                )
+            return schemas[0]
 
         return None
 
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index dd7bd88cdb..c0b4f2c6dd 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -301,19 +301,23 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
         return "from_unixtime({col})"
 
     @classmethod
-    def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+    def adjust_engine_params(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
-        if selected_schema and database:
-            selected_schema = parse.quote(selected_schema, safe="")
+        if schema and database:
+            schema = parse.quote(schema, safe="")
             if "/" in database:
-                database = database.split("/")[0] + "/" + selected_schema
+                database = database.split("/")[0] + "/" + schema
             else:
-                database += "/" + selected_schema
+                database += "/" + schema
             uri = uri.set(database=database)
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index ba15eea7fb..033b637e48 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -135,17 +135,21 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
         return extra
 
     @classmethod
-    def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+    def adjust_engine_params(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        catalog: Optional[str] = None,
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
-        if "/" in uri.database:
-            database = uri.database.split("/")[0]
-        if selected_schema:
-            selected_schema = parse.quote(selected_schema, safe="")
-            uri = uri.set(database=f"{database}/{selected_schema}")
+        if "/" in database:
+            database = database.split("/")[0]
+        if schema:
+            schema = parse.quote(schema, safe="")
+            uri = uri.set(database=f"{database}/{schema}")
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/models/core.py b/superset/models/core.py
index 5717726edc..d7a38cdc03 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -421,32 +421,58 @@ class Database(
         source: Optional[utils.QuerySource] = None,
         sqlalchemy_uri: Optional[str] = None,
     ) -> Engine:
-        extra = self.get_extra()
         sqlalchemy_url = make_url_safe(
             sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
         )
         self.db_engine_spec.validate_database_uri(sqlalchemy_url)
 
-        sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
+        extra = self.get_extra()
+        params = extra.get("engine_params", {})
+        if nullpool:
+            params["poolclass"] = NullPool
+        connect_args = params.get("connect_args", {})
+
+        # The ``adjust_database_uri`` method was renamed to ``adjust_engine_params`` and
+        # had its signature changed in order to support more DB engine specs. Since DB
+        # engine specs can be released as 3rd party modules we want to make sure the old
+        # method is still supported so we don't introduce a breaking change.
+        if hasattr(self.db_engine_spec, "adjust_database_uri"):
+            sqlalchemy_url = self.db_engine_spec.adjust_database_uri(
+                sqlalchemy_url,
+                schema,
+            )
+            logger.warning(
+                "DB engine spec %s implements the method `adjust_database_uri`, which is "
+                "deprecated and will be removed in version 3.0. Please update it to "
+                "implement `adjust_engine_params` instead.",
+                self.db_engine_spec,
+            )
+
+        sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
+            uri=sqlalchemy_url,
+            connect_args=connect_args,
+            catalog=None,
+            schema=schema,
+        )
+
         effective_username = self.get_effective_user(sqlalchemy_url)
         # If using MySQL or Presto for example, will set url.username
         # If using Hive, will not do anything yet since that relies on a
         # configuration parameter instead.
         sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
-            sqlalchemy_url, self.impersonate_user, effective_username
+            sqlalchemy_url,
+            self.impersonate_user,
+            effective_username,
         )
 
         masked_url = self.get_password_masked_url(sqlalchemy_url)
         logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
 
-        params = extra.get("engine_params", {})
-        if nullpool:
-            params["poolclass"] = NullPool
-
-        connect_args = params.get("connect_args", {})
         if self.impersonate_user:
             self.db_engine_spec.update_impersonation_config(
-                connect_args, str(sqlalchemy_url), effective_username
+                connect_args,
+                str(sqlalchemy_url),
+                effective_username,
             )
 
         if connect_args:
@@ -464,7 +490,11 @@ class Database(
                     source = utils.QuerySource.SQL_LAB
 
             sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
-                sqlalchemy_url, params, effective_username, security_manager, source
+                sqlalchemy_url,
+                params,
+                effective_username,
+                security_manager,
+                source,
             )
         try:
             return create_engine(sqlalchemy_url, **params)
diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py
index e57e6a6f8e..fef8647959 100644
--- a/tests/unit_tests/db_engine_specs/test_postgres.py
+++ b/tests/unit_tests/db_engine_specs/test_postgres.py
@@ -131,3 +131,27 @@ def test_get_schema_from_engine_params() -> None:
         "Superset is unable to determine the schema of unqualified table "
         "names and enforce permissions."
     )
+
+
+def test_adjust_engine_params() -> None:
+    """
+    Test the ``adjust_engine_params`` method.
+    """
+    from superset.db_engine_specs.postgres import PostgresEngineSpec
+
+    uri = make_url("postgres://user:password@host/catalog")
+
+    assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == (
+        uri,
+        {"options": "-csearch_path=secret"},
+    )
+
+    assert PostgresEngineSpec.adjust_engine_params(
+        uri,
+        {"foo": "bar", "options": "-csearch_path=default -c debug=1"},
+        None,
+        "secret",
+    ) == (
+        uri,
+        {"foo": "bar", "options": "-csearch_path=secret -cdebug=1"},
+    )