You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/05 15:41:52 UTC
[airflow] branch main updated: Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new df5a54d21d Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)
df5a54d21d is described below
commit df5a54d21d6991d6cae05c38e1562da2196e76aa
Author: gebo <bo...@gmail.com>
AuthorDate: Fri Aug 5 17:41:43 2022 +0200
Fix MsSqlHook.get_uri: pymssql driver to scheme (25092) (#25185)
---
airflow/providers/microsoft/mssql/hooks/mssql.py | 56 ++++++++++-
.../providers/microsoft/mssql/hooks/test_mssql.py | 102 ++++++++++++++++++++-
2 files changed, 156 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/microsoft/mssql/hooks/mssql.py b/airflow/providers/microsoft/mssql/hooks/mssql.py
index ba236efbc2..a54372adfd 100644
--- a/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -18,6 +18,8 @@
"""Microsoft SQLServer hook module"""
+from typing import Any, Optional
+
import pymssql
from airflow.providers.common.sql.hooks.sql import DbApiHook
@@ -31,10 +33,62 @@ class MsSqlHook(DbApiHook):
conn_type = 'mssql'
hook_name = 'Microsoft SQL Server'
supports_autocommit = True
+ DEFAULT_SQLALCHEMY_SCHEME = 'mssql+pymssql'
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(
+ self,
+ *args,
+ sqlalchemy_scheme: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ """
+ :param args: passed to DBApiHook
+ :param sqlalchemy_scheme: Scheme sqlalchemy connection. Default is ``mssql+pymssql`` Only used for
+ ``get_sqlalchemy_engine`` and ``get_sqlalchemy_connection`` methods.
+ :param kwargs: passed to DbApiHook
+ """
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
+ self._sqlalchemy_scheme = sqlalchemy_scheme
+
+ @property
+ def connection_extra_lower(self) -> dict:
+ """
+ ``connection.extra_dejson`` but where keys are converted to lower case.
+ This is used internally for case-insensitive access of mssql params.
+ """
+ conn = self.get_connection(self.mssql_conn_id) # type: ignore[attr-defined]
+ return {k.lower(): v for k, v in conn.extra_dejson.items()}
+
+ @property
+ def sqlalchemy_scheme(self) -> str:
+ """Sqlalchemy scheme either from constructor, connection extras or default."""
+ return (
+ self._sqlalchemy_scheme
+ or self.connection_extra_lower.get('sqlalchemy_scheme')
+ or self.DEFAULT_SQLALCHEMY_SCHEME
+ )
+
+ def get_uri(self) -> str:
+ from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
+
+ r = list(urlsplit(super().get_uri()))
+ # change pymssql driver:
+ r[0] = self.sqlalchemy_scheme
+ # remove query string 'sqlalchemy_scheme' like parameters:
+ qs = parse_qs(r[3], keep_blank_values=True)
+ for k in list(qs.keys()):
+ if k.lower() == 'sqlalchemy_scheme':
+ qs.pop(k, None)
+ r[3] = urlencode(qs, doseq=True)
+ return urlunsplit(r)
+
+ def get_sqlalchemy_connection(
+ self, connect_kwargs: Optional[dict] = None, engine_kwargs: Optional[dict] = None
+ ) -> Any:
+ """Sqlalchemy connection object"""
+ engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs)
+ return engine.connect(**(connect_kwargs or {}))
def get_conn(
self,
diff --git a/tests/providers/microsoft/mssql/hooks/test_mssql.py b/tests/providers/microsoft/mssql/hooks/test_mssql.py
index cb3add3e94..82ce184225 100644
--- a/tests/providers/microsoft/mssql/hooks/test_mssql.py
+++ b/tests/providers/microsoft/mssql/hooks/test_mssql.py
@@ -18,11 +18,37 @@
import unittest
from unittest import mock
+from urllib.parse import quote_plus
+
+from parameterized import parameterized
from airflow.models import Connection
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
-PYMSSQL_CONN = Connection(host='ip', schema='share', login='username', password='password', port=8081)
+PYMSSQL_CONN = Connection(
+ conn_type='mssql', host='ip', schema='share', login='username', password='password', port=8081
+)
+PYMSSQL_CONN_ALT = Connection(
+ conn_type='mssql', host='ip', schema='', login='username', password='password', port=8081
+)
+PYMSSQL_CONN_ALT_1 = Connection(
+ conn_type='mssql',
+ host='ip',
+ schema='',
+ login='username',
+ password='password',
+ port=8081,
+ extra={"SQlalchemy_Scheme": "mssql+testdriver"},
+)
+PYMSSQL_CONN_ALT_2 = Connection(
+ conn_type='mssql',
+ host='ip',
+ schema='',
+ login='username',
+ password='password',
+ port=8081,
+ extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"},
+)
class TestMsSqlHook(unittest.TestCase):
@@ -64,3 +90,77 @@ class TestMsSqlHook(unittest.TestCase):
mssql_get_conn.assert_called_once()
assert hook.get_autocommit(conn) == 'autocommit_state'
+
+ @parameterized.expand(
+ [
+ (
+ PYMSSQL_CONN,
+ (
+ "mssql+pymssql://"
+ f"{quote_plus(PYMSSQL_CONN.login)}:{quote_plus(PYMSSQL_CONN.password)}"
+ f"@{PYMSSQL_CONN.host}:{PYMSSQL_CONN.port}/{PYMSSQL_CONN.schema}"
+ ),
+ ),
+ (
+ PYMSSQL_CONN_ALT,
+ (
+ "mssql+pymssql://"
+ f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
+ f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}"
+ ),
+ ),
+ (
+ PYMSSQL_CONN_ALT_1,
+ (
+ f"{PYMSSQL_CONN_ALT_1.extra_dejson['SQlalchemy_Scheme']}://"
+ f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
+ f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}/"
+ ),
+ ),
+ (
+ PYMSSQL_CONN_ALT_2,
+ (
+ f"{PYMSSQL_CONN_ALT_2.extra_dejson['SQlalchemy_Scheme']}://"
+ f"{quote_plus(PYMSSQL_CONN_ALT_2.login)}:{quote_plus(PYMSSQL_CONN_ALT_2.password)}"
+ f"@{PYMSSQL_CONN_ALT_2.host}:{PYMSSQL_CONN_ALT_2.port}/"
+ f"?myparam={quote_plus(PYMSSQL_CONN_ALT_2.extra_dejson['myparam'])}"
+ ),
+ ),
+ ],
+ )
+ @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+ def test_get_uri_driver_rewrite(self, conn, exp_uri, get_connection):
+ get_connection.return_value = conn
+
+ hook = MsSqlHook()
+ res_uri = hook.get_uri()
+
+ get_connection.assert_called()
+ assert res_uri == exp_uri
+
+ @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+ def test_sqlalchemy_scheme_is_default(self, get_connection):
+ get_connection.return_value = PYMSSQL_CONN
+
+ hook = MsSqlHook()
+ assert hook.sqlalchemy_scheme == hook.DEFAULT_SQLALCHEMY_SCHEME
+
+ def test_sqlalchemy_scheme_is_from_hook(self):
+ hook = MsSqlHook(sqlalchemy_scheme="mssql+mytestdriver")
+ assert hook.sqlalchemy_scheme == "mssql+mytestdriver"
+
+ @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+ def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection):
+ get_connection.return_value = PYMSSQL_CONN_ALT_1
+
+ hook = MsSqlHook()
+ scheme = hook.sqlalchemy_scheme
+ get_connection.assert_called()
+ assert scheme == PYMSSQL_CONN_ALT_1.extra_dejson["SQlalchemy_Scheme"]
+
+ @mock.patch('airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection')
+ def test_get_sqlalchemy_engine(self, get_connection):
+ get_connection.return_value = PYMSSQL_CONN
+
+ hook = MsSqlHook()
+ hook.get_sqlalchemy_engine()