You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/02/25 10:35:30 UTC

[airflow] branch main updated: Make DbApiHook use get_uri from Connection (#21764)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 59c450e  Make DbApiHook use get_uri from Connection (#21764)
59c450e is described below

commit 59c450ee5425a2d23ef813dbf219cde14df7c85c
Author: bolkedebruin <bo...@users.noreply.github.com>
AuthorDate: Fri Feb 25 11:34:29 2022 +0100

    Make DbApiHook use get_uri from Connection (#21764)
    
    DBApi has its own get_uri method which does not deal
    with quoting properly and neither with empty passwords.
    Connection also has a get_uri method that deals properly
    with the above issues.
    
    This also fixes issues with RFC compliancy.
---
 airflow/hooks/dbapi.py                            | 11 +---
 airflow/models/connection.py                      |  6 +++
 airflow/providers/mysql/hooks/mysql.py            |  8 ---
 airflow/providers/postgres/hooks/postgres.py      |  8 +--
 tests/hooks/test_dbapi.py                         | 61 +++++++++++++++++++----
 tests/providers/amazon/aws/hooks/test_base_aws.py |  1 +
 tests/providers/snowflake/hooks/test_snowflake.py |  1 +
 7 files changed, 65 insertions(+), 31 deletions(-)

diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py
index f86fe26..933d4f2 100644
--- a/airflow/hooks/dbapi.py
+++ b/airflow/hooks/dbapi.py
@@ -18,7 +18,6 @@
 from contextlib import closing
 from datetime import datetime
 from typing import Any, Optional
-from urllib.parse import quote_plus, urlunsplit
 
 from sqlalchemy import create_engine
 
@@ -96,14 +95,8 @@ class DbApiHook(BaseHook):
         :return: the extracted uri.
         """
         conn = self.get_connection(getattr(self, self.conn_name_attr))
-        login = ''
-        if conn.login:
-            login = f'{quote_plus(conn.login)}:{quote_plus(conn.password)}@'
-        host = conn.host
-        if conn.port is not None:
-            host += f':{conn.port}'
-        schema = self.__schema or conn.schema or ''
-        return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', ''))
+        conn.schema = self.__schema or conn.schema
+        return conn.get_uri()
 
     def get_sqlalchemy_engine(self, engine_kwargs=None):
         """
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index b6cde8d..6ef071f 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -173,6 +173,12 @@ class Connection(Base, LoggingMixin):
 
     def get_uri(self) -> str:
         """Return connection in URI format"""
+        if '_' in self.conn_type:
+            self.log.warning(
+                f"Connection schemes (type: {str(self.conn_type)}) "
+                f"shall not contain '_' according to RFC3986."
+            )
+
         uri = f"{str(self.conn_type).lower().replace('_', '-')}://"
 
         authority_block = ''
diff --git a/airflow/providers/mysql/hooks/mysql.py b/airflow/providers/mysql/hooks/mysql.py
index 47f8cb7..1f8513b 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -173,14 +173,6 @@ class MySqlHook(DbApiHook):
 
         raise ValueError('Unknown MySQL client name provided!')
 
-    def get_uri(self) -> str:
-        conn = self.get_connection(getattr(self, self.conn_name_attr))
-        uri = super().get_uri()
-        if conn.extra_dejson.get('charset', False):
-            charset = conn.extra_dejson["charset"]
-            return f"{uri}?charset={charset}"
-        return uri
-
     def bulk_load(self, table: str, tmp_file: str) -> None:
         """Loads a tab-delimited file into a database table"""
         conn = self.get_conn()
diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py
index a707f9c..884a1c9 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -137,11 +137,11 @@ class PostgresHook(DbApiHook):
                     conn.commit()
 
     def get_uri(self) -> str:
-        conn = self.get_connection(getattr(self, self.conn_name_attr))
+        """
+        Extract the URI from the connection.
+        :return: the extracted uri.
+        """
         uri = super().get_uri().replace("postgres://", "postgresql://")
-        if conn.extra_dejson.get('client_encoding', False):
-            charset = conn.extra_dejson["client_encoding"]
-            return f"{uri}?client_encoding={charset}"
         return uri
 
     def bulk_load(self, table: str, tmp_file: str) -> None:
diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py
index 97e2c4a..81a63de 100644
--- a/tests/hooks/test_dbapi.py
+++ b/tests/hooks/test_dbapi.py
@@ -150,7 +150,7 @@ class TestDbApiHook(unittest.TestCase):
     def test_get_uri_schema_not_none(self):
         self.db_hook.get_connection = mock.MagicMock(
             return_value=Connection(
-                conn_type="conn_type",
+                conn_type="conn-type",
                 host="host",
                 login="login",
                 password="password",
@@ -158,12 +158,12 @@ class TestDbApiHook(unittest.TestCase):
                 port=1,
             )
         )
-        assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri()
+        assert "conn-type://login:password@host:1/schema" == self.db_hook.get_uri()
 
     def test_get_uri_schema_override(self):
         self.db_hook_schema_override.get_connection = mock.MagicMock(
             return_value=Connection(
-                conn_type="conn_type",
+                conn_type="conn-type",
                 host="host",
                 login="login",
                 password="password",
@@ -171,28 +171,69 @@ class TestDbApiHook(unittest.TestCase):
                 port=1,
             )
         )
-        assert "conn_type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()
+        assert "conn-type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()
 
     def test_get_uri_schema_none(self):
         self.db_hook.get_connection = mock.MagicMock(
             return_value=Connection(
-                conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1
+                conn_type="conn-type", host="host", login="login", password="password", schema=None, port=1
             )
         )
-        assert "conn_type://login:password@host:1" == self.db_hook.get_uri()
+        assert "conn-type://login:password@host:1" == self.db_hook.get_uri()
 
     def test_get_uri_special_characters(self):
         self.db_hook.get_connection = mock.MagicMock(
             return_value=Connection(
-                conn_type="conn_type",
+                conn_type="conn-type",
+                host="host/",
+                login="lo/gi#! n",
+                password="pass*! word/",
+                schema="schema/",
+                port=1,
+            )
+        )
+        assert (
+            "conn-type://lo%2Fgi%23%21%20n:pass%2A%21%20word%2F@host%2F:1/schema%2F" == self.db_hook.get_uri()
+        )
+
+    def test_get_uri_login_none(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login=None,
+                password="password",
+                schema="schema",
+                port=1,
+            )
+        )
+        assert "conn-type://:password@host:1/schema" == self.db_hook.get_uri()
+
+    def test_get_uri_password_none(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login="login",
+                password=None,
+                schema="schema",
+                port=1,
+            )
+        )
+        assert "conn-type://login@host:1/schema" == self.db_hook.get_uri()
+
+    def test_get_uri_authority_none(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
                 host="host",
-                login="logi#! n",
-                password="pass*! word",
+                login=None,
+                password=None,
                 schema="schema",
                 port=1,
             )
         )
-        assert "conn_type://logi%23%21+n:pass%2A%21+word@host:1/schema" == self.db_hook.get_uri()
+        assert "conn-type://host:1/schema" == self.db_hook.get_uri()
 
     def test_run_log(self):
         statement = 'SQL'
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 0000136..7c9512f 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -391,6 +391,7 @@ class TestAwsBaseHook(unittest.TestCase):
                 }
             )
         )
+        mock_connection.conn_type = 'aws'
 
         # Store original __import__
         orig_import = __import__
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py
index 650fb00..ec7fb2d 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -34,6 +34,7 @@ _PASSWORD = 'snowflake42'
 
 BASE_CONNECTION_KWARGS: Dict = {
     'login': 'user',
+    'conn_type': 'snowflake',
     'password': 'pw',
     'schema': 'public',
     'extra': {