You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/03/16 22:05:05 UTC

[superset] branch postgres_set_schema created (now cc516c27ee)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a change to branch postgres_set_schema
in repository https://gitbox.apache.org/repos/asf/superset.git


      at cc516c27ee feat(postgresql): dynamic schema

This branch includes the following new commits:

     new cc516c27ee feat(postgresql): dynamic schema

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[superset] 01/01: feat(postgresql): dynamic schema

Posted by be...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch postgres_set_schema
in repository https://gitbox.apache.org/repos/asf/superset.git

commit cc516c27ee88548a867d29d23b59ac3c06706206
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 16 15:04:46 2023 -0700

    feat(postgresql): dynamic schema
---
 superset/db_engine_specs/base.py      |  7 +++--
 superset/db_engine_specs/drill.py     | 17 ++++++-----
 superset/db_engine_specs/hive.py      | 13 ++++----
 superset/db_engine_specs/mysql.py     | 11 +++----
 superset/db_engine_specs/postgres.py  | 56 +++++++++++++++++++++++++++--------
 superset/db_engine_specs/presto.py    | 17 ++++++-----
 superset/db_engine_specs/snowflake.py | 15 ++++++----
 superset/models/core.py               | 32 +++++++++++++-------
 8 files changed, 112 insertions(+), 56 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 26dd169dc0..f4ef9bdd4c 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1060,8 +1060,9 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     def adjust_database_uri(  # pylint: disable=unused-argument
         cls,
         uri: URL,
-        selected_schema: Optional[str],
-    ) -> URL:
+        connect_args: Dict[str, Any],
+        schema: Optional[str],
+    ) -> Tuple[URL, Dict[str, Any]]:
         """
         Return a modified URL with a new database component.
 
@@ -1080,7 +1081,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         Some database drivers like Presto accept '{catalog}/{schema}' in
         the database component of the URL, that can be handled here.
         """
-        return uri
+        return uri, connect_args
 
     @classmethod
     def patch(cls) -> None:
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 4ae5ae59b3..cb90e84715 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
 from urllib import parse
 
 from sqlalchemy import types
@@ -71,13 +71,16 @@ class DrillEngineSpec(BaseEngineSpec):
         return None
 
     @classmethod
-    def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
-        if selected_schema:
-            uri = uri.set(
-                database=parse.quote(selected_schema.replace(".", "/"), safe="")
-            )
+    def adjust_database_uri(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str],
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index f90d889f8c..aac98085eb 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -261,12 +261,15 @@ class HiveEngineSpec(PrestoEngineSpec):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
-        if selected_schema:
-            uri = uri.set(database=parse.quote(selected_schema, safe=""))
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 04b8c68dd7..008fc6ba6b 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -194,12 +194,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
     def adjust_database_uri(
         cls,
         uri: URL,
-        selected_schema: Optional[str] = None,
-    ) -> URL:
-        if selected_schema:
-            uri = uri.set(database=parse.quote(selected_schema, safe=""))
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if schema:
+            uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index 84ddf56e10..2607066305 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -78,6 +78,8 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
     engine = ""
     engine_name = "PostgreSQL"
 
+    supports_dynamic_schema = True
+
     _time_grain_expressions = {
         None: "{col}",
         "PT1S": "DATE_TRUNC('second', {col})",
@@ -147,6 +149,30 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         ),
     }
 
