You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/08/25 14:21:09 UTC

[airflow] branch main updated: Fix Azure Batch Hook instantation (#33731)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 38f2737153 Fix Azure Batch Hook instantation (#33731)
38f2737153 is described below

commit 38f27371532b9f906bdeff0251d1c35956daf05c
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Fri Aug 25 16:21:01 2023 +0200

    Fix Azure Batch Hook instantation (#33731)
    
    The Hook instantiation for Azure Batch has been done in the
    constructor, which is wrong. This has been detected when #33716 added
    example dag and it started to fail provider imports as connection
    has beeen missing to instantiate it.
    
    The hook instantiation is now moved to cached property.
---
 airflow/providers/microsoft/azure/operators/batch.py          | 6 +++++-
 tests/providers/microsoft/azure/operators/test_azure_batch.py | 1 +
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py
index 63b925a981..e26f56dd6e 100644
--- a/airflow/providers/microsoft/azure/operators/batch.py
+++ b/airflow/providers/microsoft/azure/operators/batch.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
 from azure.batch import models as batch_models
@@ -176,7 +177,10 @@ class AzureBatchOperator(BaseOperator):
         self.timeout = timeout
         self.should_delete_job = should_delete_job
         self.should_delete_pool = should_delete_pool
-        self.hook = self.get_hook()
+
+    @cached_property
+    def hook(self):
+        return self.get_hook()
 
     def _check_inputs(self) -> Any:
         if not self.os_family and not self.vm_publisher:
diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py
index 0e3947732b..e920f7c1d9 100644
--- a/tests/providers/microsoft/azure/operators/test_azure_batch.py
+++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py
@@ -162,6 +162,7 @@ class TestAzureBatchOperator:
         self.batch_client = mock_batch.return_value
         self.mock_instance = mock_hook.return_value
         assert self.batch_client == self.operator.hook.connection
+        assert self.batch_client == self.operator2_pass.hook.connection
 
     @mock.patch.object(AzureBatchHook, "wait_for_all_node_state")
     def test_execute_without_failures(self, wait_mock):