You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/06/27 22:52:47 UTC
[airflow] branch main updated: Add a deferrable mode to `BatchCreateComputeEnvironmentOperator` (#32036)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1fb2831239 Add a deferrable mode to `BatchCreateComputeEnvironmentOperator` (#32036)
1fb2831239 is described below
commit 1fb28312393a59bb064e4a1cade59de5c86ef16a
Author: Raphaƫl Vandon <va...@amazon.com>
AuthorDate: Tue Jun 27 15:52:41 2023 -0700
Add a deferrable mode to `BatchCreateComputeEnvironmentOperator` (#32036)
---
airflow/providers/amazon/aws/operators/batch.py | 52 +++++++++++++++-----
airflow/providers/amazon/aws/triggers/batch.py | 57 ++++++++++++++++++++++
airflow/providers/amazon/aws/waiters/batch.json | 26 ++++++++++
tests/providers/amazon/aws/operators/test_batch.py | 28 ++++++++++-
tests/providers/amazon/aws/triggers/test_batch.py | 43 +++++++++++++++-
5 files changed, 193 insertions(+), 13 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py
index 88feb01311..b9b3322c49 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -38,7 +38,10 @@ from airflow.providers.amazon.aws.links.batch import (
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.providers.amazon.aws.triggers.batch import (
+ BatchCreateComputeEnvironmentTrigger,
+ BatchOperatorTrigger,
+)
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
@@ -402,14 +405,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
services on your behalf (templated).
:param tags: Tags that you apply to the compute-environment to help you
categorize and organize your resources.
- :param max_retries: Exponential back-off retries, 4200 = 48 hours; polling
- is only used when waiters is None.
- :param status_retries: Number of HTTP retries to get job status, 10; polling
- is only used when waiters is None.
+ :param poll_interval: How long to wait in seconds between 2 polls at the environment status.
+ Only useful when deferrable is True.
+ :param max_retries: How many times to poll for the environment status.
+ Only useful when deferrable is True.
:param aws_conn_id: Connection ID of AWS credentials / region name. If None,
credential boto3 strategy will be used.
:param region_name: Region name to use in AWS Hook. Overrides the
``region_name`` in connection if provided.
+ :param deferrable: If True, the operator will wait asynchronously for the environment to be created.
+ This mode requires aiobotocore module to be installed. (default: False)
"""
template_fields: Sequence[str] = (
@@ -428,13 +433,24 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
unmanaged_v_cpus: int | None = None,
service_role: str | None = None,
tags: dict | None = None,
+ poll_interval: int = 30,
max_retries: int | None = None,
- status_retries: int | None = None,
aws_conn_id: str | None = None,
region_name: str | None = None,
+ deferrable: bool = False,
**kwargs,
):
+ if "status_retries" in kwargs:
+ warnings.warn(
+ "The `status_retries` parameter is unused and should be removed. "
+ "It'll be deleted in a future version.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ kwargs.pop("status_retries") # remove before calling super() to prevent unexpected arg error
+
super().__init__(**kwargs)
+
self.compute_environment_name = compute_environment_name
self.environment_type = environment_type
self.state = state
@@ -442,17 +458,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
self.compute_resources = compute_resources
self.service_role = service_role
self.tags = tags or {}
- self.max_retries = max_retries
- self.status_retries = status_retries
+ self.poll_interval = poll_interval
+ self.max_retries = max_retries or 120
self.aws_conn_id = aws_conn_id
self.region_name = region_name
+ self.deferrable = deferrable
@cached_property
def hook(self):
"""Create and return a BatchClientHook."""
return BatchClientHook(
- max_retries=self.max_retries,
- status_retries=self.status_retries,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
)
@@ -468,6 +483,21 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
"serviceRole": self.service_role,
"tags": self.tags,
}
- self.hook.client.create_compute_environment(**trim_none_values(kwargs))
+ response = self.hook.client.create_compute_environment(**trim_none_values(kwargs))
+ arn = response["computeEnvironmentArn"]
+
+ if self.deferrable:
+ self.defer(
+ trigger=BatchCreateComputeEnvironmentTrigger(
+ arn, self.poll_interval, self.max_retries, self.aws_conn_id, self.region_name
+ ),
+ method_name="execute_complete",
+ )
self.log.info("AWS Batch compute environment created successfully")
+ return arn
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}")
+ return event["value"]
diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py
index f4a5de1525..b0bdbc0d45 100644
--- a/airflow/providers/amazon/aws/triggers/batch.py
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -23,6 +23,7 @@ from typing import Any
from botocore.exceptions import WaiterError
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -188,3 +189,59 @@ class BatchSensorTrigger(BaseTrigger):
"message": f"Job {self.job_id} Succeeded",
}
)
+
+
+class BatchCreateComputeEnvironmentTrigger(BaseTrigger):
+ """
+ Trigger for BatchCreateComputeEnvironmentOperator.
+ The trigger will asynchronously poll the boto3 API and wait for the compute environment to be ready.
+
+ :param job_id: A unique identifier for the cluster.
+ :param max_retries: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: region name to use in AWS Hook
+ :param poll_interval: The amount of time in seconds to wait between attempts.
+ """
+
+ def __init__(
+ self,
+ compute_env_arn: str | None = None,
+ poll_interval: int = 30,
+ max_retries: int = 10,
+ aws_conn_id: str | None = "aws_default",
+ region_name: str | None = None,
+ ):
+ super().__init__()
+ self.compute_env_arn = compute_env_arn
+ self.max_retries = max_retries
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes BatchOperatorTrigger arguments and classpath."""
+ return (
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
+ {
+ "compute_env_arn": self.compute_env_arn,
+ "max_retries": self.max_retries,
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self):
+ hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+ async with hook.async_conn as client:
+ waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client)
+ await async_wait(
+ waiter=waiter,
+ waiter_delay=self.poll_interval,
+ waiter_max_attempts=self.max_retries,
+ args={"computeEnvironments": [self.compute_env_arn]},
+ failure_message="Failure while creating Compute Environment",
+ status_message="Compute Environment not ready yet",
+ status_args=["computeEnvironments[].status", "computeEnvironments[].statusReason"],
+ )
+ yield TriggerEvent({"status": "success", "value": self.compute_env_arn})
diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json
index fa9752ea14..3fbdd43377 100644
--- a/airflow/providers/amazon/aws/waiters/batch.json
+++ b/airflow/providers/amazon/aws/waiters/batch.json
@@ -20,6 +20,32 @@
"state": "failed"
}
]
+ },
+
+ "compute_env_ready": {
+ "delay": 30,
+ "operation": "DescribeComputeEnvironments",
+ "maxAttempts": 100,
+ "acceptors": [
+ {
+ "argument": "computeEnvironments[].status",
+ "expected": "VALID",
+ "matcher": "pathAll",
+ "state": "success"
+ },
+ {
+ "argument": "computeEnvironments[].status",
+ "expected": "INVALID",
+ "matcher": "pathAny",
+ "state": "failed"
+ },
+ {
+ "argument": "computeEnvironments[].status",
+ "expected": "DELETED",
+ "matcher": "pathAny",
+ "state": "failed"
+ }
+ ]
}
}
}
diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py
index a65e00d8db..3aace0bb3e 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -27,7 +27,10 @@ from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator
# Use dummy AWS credentials
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.providers.amazon.aws.triggers.batch import (
+ BatchCreateComputeEnvironmentTrigger,
+ BatchOperatorTrigger,
+)
AWS_REGION = "eu-west-1"
AWS_ACCESS_KEY_ID = "airflow_dummy_key"
@@ -326,3 +329,26 @@ class TestBatchCreateComputeEnvironmentOperator:
computeResources=compute_resources,
tags=tags,
)
+
+ @mock.patch.object(BatchClientHook, "client")
+ def test_defer(self, client_mock):
+ client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"}
+
+ operator = BatchCreateComputeEnvironmentOperator(
+ task_id="task",
+ compute_environment_name="my_env_name",
+ environment_type="my_env_type",
+ state="my_state",
+ compute_resources={},
+ max_retries=123456,
+ poll_interval=456789,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as deferred:
+ operator.execute(None)
+
+ assert isinstance(deferred.value.trigger, BatchCreateComputeEnvironmentTrigger)
+ assert deferred.value.trigger.compute_env_arn == "my_arn"
+ assert deferred.value.trigger.poll_interval == 456789
+ assert deferred.value.trigger.max_retries == 123456
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py
index 5cf125f828..e337360762 100644
--- a/tests/providers/amazon/aws/triggers/test_batch.py
+++ b/tests/providers/amazon/aws/triggers/test_batch.py
@@ -22,7 +22,13 @@ from unittest.mock import AsyncMock
import pytest
from botocore.exceptions import WaiterError
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger
+from airflow import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.triggers.batch import (
+ BatchCreateComputeEnvironmentTrigger,
+ BatchOperatorTrigger,
+ BatchSensorTrigger,
+)
from airflow.triggers.base import TriggerEvent
BATCH_JOB_ID = "job_id"
@@ -181,3 +187,38 @@ class TestBatchSensorTrigger:
assert actual_response == TriggerEvent(
{"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"}
)
+
+
+class TestBatchCreateComputeEnvironmentTrigger:
+ @pytest.mark.asyncio
+ @mock.patch.object(BatchClientHook, "async_conn")
+ @mock.patch.object(BatchClientHook, "get_waiter")
+ async def test_success(self, get_waiter_mock, conn_mock):
+ get_waiter_mock().wait = AsyncMock(
+ side_effect=[
+ WaiterError(
+ "situation normal", "first try", {"computeEnvironments": [{"status": "my_status"}]}
+ ),
+ {},
+ ]
+ )
+ trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3)
+
+ generator = trigger.run()
+ response: TriggerEvent = await generator.asend(None)
+
+ assert response.payload["status"] == "success"
+ assert response.payload["value"] == "my_arn"
+
+ @pytest.mark.asyncio
+ @mock.patch.object(BatchClientHook, "async_conn")
+ @mock.patch.object(BatchClientHook, "get_waiter")
+ async def test_failure(self, get_waiter_mock, conn_mock):
+ get_waiter_mock().wait = AsyncMock(
+ side_effect=[WaiterError("terminal failure", "terminal failure reason", {})]
+ )
+ trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3)
+
+ with pytest.raises(AirflowException):
+ generator = trigger.run()
+ await generator.asend(None)