+    @classmethod
+    def adjust_database_uri(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
+        if not schema:
+            return uri, connect_args
+
+        options = dict(
+            [
+                tuple(token.strip() for token in option.strip().split("=", 1))
+                for option in re.split(r"-c\s?", connect_args.get("options", ""))
+                if "=" in option
+            ]
+        )
+        options["search_path"] = schema
+        connect_args["options"] = " ".join(
+            f"-c{key}={value}" for key, value in options.items()
+        )
+
+        return uri, connect_args
+
     @classmethod
     def get_schema_from_engine_params(
         cls,
@@ -166,19 +192,23 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         to determine the schema for a non-qualified table in a query. In cases like
         that we raise an exception.
         """
-        options = re.split(r"-c\s?", connect_args.get("options", ""))
-        for option in options:
-            if "=" not in option:
-                continue
-            key, value = option.strip().split("=", 1)
-            if key.strip() == "search_path":
-                if "," in value:
-                    raise Exception(
-                        "Multiple schemas are configured in the search path, which means "
-                        "Superset is unable to determine the schema of unqualified table "
-                        "names and enforce permissions."
-                    )
-                return value.strip()
+        options = dict(
+            [
+                tuple(token.strip() for token in option.strip().split("=", 1))
+                for option in re.split(r"-c\s?", connect_args.get("options", ""))
+                if "=" in option
+            ]
+        )
+
+        if search_path := options.get("search_path"):
+            schemas = search_path.split(",")
+            if len(schemas) > 1:
+                raise Exception(
+                    "Multiple schemas are configured in the search path, which means "
+                    "Superset is unable to determine the schema of unqualified table "
+                    "names and enforce permissions."
+                )
+            return schemas[0]
 
         return None
 
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index dd7bd88cdb..8bac325d41 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -302,18 +302,21 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
-        if selected_schema and database:
-            selected_schema = parse.quote(selected_schema, safe="")
+        if schema and database:
+            schema = parse.quote(schema, safe="")
             if "/" in database:
-                database = database.split("/")[0] + "/" + selected_schema
+                database = database.split("/")[0] + "/" + schema
             else:
-                database += "/" + selected_schema
+                database += "/" + schema
             uri = uri.set(database=database)
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index ba15eea7fb..caf1cca560 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -136,16 +136,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
         if "/" in uri.database:
             database = uri.database.split("/")[0]
-        if selected_schema:
-            selected_schema = parse.quote(selected_schema, safe="")
-            uri = uri.set(database=f"{database}/{selected_schema}")
+        if schema:
+            schema = parse.quote(schema, safe="")
+            uri = uri.set(database=f"{database}/{schema}")
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/superset/models/core.py b/superset/models/core.py
index 5717726edc..d063f7a0eb 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -421,32 +421,40 @@ class Database(
         source: Optional[utils.QuerySource] = None,
         sqlalchemy_uri: Optional[str] = None,
     ) -> Engine:
-        extra = self.get_extra()
         sqlalchemy_url = make_url_safe(
             sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
         )
         self.db_engine_spec.validate_database_uri(sqlalchemy_url)
 
-        sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
+        extra = self.get_extra()
+        params = extra.get("engine_params", {})
+        if nullpool:
+            params["poolclass"] = NullPool
+        connect_args = params.get("connect_args", {})
+
+        sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri(
+            sqlalchemy_url,
+            connect_args,
+            schema,
+        )
         effective_username = self.get_effective_user(sqlalchemy_url)
         # If using MySQL or Presto for example, will set url.username
         # If using Hive, will not do anything yet since that relies on a
         # configuration parameter instead.
         sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
-            sqlalchemy_url, self.impersonate_user, effective_username
+            sqlalchemy_url,
+            self.impersonate_user,
+            effective_username,
         )
 
         masked_url = self.get_password_masked_url(sqlalchemy_url)
         logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
 
-        params = extra.get("engine_params", {})
-        if nullpool:
-            params["poolclass"] = NullPool
-
-        connect_args = params.get("connect_args", {})
         if self.impersonate_user:
             self.db_engine_spec.update_impersonation_config(
-                connect_args, str(sqlalchemy_url), effective_username
+                connect_args,
+                str(sqlalchemy_url),
+                effective_username,
             )
 
         if connect_args:
@@ -464,7 +472,11 @@ class Database(
                     source = utils.QuerySource.SQL_LAB
 
             sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
-                sqlalchemy_url, params, effective_username, security_manager, source
+                sqlalchemy_url,
+                params,
+                effective_username,
+                security_manager,
+                source,
             )
         try:
             return create_engine(sqlalchemy_url, **params)