You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2022/02/25 10:35:30 UTC
[airflow] branch main updated: Make DbApiHook use get_uri from Connection (#21764)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 59c450e Make DbApiHook use get_uri from Connection (#21764)
59c450e is described below
commit 59c450ee5425a2d23ef813dbf219cde14df7c85c
Author: bolkedebruin <bo...@users.noreply.github.com>
AuthorDate: Fri Feb 25 11:34:29 2022 +0100
Make DbApiHook use get_uri from Connection (#21764)
DBApi has its own get_uri method which does not deal
with quoting properly and neither with empty passwords.
Connection also has a get_uri method that deals properly
with the above issues.
This also fixes issues with RFC compliancy.
---
airflow/hooks/dbapi.py | 11 +---
airflow/models/connection.py | 6 +++
airflow/providers/mysql/hooks/mysql.py | 8 ---
airflow/providers/postgres/hooks/postgres.py | 8 +--
tests/hooks/test_dbapi.py | 61 +++++++++++++++++++----
tests/providers/amazon/aws/hooks/test_base_aws.py | 1 +
tests/providers/snowflake/hooks/test_snowflake.py | 1 +
7 files changed, 65 insertions(+), 31 deletions(-)
diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py
index f86fe26..933d4f2 100644
--- a/airflow/hooks/dbapi.py
+++ b/airflow/hooks/dbapi.py
@@ -18,7 +18,6 @@
from contextlib import closing
from datetime import datetime
from typing import Any, Optional
-from urllib.parse import quote_plus, urlunsplit
from sqlalchemy import create_engine
@@ -96,14 +95,8 @@ class DbApiHook(BaseHook):
:return: the extracted uri.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
- login = ''
- if conn.login:
- login = f'{quote_plus(conn.login)}:{quote_plus(conn.password)}@'
- host = conn.host
- if conn.port is not None:
- host += f':{conn.port}'
- schema = self.__schema or conn.schema or ''
- return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', ''))
+ conn.schema = self.__schema or conn.schema
+ return conn.get_uri()
def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index b6cde8d..6ef071f 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -173,6 +173,12 @@ class Connection(Base, LoggingMixin):
def get_uri(self) -> str:
"""Return connection in URI format"""
+ if '_' in self.conn_type:
+ self.log.warning(
+ f"Connection schemes (type: {str(self.conn_type)}) "
+ f"shall not contain '_' according to RFC3986."
+ )
+
uri = f"{str(self.conn_type).lower().replace('_', '-')}://"
authority_block = ''
diff --git a/airflow/providers/mysql/hooks/mysql.py b/airflow/providers/mysql/hooks/mysql.py
index 47f8cb7..1f8513b 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -173,14 +173,6 @@ class MySqlHook(DbApiHook):
raise ValueError('Unknown MySQL client name provided!')
- def get_uri(self) -> str:
- conn = self.get_connection(getattr(self, self.conn_name_attr))
- uri = super().get_uri()
- if conn.extra_dejson.get('charset', False):
- charset = conn.extra_dejson["charset"]
- return f"{uri}?charset={charset}"
- return uri
-
def bulk_load(self, table: str, tmp_file: str) -> None:
"""Loads a tab-delimited file into a database table"""
conn = self.get_conn()
diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py
index a707f9c..884a1c9 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -137,11 +137,11 @@ class PostgresHook(DbApiHook):
conn.commit()
def get_uri(self) -> str:
- conn = self.get_connection(getattr(self, self.conn_name_attr))
+ """
+ Extract the URI from the connection.
+ :return: the extracted uri.
+ """
uri = super().get_uri().replace("postgres://", "postgresql://")
- if conn.extra_dejson.get('client_encoding', False):
- charset = conn.extra_dejson["client_encoding"]
- return f"{uri}?client_encoding={charset}"
return uri
def bulk_load(self, table: str, tmp_file: str) -> None:
diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py
index 97e2c4a..81a63de 100644
--- a/tests/hooks/test_dbapi.py
+++ b/tests/hooks/test_dbapi.py
@@ -150,7 +150,7 @@ class TestDbApiHook(unittest.TestCase):
def test_get_uri_schema_not_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
- conn_type="conn_type",
+ conn_type="conn-type",
host="host",
login="login",
password="password",
@@ -158,12 +158,12 @@ class TestDbApiHook(unittest.TestCase):
port=1,
)
)
- assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri()
+ assert "conn-type://login:password@host:1/schema" == self.db_hook.get_uri()
def test_get_uri_schema_override(self):
self.db_hook_schema_override.get_connection = mock.MagicMock(
return_value=Connection(
- conn_type="conn_type",
+ conn_type="conn-type",
host="host",
login="login",
password="password",
@@ -171,28 +171,69 @@ class TestDbApiHook(unittest.TestCase):
port=1,
)
)
- assert "conn_type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()
+ assert "conn-type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()
def test_get_uri_schema_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
- conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1
+ conn_type="conn-type", host="host", login="login", password="password", schema=None, port=1
)
)
- assert "conn_type://login:password@host:1" == self.db_hook.get_uri()
+ assert "conn-type://login:password@host:1" == self.db_hook.get_uri()
def test_get_uri_special_characters(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
- conn_type="conn_type",
+ conn_type="conn-type",
+ host="host/",
+ login="lo/gi#! n",
+ password="pass*! word/",
+ schema="schema/",
+ port=1,
+ )
+ )
+ assert (
+ "conn-type://lo%2Fgi%23%21%20n:pass%2A%21%20word%2F@host%2F:1/schema%2F" == self.db_hook.get_uri()
+ )
+
+ def test_get_uri_login_none(self):
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn-type",
+ host="host",
+ login=None,
+ password="password",
+ schema="schema",
+ port=1,
+ )
+ )
+ assert "conn-type://:password@host:1/schema" == self.db_hook.get_uri()
+
+ def test_get_uri_password_none(self):
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn-type",
+ host="host",
+ login="login",
+ password=None,
+ schema="schema",
+ port=1,
+ )
+ )
+ assert "conn-type://login@host:1/schema" == self.db_hook.get_uri()
+
+ def test_get_uri_authority_none(self):
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn-type",
host="host",
- login="logi#! n",
- password="pass*! word",
+ login=None,
+ password=None,
schema="schema",
port=1,
)
)
- assert "conn_type://logi%23%21+n:pass%2A%21+word@host:1/schema" == self.db_hook.get_uri()
+ assert "conn-type://host:1/schema" == self.db_hook.get_uri()
def test_run_log(self):
statement = 'SQL'
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 0000136..7c9512f 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -391,6 +391,7 @@ class TestAwsBaseHook(unittest.TestCase):
}
)
)
+ mock_connection.conn_type = 'aws'
# Store original __import__
orig_import = __import__
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py
index 650fb00..ec7fb2d 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -34,6 +34,7 @@ _PASSWORD = 'snowflake42'
BASE_CONNECTION_KWARGS: Dict = {
'login': 'user',
+ 'conn_type': 'snowflake',
'password': 'pw',
'schema': 'public',
'extra': {