You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/06/13 23:39:16 UTC

[superset] 05/07: feat: add enforce URI query params with a specific for MySQL (#23723)

This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch elizabeth/test-2.1.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 882117492117378bce0c002c7e250322ed560931
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Tue Apr 18 17:07:37 2023 +0100

    feat: add enforce URI query params with a specific for MySQL (#23723)
---
 superset/db_engine_specs/base.py               |  9 +++++++-
 superset/db_engine_specs/mysql.py              |  6 ++++-
 tests/integration_tests/model_tests.py         | 15 ++++++++++++
 tests/unit_tests/db_engine_specs/test_mysql.py | 32 ++++++++++++++++++++++++--
 4 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 5243b4660d..21aa171323 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -356,6 +356,8 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     top_keywords: Set[str] = {"TOP"}
     # A set of disallowed connection query parameters
     disallow_uri_query_params: Set[str] = set()
+    # A Dict of query parameters that will always be used on every connection
+    enforce_uri_query_params: Dict[str, Any] = {}
 
     force_column_alias_quotes = False
     arraysize = 0
@@ -1016,8 +1018,13 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
         Some database drivers like Presto accept '{catalog}/{schema}' in
         the database component of the URL, that can be handled here.
+
+        Currently, changing the catalog is not supported. The method accepts a catalog so
+        that when catalog support is added to Superset the interface remains the same.
+        This is important because DB engine specs can be installed from 3rd party
+        packages.
         """
-        return uri
+        return uri, {**cls.enforce_uri_query_params}
 
     @classmethod
     def patch(cls) -> None:
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 348b3287e3..28ef442319 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -174,6 +174,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
         ),
     }
     disallow_uri_query_params = {"local_infile"}
+    enforce_uri_query_params = {"local_infile": 0}
 
     @classmethod
     def convert_dttm(
@@ -192,10 +193,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
     def adjust_database_uri(
         cls, uri: URL, selected_schema: Optional[str] = None
     ) -> URL:
+        uri, new_connect_args = super(
+            MySQLEngineSpec, MySQLEngineSpec
+        ).adjust_database_uri(uri)
         if selected_schema:
             uri = uri.set(database=parse.quote(selected_schema, safe=""))
 
-        return uri
+        return uri, new_connect_args
 
     @classmethod
     def get_datatype(cls, type_code: Any) -> Optional[str]:
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index da6c5e6a3c..35dbcc0a6b 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase):
                 "password": "original_user_password",
             }
 
+    @unittest.skipUnless(
+        SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
+    )
+    @mock.patch("superset.models.core.create_engine")
+    def test_adjust_engine_params_mysql(self, mocked_create_engine):
+        model = Database(
+            database_name="test_database",
+            sqlalchemy_uri="mysql://user:password@localhost",
+        )
+        model._get_sqla_engine()
+        call_args = mocked_create_engine.call_args
+
+        assert str(call_args[0][0]) == "mysql://user:password@localhost"
+        assert call_args[1]["connect_args"]["local_infile"] == 0
+
     @mock.patch("superset.models.core.create_engine")
     def test_impersonate_user_trino(self, mocked_create_engine):
         principal_user = security_manager.find_user(username="gamma")
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index a512e71a97..3a24e1c2dc 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,7 +16,7 @@
 # under the License.
 
 from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Dict, Optional, Tuple, Type
 from unittest.mock import Mock, patch
 
 import pytest
@@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import (
     TINYINT,
     TINYTEXT,
 )
-from sqlalchemy.engine.url import make_url
+from sqlalchemy.engine.url import make_url, URL
 
 from superset.utils.core import GenericDataType
 from tests.unit_tests.db_engine_specs.utils import (
@@ -119,6 +119,34 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
     MySQLEngineSpec.validate_database_uri(url)
 
 
+@pytest.mark.parametrize(
+    "sqlalchemy_uri,connect_args,returns",
+    [
+        ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
+        ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
+        ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
+        (
+            "mysql://user:password@host/db1",
+            {"param1": "some_value"},
+            {"local_infile": 0, "param1": "some_value"},
+        ),
+        (
+            "mysql://user:password@host/db1",
+            {"local_infile": 1, "param1": "some_value"},
+            {"local_infile": 0, "param1": "some_value"},
+        ),
+    ],
+)
+def test_adjust_database_uri(
+    sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
+) -> None:
+    from superset.db_engine_specs.mysql import MySQLEngineSpec
+
+    url = make_url(sqlalchemy_uri)
+    returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url)
+    assert returned_connect_args == returns
+
+
 @patch("sqlalchemy.engine.Engine.connect")
 def test_get_cancel_query_id(engine_mock: Mock) -> None:
     from superset.db_engine_specs.mysql import MySQLEngineSpec