You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/05 15:41:52 UTC

[airflow] branch main updated: Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new df5a54d21d Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)
df5a54d21d is described below

commit df5a54d21d6991d6cae05c38e1562da2196e76aa
Author: gebo <bo...@gmail.com>
AuthorDate: Fri Aug 5 17:41:43 2022 +0200

    Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)
---
 airflow/providers/microsoft/mssql/hooks/mssql.py   |  56 ++++++++++-
 .../providers/microsoft/mssql/hooks/test_mssql.py  | 102 ++++++++++++++++++++-
 2 files changed, 156 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/microsoft/mssql/hooks/mssql.py b/airflow/providers/microsoft/mssql/hooks/mssql.py
index ba236efbc2..a54372adfd 100644
--- a/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -18,6 +18,8 @@
 
 """Microsoft SQLServer hook module"""
 
+from typing import Any, Optional
+
 import pymssql
 
 from airflow.providers.common.sql.hooks.sql import DbApiHook
@@ -31,10 +33,62 @@ class MsSqlHook(DbApiHook):
     conn_type = 'mssql'
     hook_name = 'Microsoft SQL Server'
     supports_autocommit = True
+    DEFAULT_SQLALCHEMY_SCHEME = 'mssql+pymssql'
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(
+        self,
+        *args,
+        sqlalchemy_scheme: Optional[str] = None,
+        **kwargs,
+    ) -> None:
+        """
+        :param args: passed to DBApiHook
+        :param sqlalchemy_scheme: Scheme sqlalchemy connection.  Default is ``mssql+pymssql`` Only used for
+          ``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods.
+        :param kwargs: passed to DbApiHook
+        """
         super().__init__(*args, **kwargs)
         self.schema = kwargs.pop("schema", None)
