You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/06/27 22:52:47 UTC

[airflow] branch main updated: Add a deferrable mode to `BatchCreateComputeEnvironmentOperator` (#32036)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1fb2831239 Add a deferrable mode to `BatchCreateComputeEnvironmentOperator` (#32036)
1fb2831239 is described below

commit 1fb28312393a59bb064e4a1cade59de5c86ef16a
Author: Raphaƫl Vandon <va...@amazon.com>
AuthorDate: Tue Jun 27 15:52:41 2023 -0700

    Add a deferrable mode to `BatchCreateComputeEnvironmentOperator` (#32036)
---
 airflow/providers/amazon/aws/operators/batch.py    | 52 +++++++++++++++-----
 airflow/providers/amazon/aws/triggers/batch.py     | 57 ++++++++++++++++++++++
 airflow/providers/amazon/aws/waiters/batch.json    | 26 ++++++++++
 tests/providers/amazon/aws/operators/test_batch.py | 28 ++++++++++-
 tests/providers/amazon/aws/triggers/test_batch.py  | 43 +++++++++++++++-
 5 files changed, 193 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py
index 88feb01311..b9b3322c49 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -38,7 +38,10 @@ from airflow.providers.amazon.aws.links.batch import (
     BatchJobQueueLink,
 )
 from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.providers.amazon.aws.triggers.batch import (
+    BatchCreateComputeEnvironmentTrigger,
+    BatchOperatorTrigger,
+)
 from airflow.providers.amazon.aws.utils import trim_none_values
 from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
 
@@ -402,14 +405,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
         services on your behalf (templated).
     :param tags: Tags that you apply to the compute-environment to help you
         categorize and organize your resources.
-    :param max_retries: Exponential back-off retries, 4200 = 48 hours; polling
-        is only used when waiters is None.
-    :param status_retries: Number of HTTP retries to get job status, 10; polling
-        is only used when waiters is None.
+    :param poll_interval: How long to wait in seconds between 2 polls at the environment status.
+        Only useful when deferrable is True.
+    :param max_retries: How many times to poll for the environment status.
+        Only useful when deferrable is True.
     :param aws_conn_id: Connection ID of AWS credentials / region name. If None,
         credential boto3 strategy will be used.
     :param region_name: Region name to use in AWS Hook. Overrides the
         ``region_name`` in connection if provided.
+    :param deferrable: If True, the operator will wait asynchronously for the environment to be created.
+        This mode requires aiobotocore module to be installed. (default: False)
     """
 
     template_fields: Sequence[str] = (
@@ -428,13 +433,24 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
         unmanaged_v_cpus: int | None = None,
         service_role: str | None = None,
         tags: dict | None = None,
+        poll_interval: int = 30,
         max_retries: int | None = None,
-        status_retries: int | None = None,
         aws_conn_id: str | None = None,
         region_name: str | None = None,
+        deferrable: bool = False,
         **kwargs,
     ):
+        if "status_retries" in kwargs:
+            warnings.warn(
+                "The `status_retries` parameter is unused and should be removed. "
+                "It'll be deleted in a future version.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            kwargs.pop("status_retries")  # remove before calling super() to prevent unexpected arg error
+
         super().__init__(**kwargs)
+
         self.compute_environment_name = compute_environment_name
         self.environment_type = environment_type
         self.state = state
@@ -442,17 +458,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
         self.compute_resources = compute_resources
         self.service_role = service_role
         self.tags = tags or {}
-        self.max_retries = max_retries
-        self.status_retries = status_retries
+        self.poll_interval = poll_interval
+        self.max_retries = max_retries or 120
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
+        self.deferrable = deferrable
 
     @cached_property
     def hook(self):
         """Create and return a BatchClientHook."""
         return BatchClientHook(
-            max_retries=self.max_retries,
-            status_retries=self.status_retries,
             aws_conn_id=self.aws_conn_id,
             region_name=self.region_name,
         )
@@ -468,6 +483,21 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
             "serviceRole": self.service_role,
             "tags": self.tags,
         }
-        self.hook.client.create_compute_environment(**trim_none_values(kwargs))
+        response = self.hook.client.create_compute_environment(**trim_none_values(kwargs))
+        arn = response["computeEnvironmentArn"]
+
+        if self.deferrable:
+            self.defer(
+                trigger=BatchCreateComputeEnvironmentTrigger(
+                    arn, self.poll_interval, self.max_retries, self.aws_conn_id, self.region_name
+                ),
+                method_name="execute_complete",
+            )
 
         self.log.info("AWS Batch compute environment created successfully")
+        return arn
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}")
+        return event["value"]
diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py
index f4a5de1525..b0bdbc0d45 100644
--- a/airflow/providers/amazon/aws/triggers/batch.py
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -23,6 +23,7 @@ from typing import Any
 from botocore.exceptions import WaiterError
 
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
@@ -188,3 +189,59 @@ class BatchSensorTrigger(BaseTrigger):
                     "message": f"Job {self.job_id} Succeeded",
                 }
             )
+
+
+class BatchCreateComputeEnvironmentTrigger(BaseTrigger):
+    """
+    Trigger for BatchCreateComputeEnvironmentOperator.
+    The trigger will asynchronously poll the boto3 API and wait for the compute environment to be ready.
+
+    :param job_id:  A unique identifier for the cluster.
+    :param max_retries: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: region name to use in AWS Hook
+    :param poll_interval: The amount of time in seconds to wait between attempts.
+    """
+
+    def __init__(
+        self,
+        compute_env_arn: str | None = None,
+        poll_interval: int = 30,
+        max_retries: int = 10,
+        aws_conn_id: str | None = "aws_default",
+        region_name: str | None = None,
+    ):
+        super().__init__()
+        self.compute_env_arn = compute_env_arn
+        self.max_retries = max_retries
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes BatchOperatorTrigger arguments and classpath."""
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "compute_env_arn": self.compute_env_arn,
+                "max_retries": self.max_retries,
+                "aws_conn_id": self.aws_conn_id,
+                "region_name": self.region_name,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    async def run(self):
+        hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+        async with hook.async_conn as client:
+            waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client)
+            await async_wait(
+                waiter=waiter,
+                waiter_delay=self.poll_interval,
+                waiter_max_attempts=self.max_retries,
+                args={"computeEnvironments": [self.compute_env_arn]},
+                failure_message="Failure while creating Compute Environment",
+                status_message="Compute Environment not ready yet",
+                status_args=["computeEnvironments[].status", "computeEnvironments[].statusReason"],
+            )
+            yield TriggerEvent({"status": "success", "value": self.compute_env_arn})
diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json
index fa9752ea14..3fbdd43377 100644
--- a/airflow/providers/amazon/aws/waiters/batch.json
+++ b/airflow/providers/amazon/aws/waiters/batch.json
@@ -20,6 +20,32 @@
           "state": "failed"
         }
       ]
