You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by on...@apache.org on 2023/03/02 22:48:15 UTC
[airflow] branch main updated: Fix Amazon ECS Enums (#29871)
This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 76d8aaa836 Fix Amazon ECS Enums (#29871)
76d8aaa836 is described below
commit 76d8aaa8362ba199d98680d71ccb3a800cbc4d38
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Mar 3 02:48:05 2023 +0400
Fix Amazon ECS Enums (#29871)
* Add missing tests
* Apply suggestions from code review
Co-authored-by: Niko Oliveira <on...@amazon.com>
---
airflow/providers/amazon/aws/hooks/ecs.py | 23 +-
airflow/providers/amazon/aws/sensors/ecs.py | 2 +-
tests/providers/amazon/aws/sensors/test_ecs.py | 261 +++++++++++++++++++++
.../amazon/aws/waiters/test_custom_waiters.py | 19 +-
4 files changed, 297 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/ecs.py b/airflow/providers/amazon/aws/hooks/ecs.py
index ccaea145dd..b0d26eb224 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -55,7 +55,24 @@ def should_retry_eni(exception: Exception):
return False
-class EcsClusterStates(str, Enum):
+class _StringCompareEnum(Enum):
+ """
+ Enum which can be compared with regular `str` and subclasses.
+
+ This class avoids multiple inheritance such as AwesomeEnum(str, Enum)
+ which does not work well with templated_fields and Jinja templates.
+ """
+
+ def __eq__(self, other):
+ if isinstance(other, str):
+ return self.value == other
+ return super().__eq__(other)
+
+ def __hash__(self):
+ return super().__hash__() # Need to set because we redefine __eq__
+
+
+class EcsClusterStates(_StringCompareEnum):
"""Contains the possible State values of an ECS Cluster."""
ACTIVE = "ACTIVE"
@@ -65,7 +82,7 @@ class EcsClusterStates(str, Enum):
INACTIVE = "INACTIVE"
-class EcsTaskDefinitionStates(str, Enum):
+class EcsTaskDefinitionStates(_StringCompareEnum):
"""Contains the possible State values of an ECS Task Definition."""
ACTIVE = "ACTIVE"
@@ -73,7 +90,7 @@ class EcsTaskDefinitionStates(str, Enum):
DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS"
-class EcsTaskStates(str, Enum):
+class EcsTaskStates(_StringCompareEnum):
"""Contains the possible State values of an ECS Task."""
PROVISIONING = "PROVISIONING"
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py
index c1151b8f9a..d3cfacbd41 100644
--- a/airflow/providers/amazon/aws/sensors/ecs.py
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, Sequence
import boto3
-from airflow import AirflowException
from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
EcsHook,
diff --git a/tests/providers/amazon/aws/sensors/test_ecs.py b/tests/providers/amazon/aws/sensors/test_ecs.py
new file mode 100644
index 0000000000..66d8bc5c54
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_ecs.py
@@ -0,0 +1,261 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TypeVar
+from unittest import mock
+
+import boto3
+import pytest
+from slugify import slugify
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.sensors.ecs import (
+ DEFAULT_CONN_ID,
+ EcsBaseSensor,
+ EcsClusterStates,
+ EcsClusterStateSensor,
+ EcsHook,
+ EcsTaskDefinitionStates,
+ EcsTaskDefinitionStateSensor,
+ EcsTaskStates,
+ EcsTaskStateSensor,
+)
+from airflow.utils import timezone
+from airflow.utils.types import NOTSET
+
+_Operator = TypeVar("_Operator")
+TEST_CLUSTER_NAME = "fake-cluster"
+TEST_TASK_ARN = "arn:aws:ecs:us-east-1:012345678910:task/spam-egg"
+TEST_TASK_DEFINITION_ARN = "arn:aws:ecs:us-east-1:012345678910:task-definition/foo-bar:42"
+
+
+class EcsBaseTestCase:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch, request, create_task_instance_of_operator):
+ self.dag_id = f"dag-{slugify(request.cls.__name__)}"
+ self.task_id = f"task-{slugify(request.node.name, max_length=40)}"
+ self.fake_client = boto3.client("ecs", region_name="eu-west-3")
+ monkeypatch.setattr(EcsHook, "conn", self.fake_client)
+ self.ti_maker = create_task_instance_of_operator
+
+ def create_rendered_task(self, operator_class: type[_Operator], **kwargs) -> _Operator:
+ """
+ Create operator from given class and render fields.
+
+ This might help to prevent of unexpected behaviour in Jinja/task field serialisation
+ """
+ return self.ti_maker(
+ operator_class,
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ execution_date=timezone.datetime(2021, 12, 21),
+ **kwargs,
+ ).render_templates()
+
+
+class TestEcsBaseSensor(EcsBaseTestCase):
+ @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+ @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+ def test_initialise_operator(self, aws_conn_id, region_name):
+ """Test sensor initialize."""
+ op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+ op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+ op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)
+
+ assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
+ assert op.region == (region_name if region_name is not NOTSET else None)
+
+ @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+ @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+ def test_hook_and_client(self, aws_conn_id, region_name):
+ """Test initialize ``EcsHook`` and ``boto3.client``."""
+ op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+ op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+ op = EcsBaseSensor(task_id="test_ecs_base_hook_client", **op_kw)
+
+ hook = op.hook
+ assert op.hook is hook, "'hook' property should be cached."
+ assert isinstance(op.hook, EcsHook)
+
+ client = op.client
+ assert op.client is client, "'client' property should be cached."
+ assert client is self.fake_client
+
+
+class TestEcsClusterStateSensor(EcsBaseTestCase):
+ @pytest.mark.parametrize(
+ "return_state, expected", [("ACTIVE", True), ("PROVISIONING", False), ("DEPROVISIONING", False)]
+ )
+ def test_default_values_poke(self, return_state, expected):
+ task = self.create_rendered_task(EcsClusterStateSensor, cluster_name=TEST_CLUSTER_NAME)
+ with mock.patch.object(task.hook, "get_cluster_state") as m:
+ m.return_value = return_state
+ assert task.poke({}) == expected
+ m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+ @pytest.mark.parametrize("return_state", ["FAILED", "INACTIVE"])
+ def test_default_values_terminal_state(self, create_task_of_operator, return_state):
+ task = self.create_rendered_task(EcsClusterStateSensor, cluster_name=TEST_CLUSTER_NAME)
+ with mock.patch.object(task.hook, "get_cluster_state") as m:
+ m.return_value = return_state
+ with pytest.raises(AirflowException, match="Terminal state reached"):
+ task.poke({})
+ m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+ @pytest.mark.parametrize(
+ "target_state, return_state, expected",
+ [
+ (EcsClusterStates.ACTIVE, "ACTIVE", True),
+ (EcsClusterStates.ACTIVE, "DEPROVISIONING", False),
+ (EcsClusterStates.DEPROVISIONING, "ACTIVE", False),
+ (EcsClusterStates.DEPROVISIONING, "DEPROVISIONING", True),
+ ],
+ )
+ def test_custom_values_poke(self, target_state, return_state, expected):
+ task = self.create_rendered_task(
+ EcsClusterStateSensor, cluster_name=TEST_CLUSTER_NAME, target_state=target_state
+ )
+ with mock.patch.object(task.hook, "get_cluster_state") as m:
+ m.return_value = return_state
+ assert task.poke({}) == expected
+ m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+ @pytest.mark.parametrize(
+ "failure_states, return_state",
+ [
+ ({EcsClusterStates.ACTIVE}, "ACTIVE"),
+ ({EcsClusterStates.PROVISIONING, EcsClusterStates.DEPROVISIONING}, "DEPROVISIONING"),
+ ({EcsClusterStates.PROVISIONING, EcsClusterStates.DEPROVISIONING}, "PROVISIONING"),
+ ],
+ )
+ def test_custom_values_terminal_state(self, failure_states, return_state):
+ task = self.create_rendered_task(
+ EcsClusterStateSensor,
+ cluster_name=TEST_CLUSTER_NAME,
+ target_state=EcsClusterStates.FAILED,
+ failure_states=failure_states,
+ )
+ with mock.patch.object(task.hook, "get_cluster_state") as m:
+ m.return_value = return_state
+ with pytest.raises(AirflowException, match="Terminal state reached"):
+ task.poke({})
+ m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+
+class TestEcsTaskDefinitionStateSensor(EcsBaseTestCase):
+ @pytest.mark.parametrize(
+ "return_state, expected", [("ACTIVE", True), ("INACTIVE", False), ("DELETE_IN_PROGRESS", False)]
+ )
+ def test_default_values_poke(self, return_state, expected):
+ task = self.create_rendered_task(
+ EcsTaskDefinitionStateSensor, task_definition=TEST_TASK_DEFINITION_ARN
+ )
+ with mock.patch.object(task.hook, "get_task_definition_state") as m:
+ m.return_value = return_state
+ assert task.poke({}) == expected
+ m.assert_called_once_with(task_definition=TEST_TASK_DEFINITION_ARN)
+
+ @pytest.mark.parametrize(
+ "target_state, return_state, expected",
+ [
+ (EcsTaskDefinitionStates.INACTIVE, "ACTIVE", False),
+ (EcsTaskDefinitionStates.INACTIVE, "INACTIVE", True),
+ (EcsTaskDefinitionStates.ACTIVE, "INACTIVE", False),
+ (EcsTaskDefinitionStates.ACTIVE, "ACTIVE", True),
+ ],
+ )
+ def test_custom_values_poke(self, create_task_of_operator, target_state, return_state, expected):
+ task = self.create_rendered_task(
+ EcsTaskDefinitionStateSensor, task_definition=TEST_TASK_DEFINITION_ARN, target_state=target_state
+ )
+ with mock.patch.object(task.hook, "get_task_definition_state") as m:
+ m.return_value = return_state
+ assert task.poke({}) == expected
+ m.assert_called_once_with(task_definition=TEST_TASK_DEFINITION_ARN)
+
+
+class TestEcsTaskStateSensor(EcsBaseTestCase):
+ @pytest.mark.parametrize(
+ "return_state, expected",
+ [
+ ("PROVISIONING", False),
+ ("PENDING", False),
+ ("ACTIVATING", False),
+ ("RUNNING", True),
+ ("DEACTIVATING", False),
+ ("STOPPING", False),
+ ("DEPROVISIONING", False),
+ ("NONE", False),
+ ],
+ )
+ def test_default_values_poke(self, return_state, expected):
+ task = self.create_rendered_task(EcsTaskStateSensor, cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+ with mock.patch.object(task.hook, "get_task_state") as m:
+ m.return_value = return_state
+ assert task.poke({}) == expected
+ m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+
+ @pytest.mark.parametrize("return_state", ["STOPPED"])
+ def test_default_values_terminal_state(self, return_state):
+ task = self.create_rendered_task(EcsTaskStateSensor, cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+ with mock.patch.object(task.hook, "get_task_state") as m:
+ m.return_value = return_state
+ with pytest.raises(AirflowException, match="Terminal state reached"):
+ task.poke({})
+ m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+
+ @pytest.mark.parametrize(
+ "target_state, return_state, expected",
+ [
+ (EcsTaskStates.RUNNING, "RUNNING", True),
+ (EcsTaskStates.DEACTIVATING, "DEACTIVATING", True),
+ (EcsTaskStates.NONE, "PENDING", False),
+ (EcsTaskStates.STOPPING, "NONE", False),
+ ],
+ )
+ def test_custom_values_poke(self, target_state, return_state, expected):
+ task = self.create_rendered_task(
+ EcsTaskStateSensor, cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN, target_state=target_state
+ )
+ with mock.patch.object(task.hook, "get_task_state") as m:
+ m.return_value = return_state
+ assert task.poke({}) == expected
+ m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+
+ @pytest.mark.parametrize(
+ "failure_states, return_state",
+ [
+ ({EcsTaskStates.RUNNING}, "RUNNING"),
+ ({EcsTaskStates.RUNNING, EcsTaskStates.DEACTIVATING}, "DEACTIVATING"),
+ ({EcsTaskStates.RUNNING, EcsTaskStates.DEACTIVATING}, "RUNNING"),
+ ],
+ )
+ def test_custom_values_terminal_state(self, failure_states, return_state):
+ task = self.create_rendered_task(
+ EcsTaskStateSensor,
+ cluster=TEST_CLUSTER_NAME,
+ task=TEST_TASK_ARN,
+ target_state=EcsTaskStates.NONE,
+ failure_states=failure_states,
+ )
+ with mock.patch.object(task.hook, "get_task_state") as m:
+ m.return_value = return_state
+ with pytest.raises(AirflowException, match="Terminal state reached"):
+ task.poke({})
+ m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
index 69f1b99b13..5531b6a801 100644
--- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -132,12 +132,18 @@ class TestCustomECSServiceWaiters:
assert "task_definition_inactive" in hook_waiters
@staticmethod
- def describe_clusters(status: str, cluster_name: str = "spam-egg", failures: dict | list | None = None):
+ def describe_clusters(
+ status: str | EcsClusterStates, cluster_name: str = "spam-egg", failures: dict | list | None = None
+ ):
"""
Helper function for generate minimal DescribeClusters response for single job.
https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeClusters.html
"""
- assert status in EcsClusterStates.__members__.values()
+ if isinstance(status, EcsClusterStates):
+ status = status.value
+ else:
+ assert status in EcsClusterStates.__members__.values()
+
failures = failures or []
if isinstance(failures, dict):
failures = [failures]
@@ -154,7 +160,7 @@ class TestCustomECSServiceWaiters:
waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3})
- @pytest.mark.parametrize("state", [EcsClusterStates.FAILED, EcsClusterStates.INACTIVE])
+ @pytest.mark.parametrize("state", ["FAILED", "INACTIVE"])
def test_cluster_active_failure_states(self, mock_describe_clusters, state):
"""Test cluster reach inactive state during creation."""
mock_describe_clusters.side_effect = [
@@ -196,11 +202,16 @@ class TestCustomECSServiceWaiters:
waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3})
@staticmethod
- def describe_task_definition(status: str, task_definition: str = "spam-egg"):
+ def describe_task_definition(status: str | EcsTaskDefinitionStates, task_definition: str = "spam-egg"):
"""
Helper function for generate minimal DescribeTaskDefinition response for single job.
https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html
"""
+ if isinstance(status, EcsTaskDefinitionStates):
+ status = status.value
+ else:
+ assert status in EcsTaskDefinitionStates.__members__.values()
+
return {
"taskDefinition": {
"taskDefinitionArn": (