You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/03/16 22:05:06 UTC
[superset] 01/01: feat(postgresql): dynamic schema
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch postgres_set_schema
in repository https://gitbox.apache.org/repos/asf/superset.git
commit cc516c27ee88548a867d29d23b59ac3c06706206
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 16 15:04:46 2023 -0700
feat(postgresql): dynamic schema
---
superset/db_engine_specs/base.py | 7 +++--
superset/db_engine_specs/drill.py | 17 ++++++-----
superset/db_engine_specs/hive.py | 13 ++++----
superset/db_engine_specs/mysql.py | 11 +++----
superset/db_engine_specs/postgres.py | 56 +++++++++++++++++++++++++++--------
superset/db_engine_specs/presto.py | 17 ++++++-----
superset/db_engine_specs/snowflake.py | 15 ++++++----
superset/models/core.py | 32 +++++++++++++-------
8 files changed, 112 insertions(+), 56 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 26dd169dc0..f4ef9bdd4c 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1060,8 +1060,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def adjust_database_uri( # pylint: disable=unused-argument
cls,
uri: URL,
- selected_schema: Optional[str],
- ) -> URL:
+ connect_args: Dict[str, Any],
+ schema: Optional[str],
+ ) -> Tuple[URL, Dict[str, Any]]:
"""
Return a modified URL with a new database component.
@@ -1080,7 +1081,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
"""
- return uri
+ return uri, connect_args
@classmethod
def patch(cls) -> None:
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 4ae5ae59b3..cb90e84715 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
from urllib import parse
from sqlalchemy import types
@@ -71,13 +71,16 @@ class DrillEngineSpec(BaseEngineSpec):
return None
@classmethod
- def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
- if selected_schema:
- uri = uri.set(
- database=parse.quote(selected_schema.replace(".", "/"), safe="")
- )
+ def adjust_database_uri(
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ schema: Optional[str],
+ ) -> Tuple[URL, Dict[str, Any]]:
+ if schema:
+ uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
- return uri
+ return uri, connect_args
@classmethod
def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index f90d889f8c..aac98085eb 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -261,12 +261,15 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> URL:
- if selected_schema:
- uri = uri.set(database=parse.quote(selected_schema, safe=""))
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
+ if schema:
+ uri = uri.set(database=parse.quote(schema, safe=""))
- return uri
+ return uri, connect_args
@classmethod
def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 04b8c68dd7..008fc6ba6b 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -194,12 +194,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def adjust_database_uri(
cls,
uri: URL,
- selected_schema: Optional[str] = None,
- ) -> URL:
- if selected_schema:
- uri = uri.set(database=parse.quote(selected_schema, safe=""))
+ connect_args: Dict[str, Any],
+ schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
+ if schema:
+ uri = uri.set(database=parse.quote(schema, safe=""))
- return uri
+ return uri, connect_args
@classmethod
def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index 84ddf56e10..2607066305 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -78,6 +78,8 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
engine = ""
engine_name = "PostgreSQL"
+ supports_dynamic_schema = True
+
_time_grain_expressions = {
None: "{col}",
"PT1S": "DATE_TRUNC('second', {col})",
@@ -147,6 +149,30 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
),
}
+ @classmethod
+ def adjust_database_uri(
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
+ if not schema:
+ return uri, connect_args
+
+ options = dict(
+ [
+ tuple(token.strip() for token in option.strip().split("=", 1))
+ for option in re.split(r"-c\s?", connect_args.get("options", ""))
+ if "=" in option
+ ]
+ )
+ options["search_path"] = schema
+ connect_args["options"] = " ".join(
+ f"-c{key}={value}" for key, value in options.items()
+ )
+
+ return uri, connect_args
+
@classmethod
def get_schema_from_engine_params(
cls,
@@ -166,19 +192,23 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
"""
- options = re.split(r"-c\s?", connect_args.get("options", ""))
- for option in options:
- if "=" not in option:
- continue
- key, value = option.strip().split("=", 1)
- if key.strip() == "search_path":
- if "," in value:
- raise Exception(
- "Multiple schemas are configured in the search path, which means "
- "Superset is unable to determine the schema of unqualified table "
- "names and enforce permissions."
- )
- return value.strip()
+ options = dict(
+ [
+ tuple(token.strip() for token in option.strip().split("=", 1))
+ for option in re.split(r"-c\s?", connect_args.get("options", ""))
+ if "=" in option
+ ]
+ )
+
+ if search_path := options.get("search_path"):
+ schemas = search_path.split(",")
+ if len(schemas) > 1:
+ raise Exception(
+ "Multiple schemas are configured in the search path, which means "
+ "Superset is unable to determine the schema of unqualified table "
+ "names and enforce permissions."
+ )
+ return schemas[0]
return None
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index dd7bd88cdb..8bac325d41 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -302,18 +302,21 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> URL:
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
- if selected_schema and database:
- selected_schema = parse.quote(selected_schema, safe="")
+ if schema and database:
+ schema = parse.quote(schema, safe="")
if "/" in database:
- database = database.split("/")[0] + "/" + selected_schema
+ database = database.split("/")[0] + "/" + schema
else:
- database += "/" + selected_schema
+ database += "/" + schema
uri = uri.set(database=database)
- return uri
+ return uri, connect_args
@classmethod
def get_schema_from_engine_params(
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index ba15eea7fb..caf1cca560 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -136,16 +136,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> URL:
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
- if selected_schema:
- selected_schema = parse.quote(selected_schema, safe="")
- uri = uri.set(database=f"{database}/{selected_schema}")
+ if schema:
+ schema = parse.quote(schema, safe="")
+ uri = uri.set(database=f"{database}/{schema}")
- return uri
+ return uri, connect_args
@classmethod
def get_schema_from_engine_params(
diff --git a/superset/models/core.py b/superset/models/core.py
index 5717726edc..d063f7a0eb 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -421,32 +421,40 @@ class Database(
source: Optional[utils.QuerySource] = None,
sqlalchemy_uri: Optional[str] = None,
) -> Engine:
- extra = self.get_extra()
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
- sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
+ extra = self.get_extra()
+ params = extra.get("engine_params", {})
+ if nullpool:
+ params["poolclass"] = NullPool
+ connect_args = params.get("connect_args", {})
+
+ sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri(
+ sqlalchemy_url,
+ connect_args,
+ schema,
+ )
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
# configuration parameter instead.
sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
- sqlalchemy_url, self.impersonate_user, effective_username
+ sqlalchemy_url,
+ self.impersonate_user,
+ effective_username,
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
- params = extra.get("engine_params", {})
- if nullpool:
- params["poolclass"] = NullPool
-
- connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
- connect_args, str(sqlalchemy_url), effective_username
+ connect_args,
+ str(sqlalchemy_url),
+ effective_username,
)
if connect_args:
@@ -464,7 +472,11 @@ class Database(
source = utils.QuerySource.SQL_LAB
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
- sqlalchemy_url, params, effective_username, security_manager, source
+ sqlalchemy_url,
+ params,
+ effective_username,
+ security_manager,
+ source,
)
try:
return create_engine(sqlalchemy_url, **params)