You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/03 17:46:04 UTC
[airflow] branch main updated: Pass connection extra parameters to wasb BlobServiceClient (#24154)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 389e858d93 Pass connection extra parameters to wasb BlobServiceClient (#24154)
389e858d93 is described below
commit 389e858d934a7813c7f15ab4e46df33c5720e415
Author: Tanel Kiis <ta...@users.noreply.github.com>
AuthorDate: Fri Jun 3 20:45:57 2022 +0300
Pass connection extra parameters to wasb BlobServiceClient (#24154)
---
airflow/providers/microsoft/azure/hooks/wasb.py | 28 ++++++-----
tests/providers/microsoft/azure/hooks/test_wasb.py | 55 ++++++++++++++++++----
2 files changed, 61 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py
index 0bcd9952fc..44d6c569a1 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -121,31 +121,33 @@ class WasbHook(BaseHook):
# Here we use anonymous public read
# more info
# https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
- return BlobServiceClient(account_url=conn.host)
+ return BlobServiceClient(account_url=conn.host, **extra)
- if extra.get('connection_string') or extra.get('extra__wasb__connection_string'):
+ connection_string = extra.pop('connection_string', extra.pop('extra__wasb__connection_string', None))
+ if connection_string:
# connection_string auth takes priority
- connection_string = extra.get('connection_string') or extra.get('extra__wasb__connection_string')
- return BlobServiceClient.from_connection_string(connection_string)
- if extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key'):
- shared_access_key = extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key')
+ return BlobServiceClient.from_connection_string(connection_string, **extra)
+
+ shared_access_key = extra.pop('shared_access_key', extra.pop('extra__wasb__shared_access_key', None))
+ if shared_access_key:
# using shared access key
- return BlobServiceClient(account_url=conn.host, credential=shared_access_key)
- if extra.get('tenant_id') or extra.get('extra__wasb__tenant_id'):
+ return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra)
+
+ tenant = extra.pop('tenant_id', extra.pop('extra__wasb__tenant_id', None))
+ if tenant:
# use Active Directory auth
app_id = conn.login
app_secret = conn.password
- tenant = extra.get('tenant_id', extra.get('extra__wasb__tenant_id'))
token_credential = ClientSecretCredential(tenant, app_id, app_secret)
- return BlobServiceClient(account_url=conn.host, credential=token_credential)
+ return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)
- sas_token = extra.get('sas_token') or extra.get('extra__wasb__sas_token')
+ sas_token = extra.pop('sas_token', extra.pop('extra__wasb__sas_token', None))
if sas_token:
if sas_token.startswith('https'):
- return BlobServiceClient(account_url=sas_token)
+ return BlobServiceClient(account_url=sas_token, **extra)
else:
return BlobServiceClient(
- account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}'
+ account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}', **extra
)
# Fall back to old auth (password) or use managed identity if not provided.
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py
index b258c50ced..e31a560ff0 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -52,11 +52,14 @@ class TestWasbHook:
self.public_read_conn_id = 'pub_read_id'
self.managed_identity_conn_id = 'managed_identity'
+ self.proxies = {'http': 'http_proxy_uri', 'https': 'https_proxy_uri'}
+
db.merge_conn(
Connection(
conn_id=self.public_read_conn_id,
conn_type=self.connection_type,
host='https://accountname.blob.core.windows.net',
+ extra=json.dumps({'proxies': self.proxies}),
)
)
@@ -64,7 +67,7 @@ class TestWasbHook:
Connection(
conn_id=self.connection_string_id,
conn_type=self.connection_type,
- extra=json.dumps({'connection_string': CONN_STRING}),
+ extra=json.dumps({'connection_string': CONN_STRING, 'proxies': self.proxies}),
)
)
db.merge_conn(
@@ -72,50 +75,59 @@ class TestWasbHook:
conn_id=self.shared_key_conn_id,
conn_type=self.connection_type,
host='https://accountname.blob.core.windows.net',
- extra=json.dumps({'shared_access_key': 'token'}),
+ extra=json.dumps({'shared_access_key': 'token', 'proxies': self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.ad_conn_id,
conn_type=self.connection_type,
- extra=json.dumps(
- {'tenant_id': 'token', 'application_id': 'appID', 'application_secret': "appsecret"}
- ),
+ host='conn_host',
+ login='appID',
+ password='appsecret',
+ extra=json.dumps({'tenant_id': 'token', 'proxies': self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.managed_identity_conn_id,
conn_type=self.connection_type,
+ extra=json.dumps({'proxies': self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.sas_conn_id,
conn_type=self.connection_type,
- extra=json.dumps({'sas_token': 'token'}),
+ extra=json.dumps({'sas_token': 'token', 'proxies': self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.extra__wasb__sas_conn_id,
conn_type=self.connection_type,
- extra=json.dumps({'extra__wasb__sas_token': 'token'}),
+ extra=json.dumps({'extra__wasb__sas_token': 'token', 'proxies': self.proxies}),
)
)
db.merge_conn(
Connection(
conn_id=self.http_sas_conn_id,
conn_type=self.connection_type,
- extra=json.dumps({'sas_token': 'https://login.blob.core.windows.net/token'}),
+ extra=json.dumps(
+ {'sas_token': 'https://login.blob.core.windows.net/token', 'proxies': self.proxies}
+ ),
)
)
db.merge_conn(
Connection(
conn_id=self.extra__wasb__http_sas_conn_id,
conn_type=self.connection_type,
- extra=json.dumps({'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token'}),
+ extra=json.dumps(
+ {
+ 'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token',
+ 'proxies': self.proxies,
+ }
+ ),
)
)
@@ -160,6 +172,31 @@ class TestWasbHook:
assert isinstance(conn, BlobServiceClient)
assert conn.url.endswith(sas_token + '/')
+ @pytest.mark.parametrize(
+ argnames="conn_id_str",
+ argvalues=[
+ 'connection_string_id',
+ 'shared_key_conn_id',
+ 'ad_conn_id',
+ 'managed_identity_conn_id',
+ 'sas_conn_id',
+ 'extra__wasb__sas_conn_id',
+ 'http_sas_conn_id',
+ 'extra__wasb__http_sas_conn_id',
+ ],
+ )
+ def test_connection_extra_arguments(self, conn_id_str):
+ conn_id = self.__getattribute__(conn_id_str)
+ hook = WasbHook(wasb_conn_id=conn_id)
+ conn = hook.get_conn()
+ assert conn._config.proxy_policy.proxies == self.proxies
+
+ def test_connection_extra_arguments_public_read(self):
+ conn_id = self.public_read_conn_id
+ hook = WasbHook(wasb_conn_id=conn_id, public_read=True)
+ conn = hook.get_conn()
+ assert conn._config.proxy_policy.proxies == self.proxies
+
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
def test_check_for_blob(self, mock_service):
hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)