You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/06/13 23:39:16 UTC
[superset] 05/07: feat: add enforce URI query params with a specific for MySQL (#23723)
This is an automated email from the ASF dual-hosted git repository.
elizabeth pushed a commit to branch elizabeth/test-2.1.1
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 882117492117378bce0c002c7e250322ed560931
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Tue Apr 18 17:07:37 2023 +0100
feat: add enforce URI query params with a specific for MySQL (#23723)
---
superset/db_engine_specs/base.py | 9 +++++++-
superset/db_engine_specs/mysql.py | 6 ++++-
tests/integration_tests/model_tests.py | 15 ++++++++++++
tests/unit_tests/db_engine_specs/test_mysql.py | 32 ++++++++++++++++++++++++--
4 files changed, 58 insertions(+), 4 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 5243b4660d..21aa171323 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -356,6 +356,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
top_keywords: Set[str] = {"TOP"}
# A set of disallowed connection query parameters
disallow_uri_query_params: Set[str] = set()
+ # A Dict of query parameters that will always be used on every connection
+ enforce_uri_query_params: Dict[str, Any] = {}
force_column_alias_quotes = False
arraysize = 0
@@ -1016,8 +1018,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Some database drivers like Presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
+
+ Currently, changing the catalog is not supported. The method accepts a catalog so
+ that when catalog support is added to Superset the interface remains the same.
+ This is important because DB engine specs can be installed from 3rd party
+ packages.
"""
- return uri
+ return uri, {**cls.enforce_uri_query_params}
@classmethod
def patch(cls) -> None:
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 348b3287e3..28ef442319 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -174,6 +174,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
),
}
disallow_uri_query_params = {"local_infile"}
+ enforce_uri_query_params = {"local_infile": 0}
@classmethod
def convert_dttm(
@@ -192,10 +193,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> URL:
+ uri, new_connect_args = super(
+ MySQLEngineSpec, MySQLEngineSpec
+ ).adjust_database_uri(uri)
if selected_schema:
uri = uri.set(database=parse.quote(selected_schema, safe=""))
- return uri
+ return uri, new_connect_args
@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index da6c5e6a3c..35dbcc0a6b 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase):
"password": "original_user_password",
}
+ @unittest.skipUnless(
+ SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
+ )
+ @mock.patch("superset.models.core.create_engine")
+ def test_adjust_engine_params_mysql(self, mocked_create_engine):
+ model = Database(
+ database_name="test_database",
+ sqlalchemy_uri="mysql://user:password@localhost",
+ )
+ model._get_sqla_engine()
+ call_args = mocked_create_engine.call_args
+
+ assert str(call_args[0][0]) == "mysql://user:password@localhost"
+ assert call_args[1]["connect_args"]["local_infile"] == 0
+
@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index a512e71a97..3a24e1c2dc 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,7 +16,7 @@
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Dict, Optional, Tuple, Type
from unittest.mock import Mock, patch
import pytest
@@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import (
TINYINT,
TINYTEXT,
)
-from sqlalchemy.engine.url import make_url
+from sqlalchemy.engine.url import make_url, URL
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
@@ -119,6 +119,34 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
MySQLEngineSpec.validate_database_uri(url)
+@pytest.mark.parametrize(
+ "sqlalchemy_uri,connect_args,returns",
+ [
+ ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
+ ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
+ ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
+ (
+ "mysql://user:password@host/db1",
+ {"param1": "some_value"},
+ {"local_infile": 0, "param1": "some_value"},
+ ),
+ (
+ "mysql://user:password@host/db1",
+ {"local_infile": 1, "param1": "some_value"},
+ {"local_infile": 0, "param1": "some_value"},
+ ),
+ ],
+)
+def test_adjust_database_uri(
+ sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
+) -> None:
+ from superset.db_engine_specs.mysql import MySQLEngineSpec
+
+ url = make_url(sqlalchemy_uri)
+ returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url)
+ assert returned_connect_args == returns
+
+
@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec