You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/11/01 09:13:55 UTC

(airflow) branch main updated: feat(provider/azure): add managed identity support to container_registry hook (#35320)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f84c45827e feat(provider/azure): add managed identity support to container_registry hook (#35320)
f84c45827e is described below

commit f84c45827e6d743d58fd01b0511bcd1b3be85f5a
Author: Wei Lee <we...@gmail.com>
AuthorDate: Wed Nov 1 17:13:42 2023 +0800

    feat(provider/azure): add managed identity support to container_registry hook (#35320)
---
 .../microsoft/azure/hooks/container_registry.py      | 18 +++++++++++++++---
 .../connections/acr.rst                              | 20 ++++++++++++++++----
 .../azure/hooks/test_azure_container_registry.py     |  4 ++--
 3 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py b/airflow/providers/microsoft/azure/hooks/container_registry.py
index 2b9383e5d3..ea7b129a46 100644
--- a/airflow/providers/microsoft/azure/hooks/container_registry.py
+++ b/airflow/providers/microsoft/azure/hooks/container_registry.py
@@ -21,12 +21,11 @@ from __future__ import annotations
 from functools import cached_property
 from typing import Any
 
-from azure.identity import DefaultAzureCredential
 from azure.mgmt.containerinstance.models import ImageRegistryCredential
 from azure.mgmt.containerregistry import ContainerRegistryManagementClient
 
 from airflow.hooks.base import BaseHook
-from airflow.providers.microsoft.azure.utils import get_field
+from airflow.providers.microsoft.azure.utils import get_default_azure_credential, get_field
 
 
 class AzureContainerRegistryHook(BaseHook):
@@ -59,6 +58,12 @@ class AzureContainerRegistryHook(BaseHook):
                 lazy_gettext("Resource group name (optional)"),
                 widget=BS3TextFieldWidget(),
             ),
+            "managed_identity_client_id": StringField(
+                lazy_gettext("Managed Identity Client ID"), widget=BS3TextFieldWidget()
+            ),
+            "workload_identity_tenant_id": StringField(
+                lazy_gettext("Workload Identity Tenant ID"), widget=BS3TextFieldWidget()
+            ),
         }
 
     @classmethod
@@ -77,6 +82,8 @@ class AzureContainerRegistryHook(BaseHook):
                 "host": "docker image registry server",
                 "subscription_id": "Subscription id (required for Azure AD authentication)",
                 "resource_group": "Resource group name (required for Azure AD authentication)",
+                "managed_identity_client_id": "Managed Identity Client ID",
+                "workload_identity_tenant_id": "Workload Identity Tenant ID",
             },
         }
 
@@ -103,8 +110,13 @@ class AzureContainerRegistryHook(BaseHook):
             extras = conn.extra_dejson
             subscription_id = self._get_field(extras, "subscription_id")
             resource_group = self._get_field(extras, "resource_group")
+            managed_identity_client_id = self._get_field(extras, "managed_identity_client_id")
+            workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id")
             client = ContainerRegistryManagementClient(
-                credential=DefaultAzureCredential(), subscription_id=subscription_id
+                credential=get_default_azure_credential(
+                    managed_identity_client_id, workload_identity_tenant_id
+                ),
+                subscription_id=subscription_id,
             )
             credentials = client.registries.list_credentials(resource_group, conn.login).as_dict()
             password = credentials["passwords"][0]["value"]
diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst b/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
index 913fcf7956..c539d8bfef 100644
--- a/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
+++ b/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst
@@ -27,13 +27,13 @@ The Microsoft Azure Container Registry connection type enables the Azure Contain
 Authenticating to Azure Container Registry
 ------------------------------------------
 
-There is one way to connect to Azure Container Registry using Airflow.
+There are three way to connect to Azure Container Registry using Airflow.
 
 1. Use `Individual login with Azure AD
    <https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication#individual-login-with-azure-ad>`_
    i.e. add specific credentials to the Airflow connection.
-2. Fallback on `DefaultAzureCredential
-   <https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential>`_.
+2. Use managed identity by setting ``managed_identity_client_id``, ``workload_identity_tenant_id`` (under the hook, it uses DefaultAzureCredential_ with these arguments)
+3. Fallback on DefaultAzureCredential_.
    This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI...
 
 Default Connection IDs
@@ -48,7 +48,7 @@ Login
     Specify the Image Registry Username used for the initial connection.
 
 Password (optional)
-    Specify the Image Registry Password used for the initial connection. It can be left out to fall back on ``DefaultAzureCredential``.
+    Specify the Image Registry Password used for the initial connection. It can be left out to fall back on DefaultAzureCredential_.
 
 Host
     Specify the Image Registry Server used for the initial connection.
@@ -63,6 +63,13 @@ Resource Group Name (optional)
     This is needed for Azure Active Directory (Azure AD) authentication.
     Use extra param ``resource_group`` to pass in the resource group name.
 
+Managed Identity Client ID (optional)
+    The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_.
+
+Workload Identity Tenant ID (optional)
+    ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_.
+
+
 When specifying the connection in environment variable you should specify
 it using URI syntax.
 
@@ -73,3 +80,8 @@ For example:
 .. code-block:: bash
 
     export AIRFLOW_CONN_AZURE_CONTAINER_REGISTRY_DEFAULT='azure-container-registry://username:password@myregistry.com?tenant=tenant+id&account_name=store+name'
+
+.. _DefaultAzureCredential: https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential
+
+.. spelling:word-list::
+    Entra
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
index a2b0635749..063a1290d2 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py
@@ -63,7 +63,7 @@ class TestAzureContainerRegistryHook:
     @mock.patch(
         "airflow.providers.microsoft.azure.hooks.container_registry.ContainerRegistryManagementClient"
     )
-    @mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.DefaultAzureCredential")
+    @mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.get_default_azure_credential")
     def test_get_conn_with_default_azure_credential(
         self, mocked_default_azure_credential, mocked_client, mocked_connection
     ):
@@ -80,4 +80,4 @@ class TestAzureContainerRegistryHook:
         assert hook.connection.password == "password"
         assert hook.connection.server == "test.cr"
 
-        mocked_default_azure_credential.assert_called_with()
+        mocked_default_azure_credential.assert_called_with(None, None)