You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fe...@apache.org on 2020/06/22 08:19:04 UTC

[airflow] branch master updated: Add AWS ECS system test (#8888)

This is an automated email from the ASF dual-hosted git repository.

feluelle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new c7a454a  Add AWS ECS system test (#8888)
c7a454a is described below

commit c7a454aa32bf33133d042e8438ac259b32144b21
Author: Mustafa Gök <sd...@gmail.com>
AuthorDate: Mon Jun 22 11:18:13 2020 +0300

    Add AWS ECS system test (#8888)
---
 .../amazon/aws/example_dags/example_ecs_fargate.py |  11 +-
 .../amazon/aws/operators/test_ecs_system.py        |  99 ++++++++++
 tests/test_utils/amazon_system_helpers.py          | 211 +++++++++++++++++++++
 3 files changed, 316 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
index be55e67..4c75d8f 100644
--- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
+++ b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
@@ -23,6 +23,7 @@ It overrides the command in the `hello-world-container` container.
 """
 
 import datetime
+import os
 
 from airflow import DAG
 from airflow.providers.amazon.aws.operators.ecs import ECSOperator
@@ -50,7 +51,7 @@ dag.doc_md = __doc__
 hello_world = ECSOperator(
     task_id="hello_world",
     dag=dag,
-    aws_conn_id="aws_default",
+    aws_conn_id="aws_ecs",
     cluster="c",
     task_definition="hello-world",
     launch_type="FARGATE",
@@ -64,8 +65,8 @@ hello_world = ECSOperator(
     },
     network_configuration={
         "awsvpcConfiguration": {
-            "securityGroups": ["sg-123abc"],
-            "subnets": ["subnet-123456ab"],
+            "securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")],
+            "subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")],
         },
     },
     tags={
@@ -75,7 +76,7 @@ hello_world = ECSOperator(
         "Version": "0.0.1",
         "Environment": "Development",
     },
-    awslogs_group="/ecs_logs/group_a",
-    awslogs_stream_prefix="prefix_b",
+    awslogs_group="/ecs/hello-world",
+    awslogs_stream_prefix="prefix_b/hello-world-container",  # prefix with container name
 )
 # [END howto_operator_ecs]
diff --git a/tests/providers/amazon/aws/operators/test_ecs_system.py b/tests/providers/amazon/aws/operators/test_ecs_system.py
new file mode 100644
index 0000000..1a6eec7
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_ecs_system.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import pytest
+
+from tests.test_utils.amazon_system_helpers import AWS_DAG_FOLDER, AmazonSystemTest
+
+
+@pytest.mark.backend("postgres", "mysql")
+class ECSSystemTest(AmazonSystemTest):
+    """
+    ECS System Test to run and test example ECS dags
+
+    Required variables.env file content (from your account):
+        # Auto-export all variables
+        set -a
+
+        # aws parameters
+        REGION_NAME="eu-west-1"
+        REGISTRY_ID="123456789012"
+        IMAGE="alpine:3.9"
+        SUBNET_ID="subnet-068e9654a3c357a"
+        SECURITY_GROUP_ID="sg-054dc69874a651"
+        EXECUTION_ROLE_ARN="arn:aws:iam::123456789012:role/FooBarRole"
+
+        # remove all created/existing resources flag
+        # comment out to keep resources or use empty string
+        # REMOVE_RESOURCES="True"
+    """
+
+    # should be same as in the example dag
+    aws_conn_id = "aws_ecs"
+    cluster = "c"
+    task_definition = "hello-world"
+    container = "hello-world-container"
+    awslogs_group = "/ecs/hello-world"
+    awslogs_stream_prefix = "prefix_b"  # only prefix without container name
+
+    @classmethod
+    def setup_class(cls):
+        cls.create_connection(
+            aws_conn_id=cls.aws_conn_id,
+            region=cls._region_name(),
+        )
+
+        # create ecs cluster if it does not exist
+        cls.create_ecs_cluster(
+            aws_conn_id=cls.aws_conn_id,
+            cluster_name=cls.cluster,
+        )
+
+        # create task_definition if it does not exist
+        task_definition_exists = cls.is_ecs_task_definition_exists(
+            aws_conn_id=cls.aws_conn_id,
+            task_definition=cls.task_definition,
+        )
+        if not task_definition_exists:
+            cls.create_ecs_task_definition(
+                aws_conn_id=cls.aws_conn_id,
+                task_definition=cls.task_definition,
+                container=cls.container,
+                image=cls._image(),
+                execution_role_arn=cls._execution_role_arn(),
+                awslogs_group=cls.awslogs_group,
+                awslogs_region=cls._region_name(),
+                awslogs_stream_prefix=cls.awslogs_stream_prefix,
+            )
+
+    @classmethod
+    def teardown_class(cls):
+        # remove all created/existing resources in tear down
+        if cls._remove_resources():
+            cls.delete_ecs_cluster(
+                aws_conn_id=cls.aws_conn_id,
+                cluster_name=cls.cluster,
+            )
+            cls.delete_ecs_task_definition(
+                aws_conn_id=cls.aws_conn_id,
+                task_definition=cls.task_definition,
+            )
+
+    def test_run_example_dag_ecs_fargate_dag(self):
+        self.run_dag("ecs_fargate_dag", AWS_DAG_FOLDER)
diff --git a/tests/test_utils/amazon_system_helpers.py b/tests/test_utils/amazon_system_helpers.py
index b0c9710..cd5d961 100644
--- a/tests/test_utils/amazon_system_helpers.py
+++ b/tests/test_utils/amazon_system_helpers.py
@@ -21,6 +21,9 @@ from typing import List
 
 import pytest
 
+from airflow.models import Connection
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.utils import db
 from tests.test_utils import AIRFLOW_MAIN_FOLDER
 from tests.test_utils.system_tests_class import SystemTest
 from tests.utils.logging_command_executor import get_executor
@@ -51,6 +54,27 @@ def provide_aws_s3_bucket(name):
 @pytest.mark.system("amazon")
 class AmazonSystemTest(SystemTest):
 
+    @staticmethod
+    def _region_name():
+        return os.environ.get("REGION_NAME")
+
+    @staticmethod
+    def _registry_id():
+        return os.environ.get("REGISTRY_ID")
+
+    @staticmethod
+    def _image():
+        return os.environ.get("IMAGE")
+
+    @staticmethod
+    def _execution_role_arn():
+        return os.environ.get("EXECUTION_ROLE_ARN")
+
+    @staticmethod
+    def _remove_resources():
+        # remove all created/existing resources flag
+        return os.environ.get("REMOVE_RESOURCES", False)
+
     @classmethod
     def execute_with_ctx(cls, cmd: List[str]):
         """
@@ -60,6 +84,25 @@ class AmazonSystemTest(SystemTest):
         with provide_aws_context():
             executor.execute_cmd(cmd=cmd)
 
+    @staticmethod
+    def create_connection(aws_conn_id: str,
+                          region: str) -> None:
+        """
+        Create aws connection with region
+
+        :param aws_conn_id: id of the aws connection to create
+        :type aws_conn_id: str
+        :param region: aws region name to use in extra field of the aws connection
+        :type region: str
+        """
+        db.merge_conn(
+            Connection(
+                conn_id=aws_conn_id,
+                conn_type="aws",
+                extra=f'{{"region_name": "{region}"}}',
+            ),
+        )
+
     @classmethod
     def create_aws_s3_bucket(cls, name: str) -> None:
         """
@@ -92,3 +135,171 @@ class AmazonSystemTest(SystemTest):
         """
         cmd = ["aws", "emr", "create-default-roles"]
         cls.execute_with_ctx(cmd)
+
+    @staticmethod
+    def create_ecs_cluster(aws_conn_id: str,
+                           cluster_name: str) -> None:
+        """
+        Create ecs cluster with given name
+
+        If specified cluster exists, it doesn't change and new cluster will not be created.
+
+        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
+        :type aws_conn_id: str
+        :param cluster_name: name of the cluster to create in aws ecs
+        :type cluster_name: str
+        """
+        hook = AwsBaseHook(
+            aws_conn_id=aws_conn_id,
+            client_type="ecs",
+        )
+        hook.conn.create_cluster(
+            clusterName=cluster_name,
+            capacityProviders=[
+                "FARGATE_SPOT",
+                "FARGATE",
+            ],
+            defaultCapacityProviderStrategy=[
+                {
+                    "capacityProvider": "FARGATE_SPOT",
+                    "weight": 1,
+                    "base": 0,
+                },
+                {
+                    "capacityProvider": "FARGATE",
+                    "weight": 1,
+                    "base": 0,
+                },
+            ],
+        )
+
+    @staticmethod
+    def delete_ecs_cluster(aws_conn_id: str,
+                           cluster_name: str) -> None:
+        """
+        Delete ecs cluster with given short name or full Amazon Resource Name (ARN)
+
+        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
+        :type aws_conn_id: str
+        :param cluster_name: name of the cluster to delete in aws ecs
+        :type cluster_name: str
+        """
+        hook = AwsBaseHook(
+            aws_conn_id=aws_conn_id,
+            client_type="ecs",
+        )
+        hook.conn.delete_cluster(
+            cluster=cluster_name,
+        )
+
+    @staticmethod
+    def create_ecs_task_definition(aws_conn_id: str,
+                                   task_definition: str,
+                                   container: str,
+                                   image: str,
+                                   execution_role_arn: str,
+                                   awslogs_group: str,
+                                   awslogs_region: str,
+                                   awslogs_stream_prefix: str) -> None:
+        """
+        Create ecs task definition with given name
+
+        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
+        :type aws_conn_id: str
+        :param task_definition: family name for task definition to create in aws ecs
+        :type task_definition: str
+        :param container: name of the container
+        :type container: str
+        :param image: image used to start a container,
+            format: `registry_id`.dkr.ecr.`region`.amazonaws.com/`repository_name`:`tag`
+        :type image: str
+        :param execution_role_arn: task execution role that the Amazon ECS container agent can assume,
+            format: arn:aws:iam::`registry_id`:role/`role_name`
+        :type execution_role_arn: str
+        :param awslogs_group: awslogs group option in log configuration
+        :type awslogs_group: str
+        :param awslogs_region: awslogs region option in log configuration
+        :type awslogs_region: str
+        :param awslogs_stream_prefix: awslogs stream prefix option in log configuration
+        :type awslogs_stream_prefix: str
+        """
+        hook = AwsBaseHook(
+            aws_conn_id=aws_conn_id,
+            client_type="ecs",
+        )
+        hook.conn.register_task_definition(
+            family=task_definition,
+            executionRoleArn=execution_role_arn,
+            networkMode="awsvpc",
+            containerDefinitions=[
+                {
+                    "name": container,
+                    "image": image,
+                    "cpu": 256,
+                    "memory": 512,  # hard limit
+                    "memoryReservation": 512,  # soft limit
+                    "logConfiguration": {
+                        "logDriver": "awslogs",
+                        "options": {
+                            "awslogs-group": awslogs_group,
+                            "awslogs-region": awslogs_region,
+                            "awslogs-stream-prefix": awslogs_stream_prefix,
+                        },
+                    },
+                },
+            ],
+            requiresCompatibilities=[
+                "FARGATE",
+            ],
+            cpu="256",  # task cpu limit (total of all containers)
+            memory="512",  # task memory limit (total of all containers)
+        )
+
+    @staticmethod
+    def delete_ecs_task_definition(aws_conn_id: str,
+                                   task_definition: str) -> None:
+        """
+        Delete all revisions of given ecs task definition
+
+        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
+        :type aws_conn_id: str
+        :param task_definition: family prefix for task definition to delete in aws ecs
+        :type task_definition: str
+        """
+        hook = AwsBaseHook(
+            aws_conn_id=aws_conn_id,
+            client_type="ecs",
+        )
+        response = hook.conn.list_task_definitions(
+            familyPrefix=task_definition,
+            status="ACTIVE",
+            sort="ASC",
+            maxResults=100,
+        )
+        revisions = [arn.split(":")[-1] for arn in response["taskDefinitionArns"]]
+        for revision in revisions:
+            hook.conn.deregister_task_definition(
+                taskDefinition=f"{task_definition}:{revision}",
+            )
+
+    @staticmethod
+    def is_ecs_task_definition_exists(aws_conn_id: str,
+                                      task_definition: str) -> bool:
+        """
+        Check whether given task definition exits in ecs
+
+        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
+        :type aws_conn_id: str
+        :param task_definition: family prefix for task definition to check in aws ecs
+        :type task_definition: str
+        """
+        hook = AwsBaseHook(
+            aws_conn_id=aws_conn_id,
+            client_type="ecs",
+        )
+        response = hook.conn.list_task_definition_families(
+            familyPrefix=task_definition,
+            status="ACTIVE",
+            maxResults=100,
+        )
+        return task_definition in response["families"]