You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by on...@apache.org on 2023/03/02 22:48:15 UTC

[airflow] branch main updated: Fix Amazon ECS Enums (#29871)

This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 76d8aaa836 Fix Amazon ECS Enums (#29871)
76d8aaa836 is described below

commit 76d8aaa8362ba199d98680d71ccb3a800cbc4d38
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Mar 3 02:48:05 2023 +0400

    Fix Amazon ECS Enums (#29871)
    
    * Add missing tests
    
    * Apply suggestions from code review
    
    Co-authored-by: Niko Oliveira <on...@amazon.com>
---
 airflow/providers/amazon/aws/hooks/ecs.py          |  23 +-
 airflow/providers/amazon/aws/sensors/ecs.py        |   2 +-
 tests/providers/amazon/aws/sensors/test_ecs.py     | 261 +++++++++++++++++++++
 .../amazon/aws/waiters/test_custom_waiters.py      |  19 +-
 4 files changed, 297 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/ecs.py b/airflow/providers/amazon/aws/hooks/ecs.py
index ccaea145dd..b0d26eb224 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -55,7 +55,24 @@ def should_retry_eni(exception: Exception):
     return False
 
 
-class EcsClusterStates(str, Enum):
+class _StringCompareEnum(Enum):
+    """
+    Enum which can be compared with regular `str` and subclasses.
+
+    This class avoids multiple inheritance such as AwesomeEnum(str, Enum)
+    which does not work well with templated_fields and Jinja templates.
+    """
+
+    def __eq__(self, other):
+        if isinstance(other, str):
+            return self.value == other
+        return super().__eq__(other)
+
+    def __hash__(self):
+        return super().__hash__()  # Need to set because we redefine __eq__
+
+
+class EcsClusterStates(_StringCompareEnum):
     """Contains the possible State values of an ECS Cluster."""
 
     ACTIVE = "ACTIVE"
@@ -65,7 +82,7 @@ class EcsClusterStates(str, Enum):
     INACTIVE = "INACTIVE"
 
 
-class EcsTaskDefinitionStates(str, Enum):
+class EcsTaskDefinitionStates(_StringCompareEnum):
     """Contains the possible State values of an ECS Task Definition."""
 
     ACTIVE = "ACTIVE"
@@ -73,7 +90,7 @@ class EcsTaskDefinitionStates(str, Enum):
     DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS"
 
 
-class EcsTaskStates(str, Enum):
+class EcsTaskStates(_StringCompareEnum):
     """Contains the possible State values of an ECS Task."""
 
     PROVISIONING = "PROVISIONING"
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py
index c1151b8f9a..d3cfacbd41 100644
--- a/airflow/providers/amazon/aws/sensors/ecs.py
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, Sequence
 
 import boto3
 
-from airflow import AirflowException
 from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.ecs import (
     EcsClusterStates,
     EcsHook,
diff --git a/tests/providers/amazon/aws/sensors/test_ecs.py b/tests/providers/amazon/aws/sensors/test_ecs.py
new file mode 100644
index 0000000000..66d8bc5c54
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_ecs.py
@@ -0,0 +1,261 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TypeVar
+from unittest import mock
+
+import boto3
+import pytest
+from slugify import slugify
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.sensors.ecs import (
+    DEFAULT_CONN_ID,
+    EcsBaseSensor,
+    EcsClusterStates,
+    EcsClusterStateSensor,
+    EcsHook,
+    EcsTaskDefinitionStates,
+    EcsTaskDefinitionStateSensor,
+    EcsTaskStates,
+    EcsTaskStateSensor,
+)
+from airflow.utils import timezone
+from airflow.utils.types import NOTSET
+
+_Operator = TypeVar("_Operator")
+TEST_CLUSTER_NAME = "fake-cluster"
+TEST_TASK_ARN = "arn:aws:ecs:us-east-1:012345678910:task/spam-egg"
+TEST_TASK_DEFINITION_ARN = "arn:aws:ecs:us-east-1:012345678910:task-definition/foo-bar:42"
+
+
+class EcsBaseTestCase:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, monkeypatch, request, create_task_instance_of_operator):
+        self.dag_id = f"dag-{slugify(request.cls.__name__)}"
+        self.task_id = f"task-{slugify(request.node.name, max_length=40)}"
+        self.fake_client = boto3.client("ecs", region_name="eu-west-3")
+        monkeypatch.setattr(EcsHook, "conn", self.fake_client)
+        self.ti_maker = create_task_instance_of_operator
+
+    def create_rendered_task(self, operator_class: type[_Operator], **kwargs) -> _Operator:
+        """
+        Create operator from given class and render fields.
+
+        This might help to prevent of unexpected behaviour in Jinja/task field serialisation
+        """
+        return self.ti_maker(
+            operator_class,
+            dag_id=self.dag_id,
+            task_id=self.task_id,
+            execution_date=timezone.datetime(2021, 12, 21),
+            **kwargs,
+        ).render_templates()
+
+
+class TestEcsBaseSensor(EcsBaseTestCase):
+    @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+    @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+    def test_initialise_operator(self, aws_conn_id, region_name):
+        """Test sensor initialize."""
+        op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+        op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+        op = EcsBaseSensor(task_id="test_ecs_base", **op_kw)
+
+        assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
+        assert op.region == (region_name if region_name is not NOTSET else None)
+
+    @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+    @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+    def test_hook_and_client(self, aws_conn_id, region_name):
+        """Test initialize ``EcsHook`` and ``boto3.client``."""
+        op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+        op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+        op = EcsBaseSensor(task_id="test_ecs_base_hook_client", **op_kw)
+
+        hook = op.hook
+        assert op.hook is hook, "'hook' property should be cached."
+        assert isinstance(op.hook, EcsHook)
+
+        client = op.client
+        assert op.client is client, "'client' property should be cached."
+        assert client is self.fake_client
+
+
+class TestEcsClusterStateSensor(EcsBaseTestCase):
+    @pytest.mark.parametrize(
+        "return_state, expected", [("ACTIVE", True), ("PROVISIONING", False), ("DEPROVISIONING", False)]
+    )
+    def test_default_values_poke(self, return_state, expected):
+        task = self.create_rendered_task(EcsClusterStateSensor, cluster_name=TEST_CLUSTER_NAME)
+        with mock.patch.object(task.hook, "get_cluster_state") as m:
+            m.return_value = return_state
+            assert task.poke({}) == expected
+            m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+    @pytest.mark.parametrize("return_state", ["FAILED", "INACTIVE"])
+    def test_default_values_terminal_state(self, create_task_of_operator, return_state):
+        task = self.create_rendered_task(EcsClusterStateSensor, cluster_name=TEST_CLUSTER_NAME)
+        with mock.patch.object(task.hook, "get_cluster_state") as m:
+            m.return_value = return_state
+            with pytest.raises(AirflowException, match="Terminal state reached"):
+                task.poke({})
+            m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+    @pytest.mark.parametrize(
+        "target_state, return_state, expected",
+        [
+            (EcsClusterStates.ACTIVE, "ACTIVE", True),
+            (EcsClusterStates.ACTIVE, "DEPROVISIONING", False),
+            (EcsClusterStates.DEPROVISIONING, "ACTIVE", False),
+            (EcsClusterStates.DEPROVISIONING, "DEPROVISIONING", True),
+        ],
+    )
+    def test_custom_values_poke(self, target_state, return_state, expected):
+        task = self.create_rendered_task(
+            EcsClusterStateSensor, cluster_name=TEST_CLUSTER_NAME, target_state=target_state
+        )
+        with mock.patch.object(task.hook, "get_cluster_state") as m:
+            m.return_value = return_state
+            assert task.poke({}) == expected
+            m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+    @pytest.mark.parametrize(
+        "failure_states, return_state",
+        [
+            ({EcsClusterStates.ACTIVE}, "ACTIVE"),
+            ({EcsClusterStates.PROVISIONING, EcsClusterStates.DEPROVISIONING}, "DEPROVISIONING"),
+            ({EcsClusterStates.PROVISIONING, EcsClusterStates.DEPROVISIONING}, "PROVISIONING"),
+        ],
+    )
+    def test_custom_values_terminal_state(self, failure_states, return_state):
+        task = self.create_rendered_task(
+            EcsClusterStateSensor,
+            cluster_name=TEST_CLUSTER_NAME,
+            target_state=EcsClusterStates.FAILED,
+            failure_states=failure_states,
+        )
+        with mock.patch.object(task.hook, "get_cluster_state") as m:
+            m.return_value = return_state
+            with pytest.raises(AirflowException, match="Terminal state reached"):
+                task.poke({})
+            m.assert_called_once_with(cluster_name=TEST_CLUSTER_NAME)
+
+
+class TestEcsTaskDefinitionStateSensor(EcsBaseTestCase):
+    @pytest.mark.parametrize(
+        "return_state, expected", [("ACTIVE", True), ("INACTIVE", False), ("DELETE_IN_PROGRESS", False)]
+    )
+    def test_default_values_poke(self, return_state, expected):
+        task = self.create_rendered_task(
+            EcsTaskDefinitionStateSensor, task_definition=TEST_TASK_DEFINITION_ARN
+        )
+        with mock.patch.object(task.hook, "get_task_definition_state") as m:
+            m.return_value = return_state
+            assert task.poke({}) == expected
+            m.assert_called_once_with(task_definition=TEST_TASK_DEFINITION_ARN)
+
+    @pytest.mark.parametrize(
+        "target_state, return_state, expected",
+        [
+            (EcsTaskDefinitionStates.INACTIVE, "ACTIVE", False),
+            (EcsTaskDefinitionStates.INACTIVE, "INACTIVE", True),
+            (EcsTaskDefinitionStates.ACTIVE, "INACTIVE", False),
+            (EcsTaskDefinitionStates.ACTIVE, "ACTIVE", True),
+        ],
+    )
+    def test_custom_values_poke(self, create_task_of_operator, target_state, return_state, expected):
+        task = self.create_rendered_task(
+            EcsTaskDefinitionStateSensor, task_definition=TEST_TASK_DEFINITION_ARN, target_state=target_state
+        )
+        with mock.patch.object(task.hook, "get_task_definition_state") as m:
+            m.return_value = return_state
+            assert task.poke({}) == expected
+            m.assert_called_once_with(task_definition=TEST_TASK_DEFINITION_ARN)
+
+
+class TestEcsTaskStateSensor(EcsBaseTestCase):
+    @pytest.mark.parametrize(
+        "return_state, expected",
+        [
+            ("PROVISIONING", False),
+            ("PENDING", False),
+            ("ACTIVATING", False),
+            ("RUNNING", True),
+            ("DEACTIVATING", False),
+            ("STOPPING", False),
+            ("DEPROVISIONING", False),
+            ("NONE", False),
+        ],
+    )
+    def test_default_values_poke(self, return_state, expected):
+        task = self.create_rendered_task(EcsTaskStateSensor, cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+        with mock.patch.object(task.hook, "get_task_state") as m:
+            m.return_value = return_state
+            assert task.poke({}) == expected
+            m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+
+    @pytest.mark.parametrize("return_state", ["STOPPED"])
+    def test_default_values_terminal_state(self, return_state):
+        task = self.create_rendered_task(EcsTaskStateSensor, cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+        with mock.patch.object(task.hook, "get_task_state") as m:
+            m.return_value = return_state
+            with pytest.raises(AirflowException, match="Terminal state reached"):
+                task.poke({})
+            m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+
+    @pytest.mark.parametrize(
+        "target_state, return_state, expected",
+        [
+            (EcsTaskStates.RUNNING, "RUNNING", True),
+            (EcsTaskStates.DEACTIVATING, "DEACTIVATING", True),
+            (EcsTaskStates.NONE, "PENDING", False),
+            (EcsTaskStates.STOPPING, "NONE", False),
+        ],
+    )
+    def test_custom_values_poke(self, target_state, return_state, expected):
+        task = self.create_rendered_task(
+            EcsTaskStateSensor, cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN, target_state=target_state
+        )
+        with mock.patch.object(task.hook, "get_task_state") as m:
+            m.return_value = return_state
+            assert task.poke({}) == expected
+            m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
+
+    @pytest.mark.parametrize(
+        "failure_states, return_state",
+        [
+            ({EcsTaskStates.RUNNING}, "RUNNING"),
+            ({EcsTaskStates.RUNNING, EcsTaskStates.DEACTIVATING}, "DEACTIVATING"),
+            ({EcsTaskStates.RUNNING, EcsTaskStates.DEACTIVATING}, "RUNNING"),
+        ],
+    )
+    def test_custom_values_terminal_state(self, failure_states, return_state):
+        task = self.create_rendered_task(
+            EcsTaskStateSensor,
+            cluster=TEST_CLUSTER_NAME,
+            task=TEST_TASK_ARN,
+            target_state=EcsTaskStates.NONE,
+            failure_states=failure_states,
+        )
+        with mock.patch.object(task.hook, "get_task_state") as m:
+            m.return_value = return_state
+            with pytest.raises(AirflowException, match="Terminal state reached"):
+                task.poke({})
+            m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
index 69f1b99b13..5531b6a801 100644
--- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -132,12 +132,18 @@ class TestCustomECSServiceWaiters:
         assert "task_definition_inactive" in hook_waiters
 
     @staticmethod
-    def describe_clusters(status: str, cluster_name: str = "spam-egg", failures: dict | list | None = None):
+    def describe_clusters(
+        status: str | EcsClusterStates, cluster_name: str = "spam-egg", failures: dict | list | None = None
+    ):
         """
         Helper function for generate minimal DescribeClusters response for single job.
         https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeClusters.html
         """
-        assert status in EcsClusterStates.__members__.values()
+        if isinstance(status, EcsClusterStates):
+            status = status.value
+        else:
+            assert status in EcsClusterStates.__members__.values()
+
         failures = failures or []
         if isinstance(failures, dict):
             failures = [failures]
@@ -154,7 +160,7 @@ class TestCustomECSServiceWaiters:
         waiter = EcsHook(aws_conn_id=None).get_waiter("cluster_active")
         waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3})
 
-    @pytest.mark.parametrize("state", [EcsClusterStates.FAILED, EcsClusterStates.INACTIVE])
+    @pytest.mark.parametrize("state", ["FAILED", "INACTIVE"])
     def test_cluster_active_failure_states(self, mock_describe_clusters, state):
         """Test cluster reach inactive state during creation."""
         mock_describe_clusters.side_effect = [
@@ -196,11 +202,16 @@ class TestCustomECSServiceWaiters:
         waiter.wait(clusters=["spam-egg"], WaiterConfig={"Delay": 0.01, "MaxAttempts": 3})
 
     @staticmethod
-    def describe_task_definition(status: str, task_definition: str = "spam-egg"):
+    def describe_task_definition(status: str | EcsTaskDefinitionStates, task_definition: str = "spam-egg"):
         """
         Helper function for generate minimal DescribeTaskDefinition response for single job.
         https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_DescribeTaskDefinition.html
         """
+        if isinstance(status, EcsTaskDefinitionStates):
+            status = status.value
+        else:
+            assert status in EcsTaskDefinitionStates.__members__.values()
+
         return {
             "taskDefinition": {
                 "taskDefinitionArn": (