You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/03/16 22:05:06 UTC

[superset] 01/01: feat(postgresql): dynamic schema

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch postgres_set_schema
in repository https://gitbox.apache.org/repos/asf/superset.git

commit cc516c27ee88548a867d29d23b59ac3c06706206
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 16 15:04:46 2023 -0700

    feat(postgresql): dynamic schema
---
 superset/db_engine_specs/base.py      |  7 +++--
 superset/db_engine_specs/drill.py     | 17 ++++++-----
 superset/db_engine_specs/hive.py      | 13 ++++----
 superset/db_engine_specs/mysql.py     | 11 +++----
 superset/db_engine_specs/postgres.py  | 56 +++++++++++++++++++++++++++--------
 superset/db_engine_specs/presto.py    | 17 ++++++-----
 superset/db_engine_specs/snowflake.py | 15 ++++++----
 superset/models/core.py               | 32 +++++++++++++-------
 8 files changed, 112 insertions(+), 56 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 26dd169dc0..f4ef9bdd4c 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1060,8 +1060,9 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     def adjust_database_uri(  # pylint: disable=unused-argument
         cls,
         uri: URL,
-        selected_schema: Optional[str],
-    ) -> URL:
+        connect_args: Dict[str, Any],
+        schema: Optional[str],
+    ) -> Tuple[URL, Dict[str, Any]]:
         """
         Return a modified URL with a new database component.
 
@@ -1080,7 +1081,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         Some database drivers like Presto accept '{catalog}/{schema}' in
         the database component of the URL, that can be handled here.
         """
-        return uri
+        return uri, connect_args
 
     @classmethod
     def patch(cls) -> None:
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 4ae5ae59b3..cb90e84715 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
 from urllib import parse
 
 from sqlalchemy import types
@@ -71,13 +71,16 @@ class DrillEngineSpec(BaseEngineSpec):
         return None
 
     @classmethod
-    def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
-        if selected_schema:
-            uri = uri.set(
-                database=parse.quote(selected_schema.replace(".", "/"), safe="")
-            )
+    def adjust_database_uri(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str],
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index f90d889f8c..aac98085eb 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -261,12 +261,15 @@ class HiveEngineSpec(PrestoEngineSpec):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
-        if selected_schema:
-            uri = uri.set(database=parse.quote(selected_schema, safe=""))
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 04b8c68dd7..008fc6ba6b 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -194,12 +194,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
     def adjust_database_uri(
         cls,
         uri: URL,
-        selected_schema: Optional[str] = None,
-    ) -> URL:
-        if selected_schema:
-            uri = uri.set(database=parse.quote(selected_schema, safe=""))
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index 84ddf56e10..2607066305 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -78,6 +78,8 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
     engine = ""
     engine_name = "PostgreSQL"
 
+    supports_dynamic_schema = True
+
     _time_grain_expressions = {
         None: "{col}",
         "PT1S": "DATE_TRUNC('second', {col})",
@@ -147,6 +149,30 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         ),
     }
 
+    @classmethod
+    def adjust_database_uri(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if not schema:
+            return uri, connect_args
+
+        options = dict(
+            [
+                tuple(token.strip() for token in option.strip().split("=", 1))
+                for option in re.split(r"-c\s?", connect_args.get("options", ""))
+                if "=" in option
+            ]
+        )
+        options["search_path"] = schema
+        connect_args["options"] = " ".join(
+            f"-c{key}={value}" for key, value in options.items()
+        )
+
+        return uri, connect_args
+
     @classmethod
     def get_schema_from_engine_params(
         cls,
@@ -166,19 +192,23 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         to determine the schema for a non-qualified table in a query. In cases like
         that we raise an exception.
         """
-        options = re.split(r"-c\s?", connect_args.get("options", ""))
-        for option in options:
-            if "=" not in option:
-                continue
-            key, value = option.strip().split("=", 1)
-            if key.strip() == "search_path":
-                if "," in value:
-                    raise Exception(
-                        "Multiple schemas are configured in the search path, which means "
-                        "Superset is unable to determine the schema of unqualified table "
-                        "names and enforce permissions."
-                    )
-                return value.strip()
+        options = dict(
+            [
+                tuple(token.strip() for token in option.strip().split("=", 1))
+                for option in re.split(r"-c\s?", connect_args.get("options", ""))
+                if "=" in option
+            ]
+        )
+
+        if search_path := options.get("search_path"):
+            schemas = search_path.split(",")
+            if len(schemas) > 1:
+                raise Exception(
+                    "Multiple schemas are configured in the search path, which means "
+                    "Superset is unable to determine the schema of unqualified table "
+                    "names and enforce permissions."
+                )
+            return schemas[0]
 
         return None
 
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index dd7bd88cdb..8bac325d41 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -302,18 +302,21 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
-        if selected_schema and database:
-            selected_schema = parse.quote(selected_schema, safe="")
+        if schema and database:
+            schema = parse.quote(schema, safe="")
             if "/" in database:
-                database = database.split("/")[0] + "/" + selected_schema
+                database = database.split("/")[0] + "/" + schema
             else:
-                database += "/" + selected_schema
+                database += "/" + schema
             uri = uri.set(database=database)
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index ba15eea7fb..caf1cca560 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -136,16 +136,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
         if "/" in uri.database:
             database = uri.database.split("/")[0]
-        if selected_schema:
-            selected_schema = parse.quote(selected_schema, safe="")
-            uri = uri.set(database=f"{database}/{selected_schema}")
+        if schema:
+            schema = parse.quote(schema, safe="")
+            uri = uri.set(database=f"{database}/{schema}")
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/models/core.py b/superset/models/core.py
index 5717726edc..d063f7a0eb 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -421,32 +421,40 @@ class Database(
         source: Optional[utils.QuerySource] = None,
         sqlalchemy_uri: Optional[str] = None,
     ) -> Engine:
-        extra = self.get_extra()
         sqlalchemy_url = make_url_safe(
             sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
         )
         self.db_engine_spec.validate_database_uri(sqlalchemy_url)
 
-        sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
+        extra = self.get_extra()
+        params = extra.get("engine_params", {})
+        if nullpool:
+            params["poolclass"] = NullPool
+        connect_args = params.get("connect_args", {})
+
+        sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri(
+            sqlalchemy_url,
+            connect_args,
+            schema,
+        )
         effective_username = self.get_effective_user(sqlalchemy_url)
         # If using MySQL or Presto for example, will set url.username
         # If using Hive, will not do anything yet since that relies on a
         # configuration parameter instead.
         sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
-            sqlalchemy_url, self.impersonate_user, effective_username
+            sqlalchemy_url,
+            self.impersonate_user,
+            effective_username,
         )
 
         masked_url = self.get_password_masked_url(sqlalchemy_url)
         logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
 
-        params = extra.get("engine_params", {})
-        if nullpool:
-            params["poolclass"] = NullPool
-
-        connect_args = params.get("connect_args", {})
         if self.impersonate_user:
             self.db_engine_spec.update_impersonation_config(
-                connect_args, str(sqlalchemy_url), effective_username
+                connect_args,
+                str(sqlalchemy_url),
+                effective_username,
             )
 
         if connect_args:
@@ -464,7 +472,11 @@ class Database(
                     source = utils.QuerySource.SQL_LAB
 
             sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
-                sqlalchemy_url, params, effective_username, security_manager, source
+                sqlalchemy_url,
+                params,
+                effective_username,
+                security_manager,
+                source,
             )
         try:
             return create_engine(sqlalchemy_url, **params)