You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/01/06 16:30:33 UTC
[airflow] branch main updated: Refactor BatchWaitersHook tests (#28758)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e9e052bdc7 Refactor BatchWaitersHook tests (#28758)
e9e052bdc7 is described below
commit e9e052bdc782a938606b31ee485fc246a37e27e3
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Jan 6 20:30:16 2023 +0400
Refactor BatchWaitersHook tests (#28758)
---
.../amazon/aws/hooks/test_batch_waiters.py | 441 +++++++--------------
1 file changed, 154 insertions(+), 287 deletions(-)
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index c245ac4da4..285784d184 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -15,38 +15,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-Test BatchWaiters
-
-This test suite uses a large suite of moto mocks for the
-AWS Batch infrastructure. These infrastructure mocks are
-derived from the moto test suite for testing the Batch client.
-
-.. seealso::
-
- - https://github.com/spulec/moto/pull/1197/files
- - https://github.com/spulec/moto/blob/master/tests/test_batch/test_batch.py
-"""
from __future__ import annotations
import inspect
-from typing import NamedTuple
+import itertools
from unittest import mock
import boto3
-import botocore.client
-import botocore.exceptions
-import botocore.waiter
import pytest
-from moto import mock_batch, mock_ec2, mock_ecs, mock_iam, mock_logs
+from botocore.exceptions import ClientError, WaiterError
+from botocore.waiter import SingleWaiterConfig, WaiterModel
+from moto import mock_batch
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook
-# Use dummy AWS credentials
+INTERMEDIATE_STATES = ("SUBMITTED", "PENDING", "RUNNABLE", "STARTING")
+RUNNING_STATE = "RUNNING"
+SUCCESS_STATE = "SUCCEEDED"
+FAILED_STATE = "FAILED"
+ALL_STATES = {*INTERMEDIATE_STATES, RUNNING_STATE, SUCCESS_STATE, FAILED_STATE}
AWS_REGION = "eu-west-1"
-AWS_ACCESS_KEY_ID = "airflow_dummy_key"
-AWS_SECRET_ACCESS_KEY = "airflow_dummy_secret"
@pytest.fixture(scope="module")
@@ -54,164 +43,13 @@ def aws_region():
return AWS_REGION
-@pytest.fixture(scope="module")
-def job_queue_name():
- return "moto_test_job_queue"
-
-
-@pytest.fixture(scope="module")
-def job_definition_name():
- return "moto_test_job_definition"
-
-
-#
-# AWS Clients
-#
-
-
-class AwsClients(NamedTuple):
- batch: botocore.client.Batch
- ec2: botocore.client.EC2
- ecs: botocore.client.ECS
- iam: botocore.client.IAM
- log: botocore.client.CloudWatchLogs
-
-
-@pytest.fixture(scope="module")
-def batch_client(aws_region):
- with mock_batch():
- yield boto3.client("batch", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def ec2_client(aws_region):
- with mock_ec2():
- yield boto3.client("ec2", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def ecs_client(aws_region):
- with mock_ecs():
- yield boto3.client("ecs", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def iam_client(aws_region):
- with mock_iam():
- yield boto3.client("iam", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def logs_client(aws_region):
- with mock_logs():
- yield boto3.client("logs", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def aws_clients(batch_client, ec2_client, ecs_client, iam_client, logs_client):
- return AwsClients(batch=batch_client, ec2=ec2_client, ecs=ecs_client, iam=iam_client, log=logs_client)
-
-
-#
-# Batch Infrastructure
-#
-
-
-class Infrastructure:
- aws_region: str
- aws_clients: AwsClients
- vpc_id: str | None = None
- subnet_id: str | None = None
- security_group_id: str | None = None
- iam_arn: str | None = None
- compute_env_name: str | None = None
- compute_env_arn: str | None = None
- job_queue_name: str | None = None
- job_queue_arn: str | None = None
- job_definition_name: str | None = None
- job_definition_arn: str | None = None
-
-
-def batch_infrastructure(
- aws_clients: AwsClients, aws_region: str, job_queue_name: str, job_definition_name: str
-) -> Infrastructure:
- """
- This function is not a fixture so that tests can pass the AWS clients to it and then
- continue to use the infrastructure created by it while the client fixtures are in-tact for
- the duration of a test.
- """
-
- infrastructure = Infrastructure()
- infrastructure.aws_region = aws_region
- infrastructure.aws_clients = aws_clients
-
- resp = aws_clients.ec2.create_vpc(CidrBlock="172.30.0.0/24")
- vpc_id = resp["Vpc"]["VpcId"]
-
- resp = aws_clients.ec2.create_subnet(
- AvailabilityZone=f"{aws_region}a", CidrBlock="172.30.0.0/25", VpcId=vpc_id
- )
- subnet_id = resp["Subnet"]["SubnetId"]
-
- resp = aws_clients.ec2.create_security_group(
- Description="moto_test_sg_desc", GroupName="moto_test_sg", VpcId=vpc_id
- )
- sg_id = resp["GroupId"]
-
- resp = aws_clients.iam.create_role(RoleName="MotoTestRole", AssumeRolePolicyDocument="moto_test_policy")
- iam_arn = resp["Role"]["Arn"]
-
- compute_env_name = "moto_test_compute_env"
- resp = aws_clients.batch.create_compute_environment(
- computeEnvironmentName=compute_env_name,
- type="UNMANAGED",
- state="ENABLED",
- serviceRole=iam_arn,
- )
- compute_env_arn = resp["computeEnvironmentArn"]
-
- resp = aws_clients.batch.create_job_queue(
- jobQueueName=job_queue_name,
- state="ENABLED",
- priority=123,
- computeEnvironmentOrder=[{"order": 123, "computeEnvironment": compute_env_arn}],
- )
- assert resp["jobQueueName"] == job_queue_name
- assert resp["jobQueueArn"]
- job_queue_arn = resp["jobQueueArn"]
-
- resp = aws_clients.batch.register_job_definition(
- jobDefinitionName=job_definition_name,
- type="container",
- containerProperties={
- "image": "busybox",
- "vcpus": 1,
- "memory": 64,
- "command": ["sleep", "10"],
- },
- )
- assert resp["jobDefinitionName"] == job_definition_name
- assert resp["jobDefinitionArn"]
- job_definition_arn = resp["jobDefinitionArn"]
- assert resp["revision"]
- assert resp["jobDefinitionArn"].endswith(f"{resp['jobDefinitionName']}:{resp['revision']}")
-
- infrastructure.vpc_id = vpc_id
- infrastructure.subnet_id = subnet_id
- infrastructure.security_group_id = sg_id
- infrastructure.iam_arn = iam_arn
- infrastructure.compute_env_name = compute_env_name
- infrastructure.compute_env_arn = compute_env_arn
- infrastructure.job_queue_name = job_queue_name
- infrastructure.job_queue_arn = job_queue_arn
- infrastructure.job_definition_name = job_definition_name
- infrastructure.job_definition_arn = job_definition_arn
- return infrastructure
-
-
-#
-# pytest tests
-#
+@mock_batch
+@pytest.fixture
+def patch_hook(monkeypatch, aws_region):
+ """Patch hook object by dummy boto3 Batch client."""
+ batch_client = boto3.client("batch", region_name=aws_region)
+ monkeypatch.setattr(BatchWaitersHook, "conn", batch_client)
+ yield
def test_batch_waiters(aws_region):
@@ -220,108 +58,9 @@ def test_batch_waiters(aws_region):
assert isinstance(batch_waiters, BatchWaitersHook)
-@mock_batch
-@mock_ec2
-@mock_ecs
-@mock_iam
-@mock_logs
-@pytest.mark.xfail(condition=True, reason="Inexplicable timeout issue when running this test. See PR 11020")
-def test_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definition_name):
- """
- Submit Batch jobs and wait for various job status indicators or errors.
- These Batch job waiter tests can be slow and might need to be marked
- for conditional skips if they take too long, although it seems to
- run in about 30 sec to a minute.
-
- .. note::
- These tests have no control over how moto transitions the Batch job status.
-
- .. seealso::
- - https://github.com/boto/botocore/blob/develop/botocore/waiter.py
- - https://github.com/spulec/moto/blob/master/moto/batch/models.py#L360
- - https://github.com/spulec/moto/blob/master/tests/test_batch/test_batch.py
- """
-
- aws_resources = batch_infrastructure(aws_clients, aws_region, job_queue_name, job_definition_name)
- batch_waiters = BatchWaitersHook(region_name=aws_resources.aws_region)
-
- job_exists_waiter = batch_waiters.get_waiter("JobExists")
- assert job_exists_waiter
- assert isinstance(job_exists_waiter, botocore.waiter.Waiter)
- assert job_exists_waiter.__class__.__name__ == "Batch.Waiter.JobExists"
-
- job_running_waiter = batch_waiters.get_waiter("JobRunning")
- assert job_running_waiter
- assert isinstance(job_running_waiter, botocore.waiter.Waiter)
- assert job_running_waiter.__class__.__name__ == "Batch.Waiter.JobRunning"
-
- job_complete_waiter = batch_waiters.get_waiter("JobComplete")
- assert job_complete_waiter
- assert isinstance(job_complete_waiter, botocore.waiter.Waiter)
- assert job_complete_waiter.__class__.__name__ == "Batch.Waiter.JobComplete"
-
- # test waiting on a jobId that does not exist (this throws immediately)
- with pytest.raises(botocore.exceptions.WaiterError) as ctx:
- job_exists_waiter.config.delay = 0.2
- job_exists_waiter.config.max_attempts = 2
- job_exists_waiter.wait(jobs=["missing-job"])
- assert isinstance(ctx.value, botocore.exceptions.WaiterError)
- assert "Waiter JobExists failed" in str(ctx.value)
-
- # Submit a job and wait for various job status indicators;
- # moto transitions the Batch job status automatically.
-
- job_name = "test-job"
- job_cmd = ['/bin/sh -c "for a in `seq 1 2`; do echo Hello World; sleep 0.25; done"']
-
- job_response = aws_clients.batch.submit_job(
- jobName=job_name,
- jobQueue=aws_resources.job_queue_arn,
- jobDefinition=aws_resources.job_definition_arn,
- containerOverrides={"command": job_cmd},
- )
- job_id = job_response["jobId"]
-
- job_description = aws_clients.batch.describe_jobs(jobs=[job_id])
- job_status = [job for job in job_description["jobs"] if job["jobId"] == job_id][0]["status"]
- assert job_status == "PENDING"
-
- # this should not raise a WaiterError and note there is no 'state' maintained in
- # the waiter that can be checked after calling wait() and it has no return value;
- # see https://github.com/boto/botocore/blob/develop/botocore/waiter.py#L287
- job_exists_waiter.config.delay = 0.2
- job_exists_waiter.config.max_attempts = 20
- job_exists_waiter.wait(jobs=[job_id])
-
- # test waiting for job completion with too few attempts (possibly before job is running)
- job_complete_waiter.config.delay = 0.1
- job_complete_waiter.config.max_attempts = 1
- with pytest.raises(botocore.exceptions.WaiterError) as ctx:
- job_complete_waiter.wait(jobs=[job_id])
- assert isinstance(ctx.value, botocore.exceptions.WaiterError)
- assert "Waiter JobComplete failed: Max attempts exceeded" in str(ctx.value)
-
- # wait for job to be running (or complete)
- job_running_waiter.config.delay = 0.25 # sec delays between status checks
- job_running_waiter.config.max_attempts = 50
- job_running_waiter.wait(jobs=[job_id])
-
- # wait for job completion
- job_complete_waiter.config.delay = 0.25
- job_complete_waiter.config.max_attempts = 50
- job_complete_waiter.wait(jobs=[job_id])
-
- job_description = aws_clients.batch.describe_jobs(jobs=[job_id])
- job_status = [job for job in job_description["jobs"] if job["jobId"] == job_id][0]["status"]
- assert job_status == "SUCCEEDED"
-
-
class TestBatchWaiters:
- @mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
- @mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
- @mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
- def setup_method(self, method, get_client_type_mock):
+ @pytest.fixture(autouse=True)
+ def setup_tests(self, patch_hook):
self.job_id = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
self.region_name = AWS_REGION
@@ -329,10 +68,6 @@ class TestBatchWaiters:
assert self.batch_waiters.aws_conn_id == "aws_default"
assert self.batch_waiters.region_name == self.region_name
- # init the mock client
- self.client_mock = self.batch_waiters.client
- get_client_type_mock.assert_called_once_with(region_name=self.region_name)
-
# don't pause in these unit tests
self.mock_delay = mock.Mock(return_value=None)
self.batch_waiters.delay = self.mock_delay
@@ -362,7 +97,7 @@ class TestBatchWaiters:
def test_waiter_model(self):
model = self.batch_waiters.waiter_model
- assert isinstance(model, botocore.waiter.WaiterModel)
+ assert isinstance(model, WaiterModel)
# test some of the default config
assert model.version == 2
@@ -376,7 +111,7 @@ class TestBatchWaiters:
# test some default waiter properties
waiter = model.get_waiter("JobExists")
- assert isinstance(waiter, botocore.waiter.SingleWaiterConfig)
+ assert isinstance(waiter, SingleWaiterConfig)
assert waiter.max_attempts == 100
waiter.max_attempts = 200
assert waiter.max_attempts == 200
@@ -417,7 +152,7 @@ class TestBatchWaiters:
with mock.patch.object(self.batch_waiters, "get_waiter") as get_waiter:
mock_waiter = get_waiter.return_value
- mock_waiter.wait.side_effect = botocore.exceptions.ClientError(
+ mock_waiter.wait.side_effect = ClientError(
error_response={"Error": {"Code": "TooManyRequestsException"}},
operation_name="get job description",
)
@@ -435,7 +170,7 @@ class TestBatchWaiters:
with mock.patch.object(self.batch_waiters, "get_waiter") as get_waiter:
mock_waiter = get_waiter.return_value
- mock_waiter.wait.side_effect = botocore.exceptions.WaiterError(
+ mock_waiter.wait.side_effect = WaiterError(
name="JobExists", reason="unit test error", last_response={}
)
with pytest.raises(AirflowException):
@@ -444,3 +179,135 @@ class TestBatchWaiters:
assert get_waiter.call_args_list == [mock.call("JobExists")]
mock_waiter.wait.assert_called_with(jobs=[self.job_id])
assert mock_waiter.wait.call_count == 1
+
+
+class TestBatchJobWaiters:
+ """Test default waiters."""
+
+ @pytest.fixture(autouse=True)
+ def setup_tests(self, patch_hook):
+ """Mock `describe_jobs` method before each test run."""
+ self.batch_waiters = BatchWaitersHook(region_name=AWS_REGION)
+ self.client = self.batch_waiters.client
+
+ with mock.patch.object(self.client, "describe_jobs") as m:
+ self.mock_describe_jobs = m
+ yield
+
+ @staticmethod
+ def describe_jobs_response(job_id: str = "mock-job-id", status: str = INTERMEDIATE_STATES[0]):
+ """
+ Helper function for generate minimal DescribeJobs response for single job.
+ https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
+ """
+ assert job_id
+ assert status in ALL_STATES
+
+ return {"jobs": [{"jobId": job_id, "status": status}]}
+
+ @pytest.mark.parametrize("status", ALL_STATES)
+ def test_job_exists_waiter_exists(self, status: str):
+ """Test `JobExists` when response return dictionary regardless state."""
+ self.mock_describe_jobs.return_value = self.describe_jobs_response(
+ job_id="job-exist-success", status=status
+ )
+ job_exists_waiter = self.batch_waiters.get_waiter("JobExists")
+ job_exists_waiter.config.delay = 0.01
+ job_exists_waiter.config.max_attempts = 5
+ job_exists_waiter.wait(jobs=["job-exist-success"])
+ assert self.mock_describe_jobs.called
+
+ def test_job_exists_waiter_missing(self):
+ """Test `JobExists` waiter when response return empty dictionary."""
+ self.mock_describe_jobs.return_value = {"jobs": []}
+
+ job_exists_waiter = self.batch_waiters.get_waiter("JobExists")
+ job_exists_waiter.config.delay = 0.01
+ job_exists_waiter.config.max_attempts = 20
+ with pytest.raises(WaiterError, match="Waiter JobExists failed"):
+ job_exists_waiter.wait(jobs=["job-missing"])
+ assert self.mock_describe_jobs.called
+
+ @pytest.mark.parametrize("status", [RUNNING_STATE, SUCCESS_STATE, FAILED_STATE])
+ def test_job_running_waiter_change_to_waited_state(self, status):
+ """Test `JobRunning` waiter reach expected state."""
+ job_id = "job-running"
+ self.mock_describe_jobs.side_effect = [
+ # Emulate change job status before one of expected states.
+ # SUBMITTED -> PENDING -> RUNNABLE -> STARTING
+ *itertools.chain(
+ *[
+ itertools.repeat(self.describe_jobs_response(job_id=job_id, status=inter_status), 3)
+ for inter_status in INTERMEDIATE_STATES
+ ]
+ ),
+ # Expected status
+ self.describe_jobs_response(job_id=job_id, status=status),
+ RuntimeError("This should not raise"),
+ ]
+
+ job_running_waiter = self.batch_waiters.get_waiter("JobRunning")
+ job_running_waiter.config.delay = 0.01
+ job_running_waiter.config.max_attempts = 20
+ job_running_waiter.wait(jobs=[job_id])
+ assert self.mock_describe_jobs.called
+
+ @pytest.mark.parametrize("status", INTERMEDIATE_STATES)
+ def test_job_running_waiter_max_attempt_exceeded(self, status):
+ """Test `JobRunning` waiter run out of attempts."""
+ job_id = "job-running-inf"
+ self.mock_describe_jobs.side_effect = itertools.repeat(
+ self.describe_jobs_response(job_id=job_id, status=status)
+ )
+ job_running_waiter = self.batch_waiters.get_waiter("JobRunning")
+ job_running_waiter.config.delay = 0.01
+ job_running_waiter.config.max_attempts = 20
+ with pytest.raises(WaiterError, match="Waiter JobRunning failed: Max attempts exceeded"):
+ job_running_waiter.wait(jobs=[job_id])
+ assert self.mock_describe_jobs.called
+
+ def test_job_complete_waiter_succeeded(self):
+ """Test `JobComplete` waiter reach `SUCCEEDED` status."""
+ job_id = "job-succeeded"
+ self.mock_describe_jobs.side_effect = [
+ *itertools.repeat(self.describe_jobs_response(job_id=job_id, status=RUNNING_STATE), 10),
+ self.describe_jobs_response(job_id=job_id, status=SUCCESS_STATE),
+ RuntimeError("This should not raise"),
+ ]
+
+ job_complete_waiter = self.batch_waiters.get_waiter("JobComplete")
+ job_complete_waiter.config.delay = 0.01
+ job_complete_waiter.config.max_attempts = 20
+ job_complete_waiter.wait(jobs=[job_id])
+ assert self.mock_describe_jobs.called
+
+ def test_job_complete_waiter_failed(self):
+ """Test `JobComplete` waiter reach `FAILED` status."""
+ job_id = "job-failed"
+ self.mock_describe_jobs.side_effect = [
+ *itertools.repeat(self.describe_jobs_response(job_id=job_id, status=RUNNING_STATE), 10),
+ self.describe_jobs_response(job_id=job_id, status=FAILED_STATE),
+ RuntimeError("This should not raise"),
+ ]
+
+ job_complete_waiter = self.batch_waiters.get_waiter("JobComplete")
+ job_complete_waiter.config.delay = 0.01
+ job_complete_waiter.config.max_attempts = 20
+ with pytest.raises(
+ WaiterError, match="Waiter JobComplete failed: Waiter encountered a terminal failure state"
+ ):
+ job_complete_waiter.wait(jobs=[job_id])
+ assert self.mock_describe_jobs.called
+
+ def test_job_complete_waiter_max_attempt_exceeded(self):
+ """Test `JobComplete` waiter run out of attempts."""
+ job_id = "job-running-inf"
+ self.mock_describe_jobs.side_effect = itertools.repeat(
+ self.describe_jobs_response(job_id=job_id, status=RUNNING_STATE)
+ )
+ job_running_waiter = self.batch_waiters.get_waiter("JobComplete")
+ job_running_waiter.config.delay = 0.01
+ job_running_waiter.config.max_attempts = 20
+ with pytest.raises(WaiterError, match="Waiter JobComplete failed: Max attempts exceeded"):
+ job_running_waiter.wait(jobs=[job_id])
+ assert self.mock_describe_jobs.called