You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/06/14 15:51:51 UTC

[superset] 01/18: fix: allow db driver distinction on enforced URI params (#23769)

This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to tag 2.1.1rc1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b26901cb05d62637abd2aaa7144378516f4b7e0f
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Sun Apr 23 15:44:21 2023 +0100

    fix: allow db driver distinction on enforced URI params (#23769)
---
 superset/db_engine_specs/base.py               | 19 +++++++++------
 superset/db_engine_specs/drill.py              | 11 ++++++---
 superset/db_engine_specs/hive.py               | 10 ++++----
 superset/db_engine_specs/mysql.py              | 17 ++++++++++----
 superset/db_engine_specs/presto.py             |  9 +++++---
 superset/db_engine_specs/snowflake.py          | 10 ++++----
 superset/models/core.py                        | 15 +++++++-----
 tests/integration_tests/model_tests.py         | 12 +++++++++-
 tests/unit_tests/db_engine_specs/test_mysql.py | 32 +++++++++++++++++++++++++-
 9 files changed, 102 insertions(+), 33 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index af2699a6dd..27dd34a802 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -354,10 +354,11 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     # This set will give the keywords for data limit statements
     # to consider for the engines with TOP SQL parsing
     top_keywords: Set[str] = {"TOP"}
-    # A set of disallowed connection query parameters
-    disallow_uri_query_params: Set[str] = set()
+    # A set of disallowed connection query parameters by driver name
+    disallow_uri_query_params: Dict[str, Set[str]] = {}
     # A Dict of query parameters that will always be used on every connection
-    enforce_uri_query_params: Dict[str, Any] = {}
+    # by driver name
+    enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}
 
     force_column_alias_quotes = False
     arraysize = 0
@@ -999,6 +1000,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     def adjust_database_uri(  # pylint: disable=unused-argument
         cls,
         uri: URL,
+        connect_args: Dict[str, Any],
         selected_schema: Optional[str] = None,
     ) -> Tuple[URL, Dict[str, Any]]:
         """
@@ -1024,7 +1026,10 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         This is important because DB engine specs can be installed from 3rd party
         packages.
         """
-        return uri, {**cls.enforce_uri_query_params}
+        return uri, {
+            **connect_args,
+            **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
+        }
 
     @classmethod
     def patch(cls) -> None:
@@ -1744,9 +1749,9 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
         :param sqlalchemy_uri:
         """
-        if existing_disallowed := cls.disallow_uri_query_params.intersection(
-            sqlalchemy_uri.query
-        ):
+        if existing_disallowed := cls.disallow_uri_query_params.get(
+            sqlalchemy_uri.get_driver_name(), set()
+        ).intersection(sqlalchemy_uri.query):
             raise ValueError(f"Forbidden query parameter(s): {existing_disallowed}")
 
 
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 756f74e82a..d8a1940007 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
 from urllib import parse
 
 from sqlalchemy import types
@@ -69,11 +69,16 @@ class DrillEngineSpec(BaseEngineSpec):
         return None
 
     @classmethod
-    def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> URL:
+    def adjust_database_uri(
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        selected_schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         if selected_schema:
             uri = uri.set(database=parse.quote(selected_schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def get_url_for_impersonation(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index c049ee652e..f07d53518c 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -191,7 +191,6 @@ class HiveEngineSpec(PrestoEngineSpec):
             raise SupersetException("Append operation not currently supported")
 
         if to_sql_kwargs["if_exists"] == "fail":
-
             # Ensure table doesn't already exist.
             if table.schema:
                 table_exists = not database.get_df(
@@ -260,12 +259,15 @@ class HiveEngineSpec(PrestoEngineSpec):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        selected_schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         if selected_schema:
             uri = uri.set(database=parse.quote(selected_schema, safe=""))
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def _extract_error_message(cls, ex: Exception) -> str:
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 457509f7a7..622e6c985c 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -173,8 +173,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
             {},
         ),
     }
-    disallow_uri_query_params = {"local_infile"}
-    enforce_uri_query_params = {"local_infile": 0}
+    disallow_uri_query_params = {
+        "mysqldb": {"local_infile"},
+        "mysqlconnector": {"allow_local_infile"},
+    }
+    enforce_uri_query_params = {
+        "mysqldb": {"local_infile": 0},
+        "mysqlconnector": {"allow_local_infile": 0},
+    }
 
     @classmethod
     def convert_dttm(
@@ -191,11 +197,14 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        selected_schema: Optional[str] = None,
     ) -> Tuple[URL, Dict[str, Any]]:
         uri, new_connect_args = super(
             MySQLEngineSpec, MySQLEngineSpec
-        ).adjust_database_uri(uri)
+        ).adjust_database_uri(uri, connect_args)
         if selected_schema:
             uri = uri.set(database=parse.quote(selected_schema, safe=""))
 
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 72931a85b4..6bd556b79e 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -300,8 +300,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        selected_schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
         if selected_schema and database:
             selected_schema = parse.quote(selected_schema, safe="")
@@ -311,7 +314,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
                 database += "/" + selected_schema
             uri = uri.set(database=database)
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 419e0a0655..35801fa768 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -134,8 +134,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
-    ) -> URL:
+        cls,
+        uri: URL,
+        connect_args: Dict[str, Any],
+        selected_schema: Optional[str] = None,
+    ) -> Tuple[URL, Dict[str, Any]]:
         database = uri.database
         if "/" in uri.database:
             database = uri.database.split("/")[0]
@@ -143,7 +146,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
             selected_schema = parse.quote(selected_schema, safe="")
             uri = uri.set(database=f"{database}/{selected_schema}")
 
-        return uri
+        return uri, connect_args
 
     @classmethod
     def epoch_to_dttm(cls) -> str:
@@ -222,7 +225,6 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
             Dict[str, Any]
         ] = None,
     ) -> str:
-
         return str(
             URL(
                 "snowflake",
diff --git a/superset/models/core.py b/superset/models/core.py
index 9c67a2efa6..fce323b13c 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -426,7 +426,15 @@ class Database(
         )
         self.db_engine_spec.validate_database_uri(sqlalchemy_url)
 
-        sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
+        params = extra.get("engine_params", {})
+        if nullpool:
+            params["poolclass"] = NullPool
+
+        connect_args = params.get("connect_args", {})
+
+        sqlalchemy_url, connect_args = self.db_engine_spec.adjust_database_uri(
+            sqlalchemy_url, connect_args, schema
+        )
         effective_username = self.get_effective_user(sqlalchemy_url)
         # If using MySQL or Presto for example, will set url.username
         # If using Hive, will not do anything yet since that relies on a
@@ -438,11 +446,6 @@ class Database(
         masked_url = self.get_password_masked_url(sqlalchemy_url)
         logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
 
-        params = extra.get("engine_params", {})
-        if nullpool:
-            params["poolclass"] = NullPool
-
-        connect_args = params.get("connect_args", {})
         if self.impersonate_user:
             self.db_engine_spec.update_impersonation_config(
                 connect_args, str(sqlalchemy_url), effective_username
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index 35dbcc0a6b..d5684b1b62 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -194,7 +194,7 @@ class TestDatabaseModel(SupersetTestCase):
     @mock.patch("superset.models.core.create_engine")
     def test_adjust_engine_params_mysql(self, mocked_create_engine):
         model = Database(
-            database_name="test_database",
+            database_name="test_database1",
             sqlalchemy_uri="mysql://user:password@localhost",
         )
         model._get_sqla_engine()
@@ -203,6 +203,16 @@ class TestDatabaseModel(SupersetTestCase):
         assert str(call_args[0][0]) == "mysql://user:password@localhost"
         assert call_args[1]["connect_args"]["local_infile"] == 0
 
+        model = Database(
+            database_name="test_database2",
+            sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
+        )
+        model._get_sqla_engine()
+        call_args = mocked_create_engine.call_args
+
+        assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
+        assert call_args[1]["connect_args"]["allow_local_infile"] == 0
+
     @mock.patch("superset.models.core.create_engine")
     def test_impersonate_user_trino(self, mocked_create_engine):
         principal_user = security_manager.find_user(username="gamma")
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index 3a24e1c2dc..a6f0d99e04 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -104,8 +104,11 @@ def test_convert_dttm(
     "sqlalchemy_uri,error",
     [
         ("mysql://user:password@host/db1?local_infile=1", True),
+        ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", True),
         ("mysql://user:password@host/db1?local_infile=0", True),
+        ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", True),
         ("mysql://user:password@host/db1", False),
+        ("mysql+mysqlconnector://user:password@host/db1", False),
     ],
 )
 def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
@@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
     "sqlalchemy_uri,connect_args,returns",
     [
         ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": 1},
+            {"allow_local_infile": 0},
+        ),
         ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": -1},
+            {"allow_local_infile": 0},
+        ),
         ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": 0},
+            {"allow_local_infile": 0},
+        ),
         (
             "mysql://user:password@host/db1",
             {"param1": "some_value"},
             {"local_infile": 0, "param1": "some_value"},
         ),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"param1": "some_value"},
+            {"allow_local_infile": 0, "param1": "some_value"},
+        ),
         (
             "mysql://user:password@host/db1",
             {"local_infile": 1, "param1": "some_value"},
             {"local_infile": 0, "param1": "some_value"},
         ),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": 1, "param1": "some_value"},
+            {"allow_local_infile": 0, "param1": "some_value"},
+        ),
     ],
 )
 def test_adjust_database_uri(
@@ -143,7 +171,9 @@ def test_adjust_database_uri(
     from superset.db_engine_specs.mysql import MySQLEngineSpec
 
     url = make_url(sqlalchemy_uri)
-    returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url)
+    returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(
+        url, connect_args
+    )
     assert returned_connect_args == returns