You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by "josh-fell (via GitHub)" <gi...@apache.org> on 2023/03/03 19:11:02 UTC

[GitHub] [airflow] josh-fell commented on a diff in pull request #29801: Add deferrable `AzureDataFactoryPipelineRunStatusSensor`

josh-fell commented on code in PR #29801:
URL: https://github.com/apache/airflow/pull/29801#discussion_r1124899383


##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,120 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                default_value = conn.extra_dejson.get(default_key)

Review Comment:
   Can you update this to match the sync version? Prefixing connection extras with `extra__...` is no longer needed, but we do keep the check for backwards compat.



##########
airflow/providers/microsoft/azure/hooks/data_factory.py:
##########
@@ -1039,3 +1048,120 @@ def test_connection(self) -> tuple[bool, str]:
             return success
         except Exception as e:
             return False, str(e)
+
+
+def provide_targeted_factory_async(func: T) -> T:
+    """
+    Provide the targeted factory to the async decorated function in case it isn't specified.
+
+    If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in
+    the connection extras.
+    """
+    signature = inspect.signature(func)
+
+    @wraps(func)
+    async def wrapper(*args: Any, **kwargs: Any) -> Any:
+        bound_args = signature.bind(*args, **kwargs)
+
+        async def bind_argument(arg: Any, default_key: str) -> None:
+            # Check if arg was not included in the function signature or, if it is, the value is not provided.
+            if arg not in bound_args.arguments or bound_args.arguments[arg] is None:
+                self = args[0]
+                conn = await sync_to_async(self.get_connection)(self.conn_id)
+                default_value = conn.extra_dejson.get(default_key)
+                if not default_value:
+                    raise AirflowException("Could not determine the targeted data factory.")
+
+                bound_args.arguments[arg] = conn.extra_dejson[default_key]
+
+        await bind_argument("resource_group_name", "extra__azure_data_factory__resource_group_name")
+        await bind_argument("factory_name", "extra__azure_data_factory__factory_name")
+
+        return await func(*bound_args.args, **bound_args.kwargs)
+
+    return cast(T, wrapper)
+
+
+class AzureDataFactoryAsyncHook(AzureDataFactoryHook):
+    """
+    An Async Hook that connects to Azure DataFactory to perform pipeline operations
+
+    :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`.
+    """
+
+    def __init__(self, azure_data_factory_conn_id: str):
+        self._async_conn: AsyncDataFactoryManagementClient = None
+        self.conn_id = azure_data_factory_conn_id
+        super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
+
+    async def get_async_conn(self) -> AsyncDataFactoryManagementClient:
+        """Get async connection and connect to azure data factory"""
+        if self._conn is not None:
+            return self._conn
+
+        conn = await sync_to_async(self.get_connection)(self.conn_id)
+        tenant = conn.extra_dejson.get("extra__azure_data_factory__tenantId")

Review Comment:
   Same idea here. Should use the `get_field()` function to retrieve the extras.



##########
tests/providers/microsoft/azure/hooks/test_azure_data_factory.py:
##########
@@ -708,3 +723,214 @@ def test_backcompat_prefix_both_prefers_short(mock_connect):
         hook = AzureDataFactoryHook("my_conn")
         hook.delete_factory(factory_name="n/a")
         mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", "n/a")
+
+
+class TestAzureDataFactoryAsyncHook:
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(

Review Comment:
   Why parametrize if each status is its own explicit test? You can inject test case `ids` like `pytest.mark.parametrize(..., ids=["test_status1", "test_status2"])`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org