You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/18 18:10:19 UTC
[airflow] branch main updated: Convert the batch sample dag to system tests (AIP-47) (#24448)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b7f51b9156 Convert the batch sample dag to system tests (AIP-47) (#24448)
b7f51b9156 is described below
commit b7f51b9156b780ebf4ca57b9f10b820043f61651
Author: Vincent <97...@users.noreply.github.com>
AuthorDate: Mon Jul 18 14:10:07 2022 -0400
Convert the batch sample dag to system tests (AIP-47) (#24448)
---
.../amazon/aws/example_dags/example_batch.py | 66 ------
airflow/providers/amazon/aws/hooks/batch_client.py | 6 +
airflow/providers/amazon/aws/operators/batch.py | 125 ++++++++++-
.../amazon/aws/operators/redshift_data.py | 4 +-
airflow/providers/amazon/aws/sensors/batch.py | 132 +++++++++++
airflow/providers/amazon/aws/utils/__init__.py | 4 +
.../operators/batch.rst | 46 +++-
docs/spelling_wordlist.txt | 2 +
tests/providers/amazon/aws/operators/test_batch.py | 40 +++-
tests/providers/amazon/aws/sensors/test_batch.py | 132 ++++++++++-
tests/providers/amazon/aws/utils/test_utils.py | 18 +-
.../system/providers/amazon/aws/example_athena.py | 2 +-
tests/system/providers/amazon/aws/example_batch.py | 250 +++++++++++++++++++++
.../system/providers/amazon/aws/utils/__init__.py | 5 +
14 files changed, 755 insertions(+), 77 deletions(-)
diff --git a/airflow/providers/amazon/aws/example_dags/example_batch.py b/airflow/providers/amazon/aws/example_dags/example_batch.py
deleted file mode 100644
index 959e4de52d..0000000000
--- a/airflow/providers/amazon/aws/example_dags/example_batch.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from json import loads
-from os import environ
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.batch import BatchOperator
-from airflow.providers.amazon.aws.sensors.batch import BatchSensor
-
-# The inputs below are required for the submit batch example DAG.
-JOB_NAME = environ.get('BATCH_JOB_NAME', 'example_job_name')
-JOB_DEFINITION = environ.get('BATCH_JOB_DEFINITION', 'example_job_definition')
-JOB_QUEUE = environ.get('BATCH_JOB_QUEUE', 'example_job_queue')
-JOB_OVERRIDES = loads(environ.get('BATCH_JOB_OVERRIDES', '{}'))
-
-# An existing (externally triggered) job id is required for the sensor example DAG.
-JOB_ID = environ.get('BATCH_JOB_ID', '00000000-0000-0000-0000-000000000000')
-
-
-with DAG(
- dag_id='example_batch_submit_job',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as submit_dag:
-
- # [START howto_operator_batch]
- submit_batch_job = BatchOperator(
- task_id='submit_batch_job',
- job_name=JOB_NAME,
- job_queue=JOB_QUEUE,
- job_definition=JOB_DEFINITION,
- overrides=JOB_OVERRIDES,
- )
- # [END howto_operator_batch]
-
-with DAG(
- dag_id='example_batch_wait_for_job_sensor',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as sensor_dag:
-
- # [START howto_sensor_batch]
- wait_for_batch_job = BatchSensor(
- task_id='wait_for_batch_job',
- job_id=JOB_ID,
- )
- # [END howto_sensor_batch]
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py
index 873e5c5779..287e319f51 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -192,6 +192,12 @@ class BatchClientHook(AwsBaseHook):
RUNNING_STATE,
)
+ COMPUTE_ENVIRONMENT_TERMINAL_STATUS = ('VALID', 'DELETED')
+ COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS = ('CREATING', 'UPDATING', 'DELETING')
+
+ JOB_QUEUE_TERMINAL_STATUS = ('VALID', 'DELETED')
+ JOB_QUEUE_INTERMEDIATE_STATUS = ('CREATING', 'UPDATING', 'DELETING')
+
def __init__(
self, *args, max_retries: Optional[int] = None, status_retries: Optional[int] = None, **kwargs
) -> None:
diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py
index 037e25c2ba..b74d47ea30 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -25,8 +25,16 @@ An Airflow operator for AWS Batch services
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
+import sys
import warnings
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
+
+from airflow.providers.amazon.aws.utils import trim_none_values
+
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else:
+ from cached_property import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -96,6 +104,8 @@ class BatchOperator(BaseOperator):
arn = None # type: Optional[str]
template_fields: Sequence[str] = (
"job_name",
+ "job_queue",
+ "job_definition",
"overrides",
"parameters",
)
@@ -103,7 +113,9 @@ class BatchOperator(BaseOperator):
@property
def operator_extra_links(self):
- op_extra_links = [BatchJobDetailsLink(), BatchJobDefinitionLink(), BatchJobQueueLink()]
+ op_extra_links = [BatchJobDetailsLink()]
+ if self.wait_for_completion:
+ op_extra_links.extend(BatchJobDefinitionLink(), BatchJobQueueLink())
if not self.array_properties:
# There is no CloudWatch Link to the parent Batch Job available.
op_extra_links.append(CloudWatchEventsLink())
@@ -126,6 +138,7 @@ class BatchOperator(BaseOperator):
aws_conn_id: Optional[str] = None,
region_name: Optional[str] = None,
tags: Optional[dict] = None,
+ wait_for_completion: bool = True,
**kwargs,
):
@@ -139,6 +152,7 @@ class BatchOperator(BaseOperator):
self.parameters = parameters or {}
self.waiters = waiters
self.tags = tags or {}
+ self.wait_for_completion = wait_for_completion
self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
@@ -153,7 +167,11 @@ class BatchOperator(BaseOperator):
:raises: AirflowException
"""
self.submit_job(context)
- self.monitor_job(context)
+
+ if self.wait_for_completion:
+ self.monitor_job(context)
+
+ return self.job_id
def on_kill(self):
response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
@@ -260,6 +278,107 @@ class BatchOperator(BaseOperator):
self.log.info("AWS Batch job (%s) succeeded", self.job_id)
+class BatchCreateComputeEnvironmentOperator(BaseOperator):
+ """
+ Create an AWS Batch compute environment
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:BatchCreateComputeEnvironmentOperator`
+
+ :param compute_environment_name: the name of the AWS batch compute environment (templated)
+
+ :param environment_type: the type of the compute-environment
+
+ :param state: the state of the compute-environment
+
+ :param compute_resources: details about the resources managed by the compute-environment (templated).
+ See more details here
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html#Batch.Client.create_compute_environment
+
+ :param unmanaged_v_cpus: the maximum number of vCPU for an unmanaged compute environment.
+ This parameter is only supported when the ``type`` parameter is set to ``UNMANAGED``.
+
+ :param service_role: the IAM role that allows Batch to make calls to other AWS services on your behalf
+ (templated)
+
+ :param tags: the tags that you apply to the compute-environment to help you categorize and organize your
+ resources
+
+ :param max_retries: exponential back-off retries, 4200 = 48 hours;
+ polling is only used when waiters is None
+
+ :param status_retries: number of HTTP retries to get job status, 10;
+ polling is only used when waiters is None
+
+ :param aws_conn_id: connection id of AWS credentials / region name. If None,
+ credential boto3 strategy will be used.
+
+ :param region_name: region name to use in AWS Hook.
+ Override the region_name in connection (if provided)
+ """
+
+ template_fields: Sequence[str] = (
+ "compute_environment_name",
+ "compute_resources",
+ "service_role",
+ )
+ template_fields_renderers = {"compute_resources": "json"}
+
+ def __init__(
+ self,
+ compute_environment_name: str,
+ environment_type: str,
+ state: str,
+ compute_resources: dict,
+ unmanaged_v_cpus: Optional[int] = None,
+ service_role: Optional[str] = None,
+ tags: Optional[dict] = None,
+ max_retries: Optional[int] = None,
+ status_retries: Optional[int] = None,
+ aws_conn_id: Optional[str] = None,
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.compute_environment_name = compute_environment_name
+ self.environment_type = environment_type
+ self.state = state
+ self.unmanaged_v_cpus = unmanaged_v_cpus
+ self.compute_resources = compute_resources
+ self.service_role = service_role
+ self.tags = tags or {}
+ self.max_retries = max_retries
+ self.status_retries = status_retries
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ @cached_property
+ def hook(self):
+ """Create and return a BatchClientHook"""
+ return BatchClientHook(
+ max_retries=self.max_retries,
+ status_retries=self.status_retries,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ )
+
+ def execute(self, context: 'Context'):
+ """Create an AWS batch compute environment"""
+ kwargs: Dict[str, Any] = {
+ 'computeEnvironmentName': self.compute_environment_name,
+ 'type': self.environment_type,
+ 'state': self.state,
+ 'unmanagedvCpus': self.unmanaged_v_cpus,
+ 'computeResources': self.compute_resources,
+ 'serviceRole': self.service_role,
+ 'tags': self.tags,
+ }
+ self.hook.client.create_compute_environment(**trim_none_values(kwargs))
+
+ self.log.info('AWS Batch compute environment created successfully')
+
+
class AwsBatchOperator(BatchOperator):
"""
This operator is deprecated.
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index f23ca928eb..a2400f94bd 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -115,8 +116,7 @@ class RedshiftDataOperator(BaseOperator):
"StatementName": self.statement_name,
}
- filter_values = {key: val for key, val in kwargs.items() if val is not None}
- resp = self.hook.conn.execute_statement(**filter_values)
+ resp = self.hook.conn.execute_statement(**trim_none_values(kwargs))
return resp['Id']
def wait_for_results(self, statement_id):
diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py
index faab424e31..3ad288c904 100644
--- a/airflow/providers/amazon/aws/sensors/batch.py
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -15,8 +15,13 @@
# specific language governing permissions and limitations
# under the License.
+import sys
from typing import TYPE_CHECKING, Optional, Sequence
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else:
+ from cached_property import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.sensors.base import BaseSensorOperator
@@ -36,6 +41,7 @@ class BatchSensor(BaseSensorOperator):
:param job_id: Batch job_id to check the state for
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :param region_name: aws region name associated with the client
"""
template_fields: Sequence[str] = ('job_id',)
@@ -81,3 +87,129 @@ class BatchSensor(BaseSensorOperator):
region_name=self.region_name,
)
return self.hook
+
+
+class BatchComputeEnvironmentSensor(BaseSensorOperator):
+ """
+ Asks for the state of the Batch compute environment until it reaches a failure state or success state.
+ If the environment fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:BatchComputeEnvironmentSensor`
+
+ :param compute_environment: Batch compute environment name
+
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+
+ :param region_name: aws region name associated with the client
+ """
+
+ template_fields: Sequence[str] = ('compute_environment',)
+ template_ext: Sequence[str] = ()
+ ui_color = '#66c3ff'
+
+ def __init__(
+ self,
+ compute_environment: str,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.compute_environment = compute_environment
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ """Create and return a BatchClientHook"""
+ return BatchClientHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ )
+
+ def poke(self, context: 'Context') -> bool:
+ response = self.hook.client.describe_compute_environments(
+ computeEnvironments=[self.compute_environment]
+ )
+
+ if len(response['computeEnvironments']) == 0:
+ raise AirflowException(f'AWS Batch compute environment {self.compute_environment} not found')
+
+ status = response['computeEnvironments'][0]['status']
+
+ if status in BatchClientHook.COMPUTE_ENVIRONMENT_TERMINAL_STATUS:
+ return True
+
+ if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS:
+ return False
+
+ raise AirflowException(
+ f'AWS Batch compute environment failed. AWS Batch compute environment status: {status}'
+ )
+
+
+class BatchJobQueueSensor(BaseSensorOperator):
+ """
+ Asks for the state of the Batch job queue until it reaches a failure state or success state.
+ If the queue fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:BatchJobQueueSensor`
+
+ :param job_queue: Batch job queue name
+
+ :param treat_non_existing_as_deleted: If True, a non-existing Batch job queue is considered as a deleted
+ queue and as such a valid case.
+
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+
+ :param region_name: aws region name associated with the client
+ """
+
+ template_fields: Sequence[str] = ('job_queue',)
+ template_ext: Sequence[str] = ()
+ ui_color = '#66c3ff'
+
+ def __init__(
+ self,
+ job_queue: str,
+ treat_non_existing_as_deleted: bool = False,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.job_queue = job_queue
+ self.treat_non_existing_as_deleted = treat_non_existing_as_deleted
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ """Create and return a BatchClientHook"""
+ return BatchClientHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ )
+
+ def poke(self, context: 'Context') -> bool:
+ response = self.hook.client.describe_job_queues(jobQueues=[self.job_queue])
+
+ if len(response['jobQueues']) == 0:
+ if self.treat_non_existing_as_deleted:
+ return True
+ else:
+ raise AirflowException(f'AWS Batch job queue {self.job_queue} not found')
+
+ status = response['jobQueues'][0]['status']
+
+ if status in BatchClientHook.JOB_QUEUE_TERMINAL_STATUS:
+ return True
+
+ if status in BatchClientHook.JOB_QUEUE_INTERMEDIATE_STATUS:
+ return False
+
+ raise AirflowException(f'AWS Batch job queue failed. AWS Batch job queue status: {status}')
diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py
index 7f127f7178..251276e738 100644
--- a/airflow/providers/amazon/aws/utils/__init__.py
+++ b/airflow/providers/amazon/aws/utils/__init__.py
@@ -22,6 +22,10 @@ from typing import Tuple
from airflow.version import version
+def trim_none_values(obj: dict):
+ return {key: val for key, val in obj.items() if val is not None}
+
+
def datetime_to_epoch(date_time: datetime) -> int:
"""Convert a datetime object to an epoch integer (seconds)."""
return int(date_time.timestamp())
diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst b/docs/apache-airflow-providers-amazon/operators/batch.rst
index 663cf24cec..ba280cb38d 100644
--- a/docs/apache-airflow-providers-amazon/operators/batch.rst
+++ b/docs/apache-airflow-providers-amazon/operators/batch.rst
@@ -40,12 +40,26 @@ Submit a new AWS Batch job
To submit a new AWS Batch job and monitor it until it reaches a terminal state you can
use :class:`~airflow.providers.amazon.aws.operators.batch.BatchOperator`.
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_batch.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
:language: python
:dedent: 4
:start-after: [START howto_operator_batch]
:end-before: [END howto_operator_batch]
+.. _howto/operator:BatchCreateComputeEnvironmentOperator:
+
+Create an AWS Batch compute environment
+=======================================
+
+To create a new AWS Batch compute environment you can
+use :class:`~airflow.providers.amazon.aws.operators.batch.BatchCreateComputeEnvironmentOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_batch_create_compute_environment]
+ :end-before: [END howto_operator_batch_create_compute_environment]
+
Sensors
-------
@@ -57,12 +71,40 @@ Wait on an AWS Batch job state
To wait on the state of an AWS Batch Job until it reaches a terminal state you can
use :class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor`.
-.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_batch.py
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_batch]
:end-before: [END howto_sensor_batch]
+.. _howto/sensor:BatchComputeEnvironmentSensor:
+
+Wait on an AWS Batch compute environment status
+===============================================
+
+To wait on the status of an AWS Batch compute environment until it reaches a terminal status you can
+use :class:`~airflow.providers.amazon.aws.sensors.batch.BatchComputeEnvironmentSensor`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_batch_compute_environment]
+ :end-before: [END howto_sensor_batch_compute_environment]
+
+.. _howto/sensor:BatchJobQueueSensor:
+
+Wait on an AWS Batch job queue status
+=====================================
+
+To wait on the status of an AWS Batch job queue until it reaches a terminal status you can
+use :class:`~airflow.providers.amazon.aws.sensors.batch.BatchJobQueueSensor`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_batch_job_queue]
+ :end-before: [END howto_sensor_batch_job_queue]
+
Reference
---------
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 1d821140c5..7c4b2fa4ed 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1517,6 +1517,7 @@ unicode
unittest
unittests
unix
+unmanaged
unmappable
unmapped
unmapping
@@ -1551,6 +1552,7 @@ util
utilise
utils
uuid
+vCPU
validator
vals
ve
diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py
index 7b32858f90..b26b12c2ce 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -25,7 +25,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
-from airflow.providers.amazon.aws.operators.batch import BatchOperator
+from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator
# Use dummy AWS credentials
AWS_REGION = "eu-west-1"
@@ -96,12 +96,15 @@ class TestBatchOperator(unittest.TestCase):
assert self.batch.hook.aws_conn_id == "airflow_test"
assert self.batch.hook.client == self.client_mock
assert self.batch.tags == {}
+ assert self.batch.wait_for_completion is True
self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1")
def test_template_fields_overrides(self):
assert self.batch.template_fields == (
"job_name",
+ "job_queue",
+ "job_definition",
"overrides",
"parameters",
)
@@ -163,7 +166,42 @@ class TestBatchOperator(unittest.TestCase):
mock_waiters.wait_for_job.assert_called_once_with(JOB_ID)
check_mock.assert_called_once_with(JOB_ID)
+ @mock.patch.object(BatchClientHook, "check_job_success")
+ def test_do_not_wait_job_complete(self, check_mock):
+ self.batch.wait_for_completion = False
+
+ self.client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
+ self.batch.execute(self.mock_context)
+
+ check_mock.assert_not_called()
+
def test_kill_job(self):
self.client_mock.terminate_job.return_value = {}
self.batch.on_kill()
self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, reason="Task killed by the user")
+
+
+class TestBatchCreateComputeEnvironmentOperator(unittest.TestCase):
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_execute(self, mock_conn):
+ environment_name = 'environment_name'
+ environment_type = 'environment_type'
+ environment_state = 'environment_state'
+ compute_resources = {}
+ tags = {}
+ operator = BatchCreateComputeEnvironmentOperator(
+ task_id='task',
+ compute_environment_name=environment_name,
+ environment_type=environment_type,
+ state=environment_state,
+ compute_resources=compute_resources,
+ tags=tags,
+ )
+ operator.execute(None)
+ mock_conn.create_compute_environment.assert_called_once_with(
+ computeEnvironmentName=environment_name,
+ type=environment_type,
+ state=environment_state,
+ computeResources=compute_resources,
+ tags=tags,
+ )
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py
index e6c3bde7ee..259d5870d5 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -18,11 +18,16 @@
import unittest
from unittest import mock
+import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
-from airflow.providers.amazon.aws.sensors.batch import BatchSensor
+from airflow.providers.amazon.aws.sensors.batch import (
+ BatchComputeEnvironmentSensor,
+ BatchJobQueueSensor,
+ BatchSensor,
+)
TASK_ID = 'batch_job_sensor'
JOB_ID = '8222a1c2-b246-4e19-b1b8-0039bb4407c0'
@@ -73,3 +78,128 @@ class TestBatchSensor(unittest.TestCase):
mock_get_job_description.return_value = {'status': job_status}
self.assertFalse(self.batch_sensor.poke({}))
mock_get_job_description.assert_called_once_with(JOB_ID)
+
+
+class TestBatchComputeEnvironmentSensor(unittest.TestCase):
+ def setUp(self):
+ self.environment_name = 'environment_name'
+ self.sensor = BatchComputeEnvironmentSensor(
+ task_id='test_batch_compute_environment_sensor',
+ compute_environment=self.environment_name,
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_no_environment(self, mock_batch_client):
+ mock_batch_client.describe_compute_environments.return_value = {'computeEnvironments': []}
+ with pytest.raises(AirflowException) as ctx:
+ self.sensor.poke({})
+ mock_batch_client.describe_compute_environments.assert_called_once_with(
+ computeEnvironments=[self.environment_name],
+ )
+ assert 'not found' in str(ctx.value)
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_valid(self, mock_batch_client):
+ mock_batch_client.describe_compute_environments.return_value = {
+ 'computeEnvironments': [{'status': 'VALID'}]
+ }
+ assert self.sensor.poke({})
+ mock_batch_client.describe_compute_environments.assert_called_once_with(
+ computeEnvironments=[self.environment_name],
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_running(self, mock_batch_client):
+ mock_batch_client.describe_compute_environments.return_value = {
+ 'computeEnvironments': [
+ {
+ 'status': 'CREATING',
+ }
+ ]
+ }
+ assert not self.sensor.poke({})
+ mock_batch_client.describe_compute_environments.assert_called_once_with(
+ computeEnvironments=[self.environment_name],
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_invalid(self, mock_batch_client):
+ mock_batch_client.describe_compute_environments.return_value = {
+ 'computeEnvironments': [
+ {
+ 'status': 'INVALID',
+ }
+ ]
+ }
+ with pytest.raises(AirflowException) as ctx:
+ self.sensor.poke({})
+ mock_batch_client.describe_compute_environments.assert_called_once_with(
+ computeEnvironments=[self.environment_name],
+ )
+ assert 'AWS Batch compute environment failed' in str(ctx.value)
+
+
+class TestBatchJobQueueSensor(unittest.TestCase):
+ def setUp(self):
+ self.job_queue = 'job_queue'
+ self.sensor = BatchJobQueueSensor(
+ task_id='test_batch_job_queue_sensor',
+ job_queue=self.job_queue,
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_no_queue(self, mock_batch_client):
+ mock_batch_client.describe_job_queues.return_value = {'jobQueues': []}
+ with pytest.raises(AirflowException) as ctx:
+ self.sensor.poke({})
+ mock_batch_client.describe_job_queues.assert_called_once_with(
+ jobQueues=[self.job_queue],
+ )
+ assert 'not found' in str(ctx.value)
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_no_queue_with_treat_non_existing_as_deleted(self, mock_batch_client):
+ self.sensor.treat_non_existing_as_deleted = True
+ mock_batch_client.describe_job_queues.return_value = {'jobQueues': []}
+ assert self.sensor.poke({})
+ mock_batch_client.describe_job_queues.assert_called_once_with(
+ jobQueues=[self.job_queue],
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_valid(self, mock_batch_client):
+ mock_batch_client.describe_job_queues.return_value = {'jobQueues': [{'status': 'VALID'}]}
+ assert self.sensor.poke({})
+ mock_batch_client.describe_job_queues.assert_called_once_with(
+ jobQueues=[self.job_queue],
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_running(self, mock_batch_client):
+ mock_batch_client.describe_job_queues.return_value = {
+ 'jobQueues': [
+ {
+ 'status': 'CREATING',
+ }
+ ]
+ }
+ assert not self.sensor.poke({})
+ mock_batch_client.describe_job_queues.assert_called_once_with(
+ jobQueues=[self.job_queue],
+ )
+
+ @mock.patch.object(BatchClientHook, 'client')
+ def test_poke_invalid(self, mock_batch_client):
+ mock_batch_client.describe_job_queues.return_value = {
+ 'jobQueues': [
+ {
+ 'status': 'INVALID',
+ }
+ ]
+ }
+ with pytest.raises(AirflowException) as ctx:
+ self.sensor.poke({})
+ mock_batch_client.describe_job_queues.assert_called_once_with(
+ jobQueues=[self.job_queue],
+ )
+ assert 'AWS Batch job queue failed' in str(ctx.value)
diff --git a/tests/providers/amazon/aws/utils/test_utils.py b/tests/providers/amazon/aws/utils/test_utils.py
index 9a6bc198c1..e27dcbc667 100644
--- a/tests/providers/amazon/aws/utils/test_utils.py
+++ b/tests/providers/amazon/aws/utils/test_utils.py
@@ -16,18 +16,34 @@
# under the License.
from datetime import datetime
+from unittest import TestCase
+
+import pytz
from airflow.providers.amazon.aws.utils import (
datetime_to_epoch,
datetime_to_epoch_ms,
datetime_to_epoch_us,
get_airflow_version,
+ trim_none_values,
)
-DT = datetime(2000, 1, 1)
+DT = datetime(2000, 1, 1, tzinfo=pytz.UTC)
EPOCH = 946_684_800
+class TestUtils(TestCase):
+ def test_trim_none_values(self):
+ input_object = {
+ "test": "test",
+ "empty": None,
+ }
+ expected_output_object = {
+ "test": "test",
+ }
+ assert trim_none_values(input_object) == expected_output_object
+
+
def test_datetime_to_epoch():
assert datetime_to_epoch(DT) == EPOCH
diff --git a/tests/system/providers/amazon/aws/example_athena.py b/tests/system/providers/amazon/aws/example_athena.py
index 42d3abf55a..2a9a4d16a5 100644
--- a/tests/system/providers/amazon/aws/example_athena.py
+++ b/tests/system/providers/amazon/aws/example_athena.py
@@ -34,7 +34,7 @@ from tests.system.providers.amazon.aws.utils import set_env_id
ENV_ID = set_env_id()
DAG_ID = 'example_athena'
-S3_BUCKET = f'{ENV_ID.lower()}-athena-bucket'
+S3_BUCKET = f'{ENV_ID}-athena-bucket'
ATHENA_TABLE = f'{ENV_ID}_test_table'
ATHENA_DATABASE = f'{ENV_ID}_default'
diff --git a/tests/system/providers/amazon/aws/example_batch.py b/tests/system/providers/amazon/aws/example_batch.py
new file mode 100644
index 0000000000..9cf91b3475
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_batch.py
@@ -0,0 +1,250 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime
+
+import boto3
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator
+from airflow.providers.amazon.aws.sensors.batch import (
+ BatchComputeEnvironmentSensor,
+ BatchJobQueueSensor,
+ BatchSensor,
+)
+from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder, split_string
+
+DAG_ID = 'example_batch'
+
+# Externally fetched variables:
+ROLE_ARN_KEY = 'ROLE_ARN'
+SUBNETS_KEY = 'SUBNETS'
+SECURITY_GROUPS_KEY = 'SECURITY_GROUPS'
+
+sys_test_context_task = (
+ SystemTestContextBuilder()
+ .add_variable(ROLE_ARN_KEY)
+ .add_variable(SUBNETS_KEY)
+ .add_variable(SECURITY_GROUPS_KEY)
+ .build()
+)
+
+JOB_OVERRIDES: dict = {}
+
+
+@task
+def create_job_definition(role_arn, job_definition_name):
+ boto3.client('batch').register_job_definition(
+ type='container',
+ containerProperties={
+ 'command': [
+ 'sleep',
+ '2',
+ ],
+ 'executionRoleArn': role_arn,
+ 'image': 'busybox',
+ 'resourceRequirements': [
+ {'value': '1', 'type': 'VCPU'},
+ {'value': '2048', 'type': 'MEMORY'},
+ ],
+ 'networkConfiguration': {
+ 'assignPublicIp': 'ENABLED',
+ },
+ },
+ jobDefinitionName=job_definition_name,
+ platformCapabilities=['FARGATE'],
+ )
+
+
+@task
+def create_job_queue(job_compute_environment_name, job_queue_name):
+ boto3.client('batch').create_job_queue(
+ computeEnvironmentOrder=[
+ {
+ 'computeEnvironment': job_compute_environment_name,
+ 'order': 1,
+ },
+ ],
+ jobQueueName=job_queue_name,
+ priority=1,
+ state='ENABLED',
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_job_definition(job_definition_name):
+ client = boto3.client('batch')
+
+ response = client.describe_job_definitions(
+ jobDefinitionName=job_definition_name,
+ status='ACTIVE',
+ )
+
+ for job_definition in response['jobDefinitions']:
+ client.deregister_job_definition(
+ jobDefinition=job_definition['jobDefinitionArn'],
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def disable_compute_environment(job_compute_environment_name):
+ boto3.client('batch').update_compute_environment(
+ computeEnvironment=job_compute_environment_name,
+ state='DISABLED',
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_compute_environment(job_compute_environment_name):
+ boto3.client('batch').delete_compute_environment(
+ computeEnvironment=job_compute_environment_name,
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def disable_job_queue(job_queue_name):
+ boto3.client('batch').update_job_queue(
+ jobQueue=job_queue_name,
+ state='DISABLED',
+ )
+
+
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def delete_job_queue(job_queue_name):
+ boto3.client('batch').delete_job_queue(
+ jobQueue=job_queue_name,
+ )
+
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule_interval='@once',
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+
+ batch_job_name: str = f'{test_context[ENV_ID_KEY]}-test-job'
+ batch_job_definition_name: str = f'{test_context[ENV_ID_KEY]}-test-job-definition'
+ batch_job_compute_environment_name: str = f'{test_context[ENV_ID_KEY]}-test-job-compute-environment'
+ batch_job_queue_name: str = f'{test_context[ENV_ID_KEY]}-test-job-queue'
+
+ security_groups = split_string(test_context[SECURITY_GROUPS_KEY])
+ subnets = split_string(test_context[SUBNETS_KEY])
+
+ # [START howto_operator_batch_create_compute_environment]
+ create_compute_environment = BatchCreateComputeEnvironmentOperator(
+ task_id='create_compute_environment',
+ compute_environment_name=batch_job_compute_environment_name,
+ environment_type='MANAGED',
+ state='ENABLED',
+ compute_resources={
+ 'type': 'FARGATE',
+ 'maxvCpus': 10,
+ 'securityGroupIds': security_groups,
+ 'subnets': subnets,
+ },
+ )
+ # [END howto_operator_batch_create_compute_environment]
+
+ # [START howto_sensor_batch_compute_environment]
+ wait_for_compute_environment_valid = BatchComputeEnvironmentSensor(
+ task_id='wait_for_compute_environment_valid',
+ compute_environment=batch_job_compute_environment_name,
+ )
+ # [END howto_sensor_batch_compute_environment]
+
+ # [START howto_sensor_batch_job_queue]
+ wait_for_job_queue_valid = BatchJobQueueSensor(
+ task_id='wait_for_job_queue_valid',
+ job_queue=batch_job_queue_name,
+ )
+ # [END howto_sensor_batch_job_queue]
+
+ # [START howto_operator_batch]
+ submit_batch_job = BatchOperator(
+ task_id='submit_batch_job',
+ job_name=batch_job_name,
+ job_queue=batch_job_queue_name,
+ job_definition=batch_job_definition_name,
+ overrides=JOB_OVERRIDES,
+ # Set this flag to False, so we can test the sensor below
+ wait_for_completion=False,
+ )
+ # [END howto_operator_batch]
+
+ # [START howto_sensor_batch]
+ wait_for_batch_job = BatchSensor(
+ task_id='wait_for_batch_job',
+ job_id=submit_batch_job.output,
+ )
+ # [END howto_sensor_batch]
+
+ wait_for_compute_environment_disabled = BatchComputeEnvironmentSensor(
+ task_id='wait_for_compute_environment_disabled',
+ compute_environment=batch_job_compute_environment_name,
+ )
+
+ wait_for_job_queue_modified = BatchJobQueueSensor(
+ task_id='wait_for_job_queue_modified',
+ job_queue=batch_job_queue_name,
+ )
+
+ wait_for_job_queue_deleted = BatchJobQueueSensor(
+ task_id='wait_for_job_queue_deleted',
+ job_queue=batch_job_queue_name,
+ treat_non_existing_as_deleted=True,
+ )
+
+ chain(
+ # TEST SETUP
+ test_context,
+ security_groups,
+ subnets,
+ create_job_definition(test_context[ROLE_ARN_KEY], batch_job_definition_name),
+ # TEST BODY
+ create_compute_environment,
+ wait_for_compute_environment_valid,
+ # ``create_job_queue`` is part of test setup but need the compute-environment to be created before
+ create_job_queue(batch_job_compute_environment_name, batch_job_queue_name),
+ wait_for_job_queue_valid,
+ submit_batch_job,
+ wait_for_batch_job,
+ # TEST TEARDOWN
+ disable_job_queue(batch_job_queue_name),
+ wait_for_job_queue_modified,
+ delete_job_queue(batch_job_queue_name),
+ wait_for_job_queue_deleted,
+ disable_compute_environment(batch_job_compute_environment_name),
+ wait_for_compute_environment_disabled,
+ delete_compute_environment(batch_job_compute_environment_name),
+ delete_job_definition(batch_job_definition_name),
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/amazon/aws/utils/__init__.py b/tests/system/providers/amazon/aws/utils/__init__.py
index 4228fa2986..0c2e16a40e 100644
--- a/tests/system/providers/amazon/aws/utils/__init__.py
+++ b/tests/system/providers/amazon/aws/utils/__init__.py
@@ -206,3 +206,8 @@ def purge_logs(test_logs: List[Tuple[str, Optional[str]]]) -> None:
if not client.describe_log_streams(logGroupName=group)['logStreams']:
client.delete_log_group(logGroupName=group)
+
+
+@task
+def split_string(string):
+ return string.split(',')