You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/06/14 15:51:51 UTC
[superset] 01/18: fix: allow db driver distinction on enforced URI params (#23769)
This is an automated email from the ASF dual-hosted git repository.
elizabeth pushed a commit to tag 2.1.1rc1
in repository https://gitbox.apache.org/repos/asf/superset.git
commit b26901cb05d62637abd2aaa7144378516f4b7e0f
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Sun Apr 23 15:44:21 2023 +0100
fix: allow db driver distinction on enforced URI params (#23769)
---
superset/db_engine_specs/base.py | 19 +++++++++------
superset/db_engine_specs/drill.py | 11 ++++++---
superset/db_engine_specs/hive.py | 10 ++++----
superset/db_engine_specs/mysql.py | 17 ++++++++++----
superset/db_engine_specs/presto.py | 9 +++++---
superset/db_engine_specs/snowflake.py | 10 ++++----
superset/models/core.py | 15 +++++++-----
tests/integration_tests/model_tests.py | 12 +++++++++-
tests/unit_tests/db_engine_specs/test_mysql.py | 32 +++++++++++++++++++++++++-
9 files changed, 102 insertions(+), 33 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index af2699a6dd..27dd34a802 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -354,10 +354,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}
- # A set of disallowed connection query parameters
- disallow_uri_query_params: Set[str] = set()
+ # A set of disallowed connection query parameters by driver name
+ disallow_uri_query_params: Dict[str, Set[str]] = {}
# A Dict of query parameters that will always be used on every connection
- enforce_uri_query_params: Dict[str, Any] = {}
+ # by driver name
+ enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}
force_column_alias_quotes = False
arraysize = 0
@@ -999,6 +1000,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def adjust_database_uri( # pylint: disable=unused-argument
cls,
uri: URL,
+ connect_args: Dict[str, Any],
selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
"""
@@ -1024,7 +1026,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
This is important because DB engine specs can be installed from 3rd party
packages.
"""
- return uri, {**cls.enforce_uri_query_params}
+ return uri, {
+ **connect_args,
+ **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
+ }
@classmethod
def patch(cls) -> None:
@@ -1744,9 +1749,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sqlalchemy_uri:
"""
- if existing_disallowed := cls.disallow_uri_query_params.intersection(
- sqlalchemy_uri.query
- ):
+ if existing_disallowed := cls.disallow_uri_query_params.get(
+ sqlalchemy_uri.get_driver_name(), set()
+ ).intersection(sqlalchemy_uri.query):
raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 756f74e82a..d8a1940007 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
from urllib import parse
from sqlalchemy import types
@@ -69,11 +69,16 @@ class DrillEngineSpec(BaseEngineSpec):
return None
@classmethod
- def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
+ def adjust_database_uri(
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ selected_schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
- return uri
+ return uri, connect_args
@classmethod
def get_url_for_impersonation(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index c049ee652e..f07d53518c 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -191,7 +191,6 @@ class HiveEngineSpec(PrestoEngineSpec):
raise SupersetException("Append operation not currently supported")
if to_sql_kwargs["if_exists"] == "fail":
-
# Ensure table doesn't already exist.
if table.schema:
table_exists = not database.get_df(
@@ -260,12 +259,15 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> URL:
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ selected_schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
- return uri
+ return uri, connect_args
@classmethod
def _extract_error_message(cls, ex: Exception) -> str:
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 457509f7a7..622e6c985c 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -173,8 +173,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
{},
),
}
- disallow_uri_query_params = {"local_infile"}
- enforce_uri_query_params = {"local_infile": 0}
+ disallow_uri_query_params = {
+ "mysqldb": {"local_infile"},
+ "mysqlconnector": {"allow_local_infile"},
+ }
+ enforce_uri_query_params = {
+ "mysqldb": {"local_infile": 0},
+ "mysqlconnector": {"allow_local_infile": 0},
+ }
@classmethod
def convert_dttm(
@@ -191,11 +197,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ selected_schema: Optional[str] = None,
) -> Tuple[URL, Dict[str, Any]]:
uri, new_connect_args = super(
MySQLEngineSpec, MySQLEngineSpec
- ).adjust_database_uri(uri)
+ ).adjust_database_uri(uri, connect_args)
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 72931a85b4..6bd556b79e 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -300,8 +300,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> URL:
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ selected_schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
@@ -311,7 +314,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
database += "/" + selected_schema
uri = uri.set(database=database)
- return uri
+ return uri, connect_args
@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 419e0a0655..35801fa768 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -134,8 +134,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> URL:
+ cls,
+ uri: URL,
+ connect_args: Dict[str, Any],
+ selected_schema: Optional[str] = None,
+ ) -> Tuple[URL, Dict[str, Any]]:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
@@ -143,7 +146,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
selected_schema = parse.quote(selected_schema, safe="")
uri = uri.set(database=f"{database}/{selected_schema}")
- return uri
+ return uri, connect_args
@classmethod
def epoch_to_dttm(cls) -> str:
@@ -222,7 +225,6 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
Dict[str, Any]
] = None,
) -> str:
-
return str(
URL(
"snowflake",
diff --git a/superset/models/core.py b/superset/models/core.py
index 9c67a2efa6..fce323b13c 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -426,7 +426,15 @@ class Database(
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
- sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
+ params = extra.get("engine_params", {})
+ if nullpool:
+ params["poolclass"] = NullPool
+
+ connect_args = params.get("connect_args", {})
+
+ sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri(
+ sqlalchemy_url, connect_args, schema
+ )
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a
@@ -438,11 +446,6 @@ class Database(
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
- params = extra.get("engine_params", {})
- if nullpool:
- params["poolclass"] = NullPool
-
- connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 35dbcc0a6b..d5684b1b62 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -194,7 +194,7 @@ class TestDatabaseModel(SupersetTestCase):
@mock.patch("superset.models.core.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
- database_name="test_database",
+ database_name="test_database1",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
@@ -203,6 +203,16 @@ class TestDatabaseModel(SupersetTestCase):
assert str(call_args[0][0]) == "mysql://user:password@localhost"
assert call_args[1]["connect_args"]["local_infile"] == 0
+ model = Database(
+ database_name="test_database2",
+ sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
+ )
+ model._get_sqla_engine()
+ call_args = mocked_create_engine.call_args
+
+ assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
+ assert call_args[1]["connect_args"]["allow_local_infile"] == 0
+
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index 3a24e1c2dc..a6f0d99e04 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -104,8 +104,11 @@ def test_convert_dttm(
"sqlalchemy_uri,error",
[
("mysql://user:password@host/db1?local_infile=1", True),
+ ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True),
("mysql://user:password@host/db1?local_infile=0", True),
+ ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True),
("mysql://user:password@host/db1", False),
+ ("mysql+mysqlconnector://user:password@host/db1", False),
],
)
def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
@@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
"sqlalchemy_uri,connect_args,returns",
[
("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
+ (
+ "mysql+mysqlconnector://user:password@host/db1",
+ {"allow_local_infile": 1},
+ {"allow_local_infile": 0},
+ ),
("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
+ (
+ "mysql+mysqlconnector://user:password@host/db1",
+ {"allow_local_infile": -1},
+ {"allow_local_infile": 0},
+ ),
("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
+ (
+ "mysql+mysqlconnector://user:password@host/db1",
+ {"allow_local_infile": 0},
+ {"allow_local_infile": 0},
+ ),
(
"mysql://user:password@host/db1",
{"param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
+ (
+ "mysql+mysqlconnector://user:password@host/db1",
+ {"param1": "some_value"},
+ {"allow_local_infile": 0, "param1": "some_value"},
+ ),
(
"mysql://user:password@host/db1",
{"local_infile": 1, "param1": "some_value"},
{"local_infile": 0, "param1": "some_value"},
),
+ (
+ "mysql+mysqlconnector://user:password@host/db1",
+ {"allow_local_infile": 1, "param1": "some_value"},
+ {"allow_local_infile": 0, "param1": "some_value"},
+ ),
],
)
def test_adjust_database_uri(
@@ -143,7 +171,9 @@ def test_adjust_database_uri(
from superset.db_engine_specs.mysql import MySQLEngineSpec
url = make_url(sqlalchemy_uri)
- returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url)
+ returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(
+ url, connect_args
+ )
assert returned_connect_args == returns