You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/01/06 16:30:33 UTC

[airflow] branch main updated: Refactor BatchWaitersHook tests (#28758)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e9e052bdc7 Refactor BatchWaitersHook tests (#28758)
e9e052bdc7 is described below

commit e9e052bdc782a938606b31ee485fc246a37e27e3
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Jan 6 20:30:16 2023 +0400

    Refactor BatchWaitersHook tests (#28758)
---
 .../amazon/aws/hooks/test_batch_waiters.py         | 441 +++++++--------------
 1 file changed, 154 insertions(+), 287 deletions(-)

diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index c245ac4da4..285784d184 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -15,38 +15,27 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""
-Test BatchWaiters
-
-This test suite uses a large suite of moto mocks for the
-AWS Batch infrastructure.  These infrastructure mocks are
-derived from the moto test suite for testing the Batch client.
-
-.. seealso::
-
-    - https://github.com/spulec/moto/pull/1197/files
-    - https://github.com/spulec/moto/blob/master/tests/test_batch/test_batch.py
-"""
 from __future__ import annotations
 
 import inspect
-from typing import NamedTuple
+import itertools
 from unittest import mock
 
 import boto3
-import botocore.client
-import botocore.exceptions
-import botocore.waiter
 import pytest
-from moto import mock_batch, mock_ec2, mock_ecs, mock_iam, mock_logs
+from botocore.exceptions import ClientError, WaiterError
+from botocore.waiter import SingleWaiterConfig, WaiterModel
+from moto import mock_batch
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook
 
-# Use dummy AWS credentials
+INTERMEDIATE_STATES = ("SUBMITTED", "PENDING", "RUNNABLE", "STARTING")
+RUNNING_STATE = "RUNNING"
+SUCCESS_STATE = "SUCCEEDED"
+FAILED_STATE = "FAILED"
+ALL_STATES = {*INTERMEDIATE_STATES, RUNNING_STATE, SUCCESS_STATE, FAILED_STATE}
 AWS_REGION = "eu-west-1"
-AWS_ACCESS_KEY_ID = "airflow_dummy_key"
-AWS_SECRET_ACCESS_KEY = "airflow_dummy_secret"
 
 
 @pytest.fixture(scope="module")
@@ -54,164 +43,13 @@ def aws_region():
     return AWS_REGION
 
 
-@pytest.fixture(scope="module")
-def job_queue_name():
-    return "moto_test_job_queue"
-
-
-@pytest.fixture(scope="module")
-def job_definition_name():
-    return "moto_test_job_definition"
-
-
-#
-# AWS Clients
-#
-
-
-class AwsClients(NamedTuple):
-    batch: botocore.client.Batch
-    ec2: botocore.client.EC2
-    ecs: botocore.client.ECS
-    iam: botocore.client.IAM
-    log: botocore.client.CloudWatchLogs
-
-
-@pytest.fixture(scope="module")
-def batch_client(aws_region):
-    with mock_batch():
-        yield boto3.client("batch", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def ec2_client(aws_region):
-    with mock_ec2():
-        yield boto3.client("ec2", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def ecs_client(aws_region):
-    with mock_ecs():
-        yield boto3.client("ecs", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def iam_client(aws_region):
-    with mock_iam():
-        yield boto3.client("iam", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def logs_client(aws_region):
-    with mock_logs():
-        yield boto3.client("logs", region_name=aws_region)
-
-
-@pytest.fixture(scope="module")
-def aws_clients(batch_client, ec2_client, ecs_client, iam_client, logs_client):
-    return AwsClients(batch=batch_client, ec2=ec2_client, ecs=ecs_client, iam=iam_client, log=logs_client)
-
-
-#
-# Batch Infrastructure
-#
-
-
-class Infrastructure:
-    aws_region: str
-    aws_clients: AwsClients
-    vpc_id: str | None = None
-    subnet_id: str | None = None
-    security_group_id: str | None = None
-    iam_arn: str | None = None
-    compute_env_name: str | None = None
-    compute_env_arn: str | None = None
-    job_queue_name: str | None = None
-    job_queue_arn: str | None = None
-    job_definition_name: str | None = None
-    job_definition_arn: str | None = None
-
-
-def batch_infrastructure(
-    aws_clients: AwsClients, aws_region: str, job_queue_name: str, job_definition_name: str
-) -> Infrastructure:
-    """
-    This function is not a fixture so that tests can pass the AWS clients to it and then
-    continue to use the infrastructure created by it while the client fixtures are in-tact for
-    the duration of a test.
-    """
-
-    infrastructure = Infrastructure()
-    infrastructure.aws_region = aws_region
-    infrastructure.aws_clients = aws_clients
-
-    resp = aws_clients.ec2.create_vpc(CidrBlock="172.30.0.0/24")
-    vpc_id = resp["Vpc"]["VpcId"]
-
-    resp = aws_clients.ec2.create_subnet(
-        AvailabilityZone=f"{aws_region}a", CidrBlock="172.30.0.0/25", VpcId=vpc_id
-    )
-    subnet_id = resp["Subnet"]["SubnetId"]
-
-    resp = aws_clients.ec2.create_security_group(
-        Description="moto_test_sg_desc", GroupName="moto_test_sg", VpcId=vpc_id
-    )
-    sg_id = resp["GroupId"]
-
-    resp = aws_clients.iam.create_role(RoleName="MotoTestRole", AssumeRolePolicyDocument="moto_test_policy")
-    iam_arn = resp["Role"]["Arn"]
-
-    compute_env_name = "moto_test_compute_env"
-    resp = aws_clients.batch.create_compute_environment(
-        computeEnvironmentName=compute_env_name,
-        type="UNMANAGED",
-        state="ENABLED",
-        serviceRole=iam_arn,
-    )
-    compute_env_arn = resp["computeEnvironmentArn"]
-
-    resp = aws_clients.batch.create_job_queue(
-        jobQueueName=job_queue_name,
-        state="ENABLED",
-        priority=123,
-        computeEnvironmentOrder=[{"order": 123, "computeEnvironment": compute_env_arn}],
-    )
-    assert resp["jobQueueName"] == job_queue_name
-    assert resp["jobQueueArn"]
-    job_queue_arn = resp["jobQueueArn"]
-
-    resp = aws_clients.batch.register_job_definition(
-        jobDefinitionName=job_definition_name,
-        type="container",
-        containerProperties={
-            "image": "busybox",
-            "vcpus": 1,
-            "memory": 64,
-            "command": ["sleep", "10"],
-        },
-    )
-    assert resp["jobDefinitionName"] == job_definition_name
-    assert resp["jobDefinitionArn"]
-    job_definition_arn = resp["jobDefinitionArn"]
-    assert resp["revision"]
-    assert resp["jobDefinitionArn"].endswith(f"{resp['jobDefinitionName']}:{resp['revision']}")
-
-    infrastructure.vpc_id = vpc_id
-    infrastructure.subnet_id = subnet_id
-    infrastructure.security_group_id = sg_id
-    infrastructure.iam_arn = iam_arn
-    infrastructure.compute_env_name = compute_env_name
-    infrastructure.compute_env_arn = compute_env_arn
-    infrastructure.job_queue_name = job_queue_name
-    infrastructure.job_queue_arn = job_queue_arn
-    infrastructure.job_definition_name = job_definition_name
-    infrastructure.job_definition_arn = job_definition_arn
-    return infrastructure
-
-
-#
-# pytest tests
-#
+@mock_batch
+@pytest.fixture
+def patch_hook(monkeypatch, aws_region):
+    """Patch hook object by dummy boto3 Batch client."""
+    batch_client = boto3.client("batch", region_name=aws_region)
+    monkeypatch.setattr(BatchWaitersHook, "conn", batch_client)
+    yield
 
 
 def test_batch_waiters(aws_region):
@@ -220,108 +58,9 @@ def test_batch_waiters(aws_region):
     assert isinstance(batch_waiters, BatchWaitersHook)
 
 
-@mock_batch
-@mock_ec2
-@mock_ecs
-@mock_iam
-@mock_logs
-@pytest.mark.xfail(condition=True, reason="Inexplicable timeout issue when running this test. See PR 11020")
-def test_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definition_name):
-    """
-    Submit Batch jobs and wait for various job status indicators or errors.
-    These Batch job waiter tests can be slow and might need to be marked
-    for conditional skips if they take too long, although it seems to
-    run in about 30 sec to a minute.
-
-    .. note::
-        These tests have no control over how moto transitions the Batch job status.
-
-    .. seealso::
-        - https://github.com/boto/botocore/blob/develop/botocore/waiter.py
-        - https://github.com/spulec/moto/blob/master/moto/batch/models.py#L360
-        - https://github.com/spulec/moto/blob/master/tests/test_batch/test_batch.py
-    """
-
-    aws_resources = batch_infrastructure(aws_clients, aws_region, job_queue_name, job_definition_name)
-    batch_waiters = BatchWaitersHook(region_name=aws_resources.aws_region)
-
-    job_exists_waiter = batch_waiters.get_waiter("JobExists")
-    assert job_exists_waiter
-    assert isinstance(job_exists_waiter, botocore.waiter.Waiter)
-    assert job_exists_waiter.__class__.__name__ == "Batch.Waiter.JobExists"
-
-    job_running_waiter = batch_waiters.get_waiter("JobRunning")
-    assert job_running_waiter
-    assert isinstance(job_running_waiter, botocore.waiter.Waiter)
-    assert job_running_waiter.__class__.__name__ == "Batch.Waiter.JobRunning"
-
-    job_complete_waiter = batch_waiters.get_waiter("JobComplete")
-    assert job_complete_waiter
-    assert isinstance(job_complete_waiter, botocore.waiter.Waiter)
-    assert job_complete_waiter.__class__.__name__ == "Batch.Waiter.JobComplete"
-
-    # test waiting on a jobId that does not exist (this throws immediately)
-    with pytest.raises(botocore.exceptions.WaiterError) as ctx:
-        job_exists_waiter.config.delay = 0.2
-        job_exists_waiter.config.max_attempts = 2
-        job_exists_waiter.wait(jobs=["missing-job"])
-    assert isinstance(ctx.value, botocore.exceptions.WaiterError)
-    assert "Waiter JobExists failed" in str(ctx.value)
-
-    # Submit a job and wait for various job status indicators;
-    # moto transitions the Batch job status automatically.
-
-    job_name = "test-job"
-    job_cmd = ['/bin/sh -c "for a in `seq 1 2`; do echo Hello World; sleep 0.25; done"']
-
-    job_response = aws_clients.batch.submit_job(
-        jobName=job_name,
-        jobQueue=aws_resources.job_queue_arn,
-        jobDefinition=aws_resources.job_definition_arn,
-        containerOverrides={"command": job_cmd},
-    )
-    job_id = job_response["jobId"]
-
-    job_description = aws_clients.batch.describe_jobs(jobs=[job_id])
-    job_status = [job for job in job_description["jobs"] if job["jobId"] == job_id][0]["status"]
-    assert job_status == "PENDING"
-
-    # this should not raise a WaiterError and note there is no 'state' maintained in
-    # the waiter that can be checked after calling wait() and it has no return value;
-    # see https://github.com/boto/botocore/blob/develop/botocore/waiter.py#L287
-    job_exists_waiter.config.delay = 0.2
-    job_exists_waiter.config.max_attempts = 20
-    job_exists_waiter.wait(jobs=[job_id])
-
-    # test waiting for job completion with too few attempts (possibly before job is running)
-    job_complete_waiter.config.delay = 0.1
-    job_complete_waiter.config.max_attempts = 1
-    with pytest.raises(botocore.exceptions.WaiterError) as ctx:
-        job_complete_waiter.wait(jobs=[job_id])
-    assert isinstance(ctx.value, botocore.exceptions.WaiterError)
-    assert "Waiter JobComplete failed: Max attempts exceeded" in str(ctx.value)
-
-    # wait for job to be running (or complete)
-    job_running_waiter.config.delay = 0.25  # sec delays between status checks
-    job_running_waiter.config.max_attempts = 50
-    job_running_waiter.wait(jobs=[job_id])
-
-    # wait for job completion
-    job_complete_waiter.config.delay = 0.25
-    job_complete_waiter.config.max_attempts = 50
-    job_complete_waiter.wait(jobs=[job_id])
-
-    job_description = aws_clients.batch.describe_jobs(jobs=[job_id])
-    job_status = [job for job in job_description["jobs"] if job["jobId"] == job_id][0]["status"]
-    assert job_status == "SUCCEEDED"
-
-
 class TestBatchWaiters:
-    @mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
-    @mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
-    @mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
-    @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
-    def setup_method(self, method, get_client_type_mock):
+    @pytest.fixture(autouse=True)
+    def setup_tests(self, patch_hook):
         self.job_id = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
         self.region_name = AWS_REGION
 
@@ -329,10 +68,6 @@ class TestBatchWaiters:
         assert self.batch_waiters.aws_conn_id == "aws_default"
         assert self.batch_waiters.region_name == self.region_name
 
-        # init the mock client
-        self.client_mock = self.batch_waiters.client
-        get_client_type_mock.assert_called_once_with(region_name=self.region_name)
-
         # don't pause in these unit tests
         self.mock_delay = mock.Mock(return_value=None)
         self.batch_waiters.delay = self.mock_delay
@@ -362,7 +97,7 @@ class TestBatchWaiters:
 
     def test_waiter_model(self):
         model = self.batch_waiters.waiter_model
-        assert isinstance(model, botocore.waiter.WaiterModel)
+        assert isinstance(model, WaiterModel)
 
         # test some of the default config
         assert model.version == 2
@@ -376,7 +111,7 @@ class TestBatchWaiters:
 
         # test some default waiter properties
         waiter = model.get_waiter("JobExists")
-        assert isinstance(waiter, botocore.waiter.SingleWaiterConfig)
+        assert isinstance(waiter, SingleWaiterConfig)
         assert waiter.max_attempts == 100
         waiter.max_attempts = 200
         assert waiter.max_attempts == 200
@@ -417,7 +152,7 @@ class TestBatchWaiters:
 
         with mock.patch.object(self.batch_waiters, "get_waiter") as get_waiter:
             mock_waiter = get_waiter.return_value
-            mock_waiter.wait.side_effect = botocore.exceptions.ClientError(
+            mock_waiter.wait.side_effect = ClientError(
                 error_response={"Error": {"Code": "TooManyRequestsException"}},
                 operation_name="get job description",
             )
@@ -435,7 +170,7 @@ class TestBatchWaiters:
 
         with mock.patch.object(self.batch_waiters, "get_waiter") as get_waiter:
             mock_waiter = get_waiter.return_value
-            mock_waiter.wait.side_effect = botocore.exceptions.WaiterError(
+            mock_waiter.wait.side_effect = WaiterError(
                 name="JobExists", reason="unit test error", last_response={}
             )
             with pytest.raises(AirflowException):
@@ -444,3 +179,135 @@ class TestBatchWaiters:
             assert get_waiter.call_args_list == [mock.call("JobExists")]
             mock_waiter.wait.assert_called_with(jobs=[self.job_id])
             assert mock_waiter.wait.call_count == 1
+
+
+class TestBatchJobWaiters:
+    """Test default waiters."""
+
+    @pytest.fixture(autouse=True)
+    def setup_tests(self, patch_hook):
+        """Mock `describe_jobs` method before each test run."""
+        self.batch_waiters = BatchWaitersHook(region_name=AWS_REGION)
+        self.client = self.batch_waiters.client
+
+        with mock.patch.object(self.client, "describe_jobs") as m:
+            self.mock_describe_jobs = m
+            yield
+
+    @staticmethod
+    def describe_jobs_response(job_id: str = "mock-job-id", status: str = INTERMEDIATE_STATES[0]):
+        """
+        Helper function for generate minimal DescribeJobs response for single job.
+        https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
+        """
+        assert job_id
+        assert status in ALL_STATES
+
+        return {"jobs": [{"jobId": job_id, "status": status}]}
+
+    @pytest.mark.parametrize("status", ALL_STATES)
+    def test_job_exists_waiter_exists(self, status: str):
+        """Test `JobExists` when response return dictionary regardless state."""
+        self.mock_describe_jobs.return_value = self.describe_jobs_response(
+            job_id="job-exist-success", status=status
+        )
+        job_exists_waiter = self.batch_waiters.get_waiter("JobExists")
+        job_exists_waiter.config.delay = 0.01
+        job_exists_waiter.config.max_attempts = 5
+        job_exists_waiter.wait(jobs=["job-exist-success"])
+        assert self.mock_describe_jobs.called
+
+    def test_job_exists_waiter_missing(self):
+        """Test `JobExists` waiter when response return empty dictionary."""
+        self.mock_describe_jobs.return_value = {"jobs": []}
+
+        job_exists_waiter = self.batch_waiters.get_waiter("JobExists")
+        job_exists_waiter.config.delay = 0.01
+        job_exists_waiter.config.max_attempts = 20
+        with pytest.raises(WaiterError, match="Waiter JobExists failed"):
+            job_exists_waiter.wait(jobs=["job-missing"])
+        assert self.mock_describe_jobs.called
+
+    @pytest.mark.parametrize("status", [RUNNING_STATE, SUCCESS_STATE, FAILED_STATE])
+    def test_job_running_waiter_change_to_waited_state(self, status):
+        """Test `JobRunning` waiter reach expected state."""
+        job_id = "job-running"
+        self.mock_describe_jobs.side_effect = [
+            # Emulate change job status before one of expected states.
+            # SUBMITTED -> PENDING -> RUNNABLE -> STARTING
+            *itertools.chain(
+                *[
+                    itertools.repeat(self.describe_jobs_response(job_id=job_id, status=inter_status), 3)
+                    for inter_status in INTERMEDIATE_STATES
+                ]
+            ),
+            # Expected status
+            self.describe_jobs_response(job_id=job_id, status=status),
+            RuntimeError("This should not raise"),
+        ]
+
+        job_running_waiter = self.batch_waiters.get_waiter("JobRunning")
+        job_running_waiter.config.delay = 0.01
+        job_running_waiter.config.max_attempts = 20
+        job_running_waiter.wait(jobs=[job_id])
+        assert self.mock_describe_jobs.called
+
+    @pytest.mark.parametrize("status", INTERMEDIATE_STATES)
+    def test_job_running_waiter_max_attempt_exceeded(self, status):
+        """Test `JobRunning` waiter run out of attempts."""
+        job_id = "job-running-inf"
+        self.mock_describe_jobs.side_effect = itertools.repeat(
+            self.describe_jobs_response(job_id=job_id, status=status)
+        )
+        job_running_waiter = self.batch_waiters.get_waiter("JobRunning")
+        job_running_waiter.config.delay = 0.01
+        job_running_waiter.config.max_attempts = 20
+        with pytest.raises(WaiterError, match="Waiter JobRunning failed: Max attempts exceeded"):
+            job_running_waiter.wait(jobs=[job_id])
+        assert self.mock_describe_jobs.called
+
+    def test_job_complete_waiter_succeeded(self):
+        """Test `JobComplete` waiter reach `SUCCEEDED` status."""
+        job_id = "job-succeeded"
+        self.mock_describe_jobs.side_effect = [
+            *itertools.repeat(self.describe_jobs_response(job_id=job_id, status=RUNNING_STATE), 10),
+            self.describe_jobs_response(job_id=job_id, status=SUCCESS_STATE),
+            RuntimeError("This should not raise"),
+        ]
+
+        job_complete_waiter = self.batch_waiters.get_waiter("JobComplete")
+        job_complete_waiter.config.delay = 0.01
+        job_complete_waiter.config.max_attempts = 20
+        job_complete_waiter.wait(jobs=[job_id])
+        assert self.mock_describe_jobs.called
+
+    def test_job_complete_waiter_failed(self):
+        """Test `JobComplete` waiter reach `FAILED` status."""
+        job_id = "job-failed"
+        self.mock_describe_jobs.side_effect = [
+            *itertools.repeat(self.describe_jobs_response(job_id=job_id, status=RUNNING_STATE), 10),
+            self.describe_jobs_response(job_id=job_id, status=FAILED_STATE),
+            RuntimeError("This should not raise"),
+        ]
+
+        job_complete_waiter = self.batch_waiters.get_waiter("JobComplete")
+        job_complete_waiter.config.delay = 0.01
+        job_complete_waiter.config.max_attempts = 20
+        with pytest.raises(
+            WaiterError, match="Waiter JobComplete failed: Waiter encountered a terminal failure state"
+        ):
+            job_complete_waiter.wait(jobs=[job_id])
+        assert self.mock_describe_jobs.called
+
+    def test_job_complete_waiter_max_attempt_exceeded(self):
+        """Test `JobComplete` waiter run out of attempts."""
+        job_id = "job-running-inf"
+        self.mock_describe_jobs.side_effect = itertools.repeat(
+            self.describe_jobs_response(job_id=job_id, status=RUNNING_STATE)
+        )
+        job_running_waiter = self.batch_waiters.get_waiter("JobComplete")
+        job_running_waiter.config.delay = 0.01
+        job_running_waiter.config.max_attempts = 20
+        with pytest.raises(WaiterError, match="Waiter JobComplete failed: Max attempts exceeded"):
+            job_running_waiter.wait(jobs=[job_id])
+        assert self.mock_describe_jobs.called