+        self._sqlalchemy_scheme = sqlalchemy_scheme
+
+    @property
+    def connection_extra_lower(self) -> dict:
+        """
+        ``connection.extra_dejson`` but where keys are converted to lower case.
+        This is used internally for case-insensitive access of mssql params.
+        """
+        conn = self.get_connection(self.mssql_conn_id)  # type: ignore[attr-defined]
+        return {k.lower(): v for k, v in conn.extra_dejson.items()}
+
+    @property
+    def sqlalchemy_scheme(self) -> str:
+        """Sqlalchemy scheme either from constructor, connection extras or default."""
+        return (
+            self._sqlalchemy_scheme
+            or self.connection_extra_lower.get('sqlalchemy_scheme')
+            or self.DEFAULT_SQLALCHEMY_SCHEME
+        )
+
+    def get_uri(self) -> str:
+        from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
+
+        r = list(urlsplit(super().get_uri()))
+        # change pymssql driver:
+        r[0] = self.sqlalchemy_scheme
+        # remove query string 'sqlalchemy_scheme' like parameters:
+        qs = parse_qs(r[3], keep_blank_values=True)
+        for k in list(qs.keys()):
+            if k.lower() == 'sqlalchemy_scheme':
+                qs.pop(k, None)
+        r[3] = urlencode(qs, doseq=True)
+        return urlunsplit(r)
+
+    def get_sqlalchemy_connection(
+        self, connect_kwargs: Optional[dict] = None, engine_kwargs: Optional[dict] = None
+    ) -> Any:
+        """Sqlalchemy connection object"""
+        engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs)
+        return engine.connect(**(connect_kwargs or {}))
 
     def get_conn(
         self,
diff --git a/tests/providers/microsoft/mssql/hooks/test_mssql.py b/tests/providers/microsoft/mssql/hooks/test_mssql.py
index cb3add3e94..82ce184225 100644
--- a/tests/providers/microsoft/mssql/hooks/test_mssql.py
+++ b/tests/providers/microsoft/mssql/hooks/test_mssql.py
@@ -18,11 +18,37 @@
 
 import unittest
 from unittest import mock
+from urllib.parse import quote_plus
+
+from parameterized import parameterized
 
 from airflow.models import Connection
 from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
 
-PYMSSQL_CONN = Connection(host='ip', schema='share', login='username', password='password', port=8081)
+PYMSSQL_CONN = Connection(
+    conn_type='mssql', host='ip', schema='share', login='username', password='password', port=8081
+)
+PYMSSQL_CONN_ALT = Connection(
+    conn_type='mssql', host='ip', schema='', login='username', password='password', port=8081
+)
+PYMSSQL_CONN_ALT_1 = Connection(
+    conn_type='mssql',
+    host='ip',
+    schema='',
+    login='username',
+    password='password',
+    port=8081,
+    extra={"SQlalchemy_Scheme": "mssql+testdriver"},
+)
+PYMSSQL_CONN_ALT_2 = Connection(
+    conn_type='mssql',
+    host='ip',
+    schema='',
+    login='username',
+    password='password',
+    port=8081,
+    extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"},
+)
 
 
 class TestMsSqlHook(unittest.TestCase):
@@ -64,3 +90,77 @@ class TestMsSqlHook(unittest.TestCase):
 
         mssql_get_conn.assert_called_once()
         assert hook.get_autocommit(conn) == 'autocommit_state'
+
+    @parameterized.expand(
+        [
+            (
+                PYMSSQL_CONN,
+                (
+                    "mssql+pymssql://"
+                    f"{quote_plus(PYMSSQL_CONN.login)}:{quote_plus(PYMSSQL_CONN.password)}"
+                    f"@{PYMSSQL_CONN.host}:{PYMSSQL_CONN.port}/{PYMSSQL_CONN.schema}"
+                ),
+            ),
+            (
+                PYMSSQL_CONN_ALT,
+                (
+                    "mssql+pymssql://"
+                    f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
+                    f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}"
+                ),
+            ),
+            (
+                PYMSSQL_CONN_ALT_1,
+                (
+                    f"{PYMSSQL_CONN_ALT_1.extra_dejson['SQlalchemy_Scheme']}://"
+                    f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
+                    f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}/"
+                ),
+            ),
+            (
+                PYMSSQL_CONN_ALT_2,
+                (
+                    f"{PYMSSQL_CONN_ALT_2.extra_dejson['SQlalchemy_Scheme']}://"
+                    f"{quote_plus(PYMSSQL_CONN_ALT_2.login)}:{quote_plus(PYMSSQL_CONN_ALT_2.password)}"
+                    f"@{PYMSSQL_CONN_ALT_2.host}:{PYMSSQL_CONN_ALT_2.port}/"
+                    f"?myparam={quote_plus(PYMSSQL_CONN_ALT_2.extra_dejson['myparam'])}"
+                ),
+            ),
+        ],
+    )
+    @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+    def test_get_uri_driver_rewrite(self, conn, exp_uri, get_connection):
+        get_connection.return_value = conn
+
+        hook = MsSqlHook()
+        res_uri = hook.get_uri()
+
+        get_connection.assert_called()
+        assert res_uri == exp_uri
+
+    @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+    def test_sqlalchemy_scheme_is_default(self, get_connection):
+        get_connection.return_value = PYMSSQL_CONN
+
+        hook = MsSqlHook()
+        assert hook.sqlalchemy_scheme == hook.DEFAULT_SQLALCHEMY_SCHEME
+
+    def test_sqlalchemy_scheme_is_from_hook(self):
+        hook = MsSqlHook(sqlalchemy_scheme="mssql+mytestdriver")
+        assert hook.sqlalchemy_scheme == "mssql+mytestdriver"
+
+    @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+    def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection):
+        get_connection.return_value = PYMSSQL_CONN_ALT_1
+
+        hook = MsSqlHook()
+        scheme = hook.sqlalchemy_scheme
+        get_connection.assert_called()
+        assert scheme == PYMSSQL_CONN_ALT_1.extra_dejson["SQlalchemy_Scheme"]
+
+    @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+    def test_get_sqlalchemy_engine(self, get_connection):
+        get_connection.return_value = PYMSSQL_CONN
+
+        hook = MsSqlHook()
+        hook.get_sqlalchemy_engine()