You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2022/12/16 01:08:48 UTC

[superset] branch master updated: chore: set Snowflake user agent (#22432)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new bdeedaaf80 chore: set Snowflake user agent (#22432)
bdeedaaf80 is described below

commit bdeedaaf80deb5785d82b786e713c8a3cb579ee3
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Dec 15 17:08:34 2022 -0800

    chore: set Snowflake user agent (#22432)
---
 superset/db_engine_specs/databricks.py             | 12 +++---
 superset/db_engine_specs/snowflake.py              | 16 +++++++-
 .../db_engine_specs/databricks_tests.py            | 21 ++++++-----
 .../unit_tests/db_engine_specs/test_databricks.py  | 44 +++++++++++++++++++++-
 tests/unit_tests/db_engine_specs/test_snowflake.py | 31 +++++++++++++++
 5 files changed, 107 insertions(+), 17 deletions(-)

diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 7ebe6ab1ab..131679359c 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -163,11 +163,13 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
         """
         Add a user agent to be used in the requests.
         """
-        extra = {
-            "http_headers": [("User-Agent", USER_AGENT)],
-            "_user_agent_entry": USER_AGENT,
-        }
-        extra.update(BaseEngineSpec.get_extra_params(database))
+        extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+        engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
+        connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+
+        connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)])
+        connect_args.setdefault("_user_agent_entry", USER_AGENT)
+
         return extra
 
     @classmethod
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 0704712d65..578ded965b 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -31,8 +31,9 @@ from marshmallow import fields, Schema
 from sqlalchemy.engine.url import URL
 from typing_extensions import TypedDict
 
+from superset.constants import USER_AGENT
 from superset.databases.utils import make_url_safe
-from superset.db_engine_specs.base import BasicPropertiesType
+from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
@@ -118,6 +119,19 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
         ),
     }
 
+    @staticmethod
+    def get_extra_params(database: "Database") -> Dict[str, Any]:
+        """
+        Add a user agent to be used in the requests.
+        """
+        extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+        engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
+        connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+
+        connect_args.setdefault("application", USER_AGENT)
+
+        return extra
+
     @classmethod
     def adjust_database_uri(
         cls, uri: URL, selected_schema: Optional[str] = None
diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py
index b399e41fd3..c2d57831a5 100644
--- a/tests/integration_tests/db_engine_specs/databricks_tests.py
+++ b/tests/integration_tests/db_engine_specs/databricks_tests.py
@@ -44,16 +44,17 @@ class TestDatabricksDbEngineSpec(TestDbEngineSpec):
         db.extra = default_db_extra
         db.server_cert = None
         extras = DatabricksNativeEngineSpec.get_extra_params(db)
-        assert "connect_args" not in extras["engine_params"]
-
-    def test_extras_with_user_agent(self):
-        db = mock.Mock()
-        db.extra = default_db_extra
-        extras = DatabricksNativeEngineSpec.get_extra_params(db)
-        _, user_agent = extras["http_headers"][0]
-        user_agent_entry = extras["_user_agent_entry"]
-        assert user_agent == USER_AGENT
-        assert user_agent_entry == USER_AGENT
+        assert extras == {
+            "engine_params": {
+                "connect_args": {
+                    "_user_agent_entry": "Apache Superset",
+                    "http_headers": [("User-Agent", "Apache Superset")],
+                },
+            },
+            "metadata_cache_timeout": {},
+            "metadata_params": {},
+            "schemas_allowed_for_file_upload": [],
+        }
 
     def test_extras_with_ssl_custom(self):
         db = mock.Mock()
diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py
index 0cc0907f4d..50c7fd47a3 100644
--- a/tests/unit_tests/db_engine_specs/test_databricks.py
+++ b/tests/unit_tests/db_engine_specs/test_databricks.py
@@ -18,8 +18,9 @@
 
 import json
 
+from pytest_mock import MockerFixture
+
 from superset.utils.core import GenericDataType
-from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types
 
 
 def test_get_parameters_from_uri() -> None:
@@ -110,6 +111,7 @@ def test_generic_type() -> None:
     assert that generic types match
     """
     from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+    from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types
 
     type_expectations = (
         # Numeric
@@ -133,3 +135,43 @@ def test_generic_type() -> None:
         ("BOOLEAN", GenericDataType.BOOLEAN),
     )
     assert_generic_types(DatabricksNativeEngineSpec, type_expectations)
+
+
+def test_get_extra_params(mocker: MockerFixture) -> None:
+    """
+    Test the ``get_extra_params`` method.
+    """
+    from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+
+    database = mocker.MagicMock()
+
+    database.extra = {}
+    assert DatabricksNativeEngineSpec.get_extra_params(database) == {
+        "engine_params": {
+            "connect_args": {
+                "http_headers": [("User-Agent", "Apache Superset")],
+                "_user_agent_entry": "Apache Superset",
+            }
+        }
+    }
+
+    database.extra = json.dumps(
+        {
+            "engine_params": {
+                "connect_args": {
+                    "http_headers": [("User-Agent", "Custom user agent")],
+                    "_user_agent_entry": "Custom user agent",
+                    "foo": "bar",
+                }
+            }
+        }
+    )
+    assert DatabricksNativeEngineSpec.get_extra_params(database) == {
+        "engine_params": {
+            "connect_args": {
+                "http_headers": [["User-Agent", "Custom user agent"]],
+                "_user_agent_entry": "Custom user agent",
+                "foo": "bar",
+            }
+        }
+    }
diff --git a/tests/unit_tests/db_engine_specs/test_snowflake.py b/tests/unit_tests/db_engine_specs/test_snowflake.py
index 2479e071f2..3611c7214d 100644
--- a/tests/unit_tests/db_engine_specs/test_snowflake.py
+++ b/tests/unit_tests/db_engine_specs/test_snowflake.py
@@ -14,11 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+# pylint: disable=import-outside-toplevel
+
 import json
 from datetime import datetime
 from unittest import mock
 
 import pytest
+from pytest_mock import MockerFixture
 
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from tests.unit_tests.fixtures.common import dttm
@@ -122,3 +126,30 @@ def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
     query = Query()
     cursor_mock = engine_mock.raiseError.side_effect = Exception()
     assert SnowflakeEngineSpec.cancel_query(cursor_mock, query, "123") is False
+
+
+def test_get_extra_params(mocker: MockerFixture) -> None:
+    """
+    Test the ``get_extra_params`` method.
+    """
+    from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
+
+    database = mocker.MagicMock()
+
+    database.extra = {}
+    assert SnowflakeEngineSpec.get_extra_params(database) == {
+        "engine_params": {"connect_args": {"application": "Apache Superset"}}
+    }
+
+    database.extra = json.dumps(
+        {
+            "engine_params": {
+                "connect_args": {"application": "Custom user agent", "foo": "bar"}
+            }
+        }
+    )
+    assert SnowflakeEngineSpec.get_extra_params(database) == {
+        "engine_params": {
+            "connect_args": {"application": "Custom user agent", "foo": "bar"}
+        }
+    }