You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/11/05 15:07:22 UTC

(airflow) branch main updated: fix(providers/microsoft): setting use_async=True for get_async_default_azure_credential (#35432)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2980eb137d fix(providers/microsoft): setting use_async=True for get_async_default_azure_credential (#35432)
2980eb137d is described below

commit 2980eb137d518d071aaec4f849a6dbbe5e1724cb
Author: Wei Lee <we...@gmail.com>
AuthorDate: Sun Nov 5 23:07:13 2023 +0800

    fix(providers/microsoft): setting use_async=True for get_async_default_azure_credential (#35432)
---
 airflow/providers/microsoft/azure/utils.py    |  6 +++---
 tests/providers/microsoft/azure/test_utils.py | 16 ++++++++++++++++
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/microsoft/azure/utils.py b/airflow/providers/microsoft/azure/utils.py
index 1b738ed957..a7a7e38966 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -59,8 +59,8 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
 
 def _get_default_azure_credential(
     *,
-    managed_identity_client_id: str | None,
-    workload_identity_tenant_id: str | None,
+    managed_identity_client_id: str | None = None,
+    workload_identity_tenant_id: str | None = None,
     use_async: bool = False,
 ) -> DefaultAzureCredential | AsyncDefaultAzureCredential:
     """Get DefaultAzureCredential based on provided arguments.
@@ -88,7 +88,7 @@ get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial(
 
 get_async_default_azure_credential: partial[AsyncDefaultAzureCredential] = partial(
     _get_default_azure_credential,  #  type: ignore[arg-type]
-    use_async=False,
+    use_async=True,
 )
 
 
diff --git a/tests/providers/microsoft/azure/test_utils.py b/tests/providers/microsoft/azure/test_utils.py
index 5a081441ca..f04acaab13 100644
--- a/tests/providers/microsoft/azure/test_utils.py
+++ b/tests/providers/microsoft/azure/test_utils.py
@@ -25,7 +25,10 @@ import pytest
 from airflow.providers.microsoft.azure.utils import (
     AzureIdentityCredentialAdapter,
     add_managed_identity_connection_widgets,
+    get_async_default_azure_credential,
     get_field,
+    # _get_default_azure_credential
+    get_sync_default_azure_credential,
 )
 
 MODULE = "airflow.providers.microsoft.azure.utils"
@@ -77,6 +80,19 @@ def test_add_managed_identity_connection_widgets():
     assert "workload_identity_tenant_id" in widgets
 
 
+@mock.patch(f"{MODULE}.DefaultAzureCredential")
+def test_get_sync_default_azure_credential(mock_default_azure_credential):
+    get_sync_default_azure_credential()
+
+    assert mock_default_azure_credential.called
+
+
+@mock.patch(f"{MODULE}.AsyncDefaultAzureCredential")
+def test_get_async_default_azure_credential(mock_default_azure_credential):
+    get_async_default_azure_credential()
+    assert mock_default_azure_credential.called
+
+
 class TestAzureIdentityCredentialAdapter:
     @mock.patch(f"{MODULE}.PipelineRequest")
     @mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")