+    },
+
+    "compute_env_ready": {
+      "delay": 30,
+      "operation": "DescribeComputeEnvironments",
+      "maxAttempts": 100,
+      "acceptors": [
+        {
+          "argument": "computeEnvironments[].status",
+          "expected": "VALID",
+          "matcher": "pathAll",
+          "state": "success"
+        },
+        {
+          "argument": "computeEnvironments[].status",
+          "expected": "INVALID",
+          "matcher": "pathAny",
+          "state": "failed"
+        },
+        {
+          "argument": "computeEnvironments[].status",
+          "expected": "DELETED",
+          "matcher": "pathAny",
+          "state": "failed"
+        }
+      ]
     }
   }
 }
diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py
index a65e00d8db..3aace0bb3e 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -27,7 +27,10 @@ from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
 from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator
 
 # Use dummy AWS credentials
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.providers.amazon.aws.triggers.batch import (
+    BatchCreateComputeEnvironmentTrigger,
+    BatchOperatorTrigger,
+)
 
 AWS_REGION = "eu-west-1"
 AWS_ACCESS_KEY_ID = "airflow_dummy_key"
@@ -326,3 +329,26 @@ class TestBatchCreateComputeEnvironmentOperator:
             computeResources=compute_resources,
             tags=tags,
         )
+
+    @mock.patch.object(BatchClientHook, "client")
+    def test_defer(self, client_mock):
+        client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"}
+
+        operator = BatchCreateComputeEnvironmentOperator(
+            task_id="task",
+            compute_environment_name="my_env_name",
+            environment_type="my_env_type",
+            state="my_state",
+            compute_resources={},
+            max_retries=123456,
+            poll_interval=456789,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as deferred:
+            operator.execute(None)
+
+        assert isinstance(deferred.value.trigger, BatchCreateComputeEnvironmentTrigger)
+        assert deferred.value.trigger.compute_env_arn == "my_arn"
+        assert deferred.value.trigger.poll_interval == 456789
+        assert deferred.value.trigger.max_retries == 123456
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py
index 5cf125f828..e337360762 100644
--- a/tests/providers/amazon/aws/triggers/test_batch.py
+++ b/tests/providers/amazon/aws/triggers/test_batch.py
@@ -22,7 +22,13 @@ from unittest.mock import AsyncMock
 import pytest
 from botocore.exceptions import WaiterError
 
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger
+from airflow import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.triggers.batch import (
+    BatchCreateComputeEnvironmentTrigger,
+    BatchOperatorTrigger,
+    BatchSensorTrigger,
+)
 from airflow.triggers.base import TriggerEvent
 
 BATCH_JOB_ID = "job_id"
@@ -181,3 +187,38 @@ class TestBatchSensorTrigger:
         assert actual_response == TriggerEvent(
             {"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"}
         )
+
+
+class TestBatchCreateComputeEnvironmentTrigger:
+    @pytest.mark.asyncio
+    @mock.patch.object(BatchClientHook, "async_conn")
+    @mock.patch.object(BatchClientHook, "get_waiter")
+    async def test_success(self, get_waiter_mock, conn_mock):
+        get_waiter_mock().wait = AsyncMock(
+            side_effect=[
+                WaiterError(
+                    "situation normal", "first try", {"computeEnvironments": [{"status": "my_status"}]}
+                ),
+                {},
+            ]
+        )
+        trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3)
+
+        generator = trigger.run()
+        response: TriggerEvent = await generator.asend(None)
+
+        assert response.payload["status"] == "success"
+        assert response.payload["value"] == "my_arn"
+
+    @pytest.mark.asyncio
+    @mock.patch.object(BatchClientHook, "async_conn")
+    @mock.patch.object(BatchClientHook, "get_waiter")
+    async def test_failure(self, get_waiter_mock, conn_mock):
+        get_waiter_mock().wait = AsyncMock(
+            side_effect=[WaiterError("terminal failure", "terminal failure reason", {})]
+        )
+        trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3)
+
+        with pytest.raises(AirflowException):
+            generator = trigger.run()
+            await generator.asend(None)