You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/11/05 15:07:22 UTC
(airflow) branch main updated: fix(providers/microsoft): setting use_async=True for get_async_default_azure_credential (#35432)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2980eb137d fix(providers/microsoft): setting use_async=True for get_async_default_azure_credential (#35432)
2980eb137d is described below
commit 2980eb137d518d071aaec4f849a6dbbe5e1724cb
Author: Wei Lee <we...@gmail.com>
AuthorDate: Sun Nov 5 23:07:13 2023 +0800
fix(providers/microsoft): setting use_async=True for get_async_default_azure_credential (#35432)
---
airflow/providers/microsoft/azure/utils.py | 6 +++---
tests/providers/microsoft/azure/test_utils.py | 16 ++++++++++++++++
2 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/microsoft/azure/utils.py b/airflow/providers/microsoft/azure/utils.py
index 1b738ed957..a7a7e38966 100644
--- a/airflow/providers/microsoft/azure/utils.py
+++ b/airflow/providers/microsoft/azure/utils.py
@@ -59,8 +59,8 @@ def get_field(*, conn_id: str, conn_type: str, extras: dict, field_name: str):
def _get_default_azure_credential(
*,
- managed_identity_client_id: str | None,
- workload_identity_tenant_id: str | None,
+ managed_identity_client_id: str | None = None,
+ workload_identity_tenant_id: str | None = None,
use_async: bool = False,
) -> DefaultAzureCredential | AsyncDefaultAzureCredential:
"""Get DefaultAzureCredential based on provided arguments.
@@ -88,7 +88,7 @@ get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial(
get_async_default_azure_credential: partial[AsyncDefaultAzureCredential] = partial(
_get_default_azure_credential, # type: ignore[arg-type]
- use_async=False,
+ use_async=True,
)
diff --git a/tests/providers/microsoft/azure/test_utils.py b/tests/providers/microsoft/azure/test_utils.py
index 5a081441ca..f04acaab13 100644
--- a/tests/providers/microsoft/azure/test_utils.py
+++ b/tests/providers/microsoft/azure/test_utils.py
@@ -25,7 +25,10 @@ import pytest
from airflow.providers.microsoft.azure.utils import (
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
+ get_async_default_azure_credential,
get_field,
+ # _get_default_azure_credential
+ get_sync_default_azure_credential,
)
MODULE = "airflow.providers.microsoft.azure.utils"
@@ -77,6 +80,19 @@ def test_add_managed_identity_connection_widgets():
assert "workload_identity_tenant_id" in widgets
+@mock.patch(f"{MODULE}.DefaultAzureCredential")
+def test_get_sync_default_azure_credential(mock_default_azure_credential):
+ get_sync_default_azure_credential()
+
+ assert mock_default_azure_credential.called
+
+
+@mock.patch(f"{MODULE}.AsyncDefaultAzureCredential")
+def test_get_async_default_azure_credential(mock_default_azure_credential):
+ get_async_default_azure_credential()
+ assert mock_default_azure_credential.called
+
+
class TestAzureIdentityCredentialAdapter:
@mock.patch(f"{MODULE}.PipelineRequest")
@mock.patch(f"{MODULE}.BearerTokenCredentialPolicy")