You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by dp...@apache.org on 2023/04/18 16:07:46 UTC

[superset] branch master updated: feat: add enforce URI query params with a specific for MySQL (#23723)

This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 0ad6c879b3 feat: add enforce URI query params with a specific for MySQL (#23723)
0ad6c879b3 is described below

commit 0ad6c879b3be44b6cb220dd1a03a541d2fe65d9b
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Tue Apr 18 17:07:37 2023 +0100

    feat: add enforce URI query params with a specific for MySQL (#23723)
---
 superset/db_engine_specs/base.py               | 11 ++++++---
 superset/db_engine_specs/mysql.py              |  6 ++++-
 tests/integration_tests/model_tests.py         | 15 ++++++++++++
 tests/unit_tests/db_engine_specs/test_mysql.py | 34 ++++++++++++++++++++++++--
 4 files changed, 59 insertions(+), 7 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index ed58e8cb86..93df7c7216 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -357,6 +357,8 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     top_keywords: Set[str] = {"TOP"}
     # A set of disallowed connection query parameters
     disallow_uri_query_params: Set[str] = set()
+    # A Dict of query parameters that will always be used on every connection
+    enforce_uri_query_params: Dict[str, Any] = {}
 
     force_column_alias_quotes = False
     arraysize = 0
@@ -1089,11 +1091,12 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
         given query is running in order to enforce permissions (see #23385 and #23401).
 
-        Currently, changing the catalog is not supported. The method acceps a catalog so
-        that when catalog support is added to Superse the interface remains the same. This
-        is important because DB engine specs can be installed from 3rd party packages.
+        Currently, changing the catalog is not supported. The method accepts a catalog so
+        that when catalog support is added to Superset the interface remains the same.
+        This is important because DB engine specs can be installed from 3rd party
+        packages.
         """
-        return uri, connect_args
+        return uri, {**connect_args, **cls.enforce_uri_query_params}
 
     @classmethod
     def patch(cls) -> None:
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index e5ff964f86..07d2aea362 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -176,6 +176,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
         ),
     }
     disallow_uri_query_params = {"local_infile"}
+    enforce_uri_query_params = {"local_infile": 0}
 
     @classmethod
     def convert_dttm(
@@ -198,10 +199,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
         catalog: Optional[str] = None,
         schema: Optional[str] = None,
     ) -> Tuple[URL, Dict[str, Any]]:
+        uri, new_connect_args = super(
+            MySQLEngineSpec, MySQLEngineSpec
+        ).adjust_engine_params(uri, connect_args, catalog, schema)
         if schema:
             uri = uri.set(database=parse.quote(schema, safe=""))
 
-        return uri, connect_args
+        return uri, new_connect_args
 
     @classmethod
     def get_schema_from_engine_params(
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index da6c5e6a3c..35dbcc0a6b 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase):
                 "password": "original_user_password",
             }
 
+    @unittest.skipUnless(
+        SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
+    )
+    @mock.patch("superset.models.core.create_engine")
+    def test_adjust_engine_params_mysql(self, mocked_create_engine):
+        model = Database(
+            database_name="test_database",
+            sqlalchemy_uri="mysql://user:password@localhost",
+        )
+        model._get_sqla_engine()
+        call_args = mocked_create_engine.call_args
+
+        assert str(call_args[0][0]) == "mysql://user:password@localhost"
+        assert call_args[1]["connect_args"]["local_infile"] == 0
+
     @mock.patch("superset.models.core.create_engine")
     def test_impersonate_user_trino(self, mocked_create_engine):
         principal_user = security_manager.find_user(username="gamma")
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index 091cdb3b46..31e01ace58 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,7 +16,7 @@
 # under the License.
 
 from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Dict, Optional, Tuple, Type
 from unittest.mock import Mock, patch
 
 import pytest
@@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import (
     TINYINT,
     TINYTEXT,
 )
-from sqlalchemy.engine.url import make_url
+from sqlalchemy.engine.url import make_url, URL
 
 from superset.utils.core import GenericDataType
 from tests.unit_tests.db_engine_specs.utils import (
@@ -119,6 +119,36 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
     MySQLEngineSpec.validate_database_uri(url)
 
 
+@pytest.mark.parametrize(
+    "sqlalchemy_uri,connect_args,returns",
+    [
+        ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
+        ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
+        ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
+        (
+            "mysql://user:password@host/db1",
+            {"param1": "some_value"},
+            {"local_infile": 0, "param1": "some_value"},
+        ),
+        (
+            "mysql://user:password@host/db1",
+            {"local_infile": 1, "param1": "some_value"},
+            {"local_infile": 0, "param1": "some_value"},
+        ),
+    ],
+)
+def test_adjust_engine_params(
+    sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
+) -> None:
+    from superset.db_engine_specs.mysql import MySQLEngineSpec
+
+    url = make_url(sqlalchemy_uri)
+    returned_url, returned_connect_args = MySQLEngineSpec.adjust_engine_params(
+        url, connect_args
+    )
+    assert returned_connect_args == returns
+
+
 @patch("sqlalchemy.engine.Engine.connect")
 def test_get_cancel_query_id(engine_mock: Mock) -> None:
     from superset.db_engine_specs.mysql import MySQLEngineSpec