You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/03 17:46:04 UTC

[airflow] branch main updated: Pass connection extra parameters to wasb BlobServiceClient (#24154)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 389e858d93 Pass connection extra parameters to wasb BlobServiceClient (#24154)
389e858d93 is described below

commit 389e858d934a7813c7f15ab4e46df33c5720e415
Author: Tanel Kiis <ta...@users.noreply.github.com>
AuthorDate: Fri Jun 3 20:45:57 2022 +0300

    Pass connection extra parameters to wasb BlobServiceClient (#24154)
---
 airflow/providers/microsoft/azure/hooks/wasb.py    | 28 ++++++-----
 tests/providers/microsoft/azure/hooks/test_wasb.py | 55 ++++++++++++++++++----
 2 files changed, 61 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py
index 0bcd9952fc..44d6c569a1 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -121,31 +121,33 @@ class WasbHook(BaseHook):
             # Here we use anonymous public read
             # more info
             # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
-            return BlobServiceClient(account_url=conn.host)
+            return BlobServiceClient(account_url=conn.host, **extra)
 
-        if extra.get('connection_string') or extra.get('extra__wasb__connection_string'):
+        connection_string = extra.pop('connection_string', extra.pop('extra__wasb__connection_string', None))
+        if connection_string:
             # connection_string auth takes priority
-            connection_string = extra.get('connection_string') or extra.get('extra__wasb__connection_string')
-            return BlobServiceClient.from_connection_string(connection_string)
-        if extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key'):
-            shared_access_key = extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key')
+            return BlobServiceClient.from_connection_string(connection_string, **extra)
+
+        shared_access_key = extra.pop('shared_access_key', extra.pop('extra__wasb__shared_access_key', None))
+        if shared_access_key:
             # using shared access key
-            return BlobServiceClient(account_url=conn.host, credential=shared_access_key)
-        if extra.get('tenant_id') or extra.get('extra__wasb__tenant_id'):
+            return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra)
+
+        tenant = extra.pop('tenant_id', extra.pop('extra__wasb__tenant_id', None))
+        if tenant:
             # use Active Directory auth
             app_id = conn.login
             app_secret = conn.password
-            tenant = extra.get('tenant_id', extra.get('extra__wasb__tenant_id'))
             token_credential = ClientSecretCredential(tenant, app_id, app_secret)
-            return BlobServiceClient(account_url=conn.host, credential=token_credential)
+            return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)
 
-        sas_token = extra.get('sas_token') or extra.get('extra__wasb__sas_token')
+        sas_token = extra.pop('sas_token', extra.pop('extra__wasb__sas_token', None))
         if sas_token:
             if sas_token.startswith('https'):
-                return BlobServiceClient(account_url=sas_token)
+                return BlobServiceClient(account_url=sas_token, **extra)
             else:
                 return BlobServiceClient(
-                    account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}'
+                    account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}', **extra
                 )
 
         # Fall back to old auth (password) or use managed identity if not provided.
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py
index b258c50ced..e31a560ff0 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -52,11 +52,14 @@ class TestWasbHook:
         self.public_read_conn_id = 'pub_read_id'
         self.managed_identity_conn_id = 'managed_identity'
 
+        self.proxies = {'http': 'http_proxy_uri', 'https': 'https_proxy_uri'}
+
         db.merge_conn(
             Connection(
                 conn_id=self.public_read_conn_id,
                 conn_type=self.connection_type,
                 host='https://accountname.blob.core.windows.net',
+                extra=json.dumps({'proxies': self.proxies}),
             )
         )
 
@@ -64,7 +67,7 @@ class TestWasbHook:
             Connection(
                 conn_id=self.connection_string_id,
                 conn_type=self.connection_type,
-                extra=json.dumps({'connection_string': CONN_STRING}),
+                extra=json.dumps({'connection_string': CONN_STRING, 'proxies': self.proxies}),
             )
         )
         db.merge_conn(
@@ -72,50 +75,59 @@ class TestWasbHook:
                 conn_id=self.shared_key_conn_id,
                 conn_type=self.connection_type,
                 host='https://accountname.blob.core.windows.net',
-                extra=json.dumps({'shared_access_key': 'token'}),
+                extra=json.dumps({'shared_access_key': 'token', 'proxies': self.proxies}),
             )
         )
         db.merge_conn(
             Connection(
                 conn_id=self.ad_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps(
-                    {'tenant_id': 'token', 'application_id': 'appID', 'application_secret': "appsecret"}
-                ),
+                host='conn_host',
+                login='appID',
+                password='appsecret',
+                extra=json.dumps({'tenant_id': 'token', 'proxies': self.proxies}),
             )
         )
         db.merge_conn(
             Connection(
                 conn_id=self.managed_identity_conn_id,
                 conn_type=self.connection_type,
+                extra=json.dumps({'proxies': self.proxies}),
             )
         )
         db.merge_conn(
             Connection(
                 conn_id=self.sas_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps({'sas_token': 'token'}),
+                extra=json.dumps({'sas_token': 'token', 'proxies': self.proxies}),
             )
         )
         db.merge_conn(
             Connection(
                 conn_id=self.extra__wasb__sas_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps({'extra__wasb__sas_token': 'token'}),
+                extra=json.dumps({'extra__wasb__sas_token': 'token', 'proxies': self.proxies}),
             )
         )
         db.merge_conn(
             Connection(
                 conn_id=self.http_sas_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps({'sas_token': 'https://login.blob.core.windows.net/token'}),
+                extra=json.dumps(
+                    {'sas_token': 'https://login.blob.core.windows.net/token', 'proxies': self.proxies}
+                ),
             )
         )
         db.merge_conn(
             Connection(
                 conn_id=self.extra__wasb__http_sas_conn_id,
                 conn_type=self.connection_type,
-                extra=json.dumps({'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token'}),
+                extra=json.dumps(
+                    {
+                        'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token',
+                        'proxies': self.proxies,
+                    }
+                ),
             )
         )
 
@@ -160,6 +172,31 @@ class TestWasbHook:
         assert isinstance(conn, BlobServiceClient)
         assert conn.url.endswith(sas_token + '/')
 
+    @pytest.mark.parametrize(
+        argnames="conn_id_str",
+        argvalues=[
+            'connection_string_id',
+            'shared_key_conn_id',
+            'ad_conn_id',
+            'managed_identity_conn_id',
+            'sas_conn_id',
+            'extra__wasb__sas_conn_id',
+            'http_sas_conn_id',
+            'extra__wasb__http_sas_conn_id',
+        ],
+    )
+    def test_connection_extra_arguments(self, conn_id_str):
+        conn_id = self.__getattribute__(conn_id_str)
+        hook = WasbHook(wasb_conn_id=conn_id)
+        conn = hook.get_conn()
+        assert conn._config.proxy_policy.proxies == self.proxies
+
+    def test_connection_extra_arguments_public_read(self):
+        conn_id = self.public_read_conn_id
+        hook = WasbHook(wasb_conn_id=conn_id, public_read=True)
+        conn = hook.get_conn()
+        assert conn._config.proxy_policy.proxies == self.proxies
+
     @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     def test_check_for_blob(self, mock_service):
         hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)