You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jo...@apache.org on 2023/03/07 22:22:11 UTC

[airflow] branch main updated: Add `EC2CreateInstanceOperator`, `EC2TerminateInstanceOperator` (#29548)

This is an automated email from the ASF dual-hosted git repository.

joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d2cc9df82c Add `EC2CreateInstanceOperator`, `EC2TerminateInstanceOperator` (#29548)
d2cc9df82c is described below

commit d2cc9df82c8b6ae6cccb51462b8b5a37155666a7
Author: Syed Hussaain <10...@users.noreply.github.com>
AuthorDate: Tue Mar 7 14:22:02 2023 -0800

    Add `EC2CreateInstanceOperator`, `EC2TerminateInstanceOperator` (#29548)
    
    * Add EC2CreateInstanceOperator and EC2TerminteInstanceOperator
    Change system test to use the new operators
    Add unit tests for new operators
    
    * Add support for multiple ids to EC2TerminateInstanceOperator
    Change system test to terminate without stopping instances
    
    * Fix failing tests for terminate operator
    
    * Update doc strings to add that the operators can create/terminate multiple instances
    Add tests for creating/terminating multiple instances
    
    * Fix system test so it passes
    Fix doc string on EC2TerminateInstanceOperator
    
    ---------
    
    Co-authored-by: syedahsn <sy...@ud74d1a752d7e5b.ant.amazon.com>
---
 airflow/providers/amazon/aws/operators/ec2.py      | 138 +++++++++++++++++++++
 .../operators/ec2.rst                              |  28 +++++
 tests/providers/amazon/aws/operators/test_ec2.py   | 130 +++++++++++++++++--
 tests/system/providers/amazon/aws/example_ec2.py   | 104 +++++++++-------
 4 files changed, 345 insertions(+), 55 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ec2.py b/airflow/providers/amazon/aws/operators/ec2.py
index 60cb43a32d..5f6b76ce15 100644
--- a/airflow/providers/amazon/aws/operators/ec2.py
+++ b/airflow/providers/amazon/aws/operators/ec2.py
@@ -116,3 +116,141 @@ class EC2StopInstanceOperator(BaseOperator):
             target_state="stopped",
             check_interval=self.check_interval,
         )
+
+
+class EC2CreateInstanceOperator(BaseOperator):
+    """
+    Create and start a specified number of EC2 Instances using boto3
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:EC2CreateInstanceOperator`
+
+    :param image_id: ID of the AMI used to create the instance.
+    :param max_count: Maximum number of instances to launch. Defaults to 1.
+    :param min_count: Minimum number of instances to launch. Defaults to 1.
+    :param aws_conn_id: AWS connection to use
+    :param region_name: AWS region name associated with the client.
+    :param poll_interval: Number of seconds to wait before attempting to
+        check state of instance. Only used if wait_for_completion is True. Default is 20.
+    :param max_attempts: Maximum number of attempts when checking state of instance.
+        Only used if wait_for_completion is True. Default is 20.
+    :param config: Dictionary for arbitrary parameters to the boto3 run_instances call.
+    :param wait_for_completion: If True, the operator will wait for the instance to be
+        in the `running` state before returning.
+    """
+
+    template_fields: Sequence[str] = (
+        "image_id",
+        "max_count",
+        "min_count",
+        "aws_conn_id",
+        "region_name",
+        "config",
+        "wait_for_completion",
+    )
+
+    def __init__(
+        self,
+        image_id: str,
+        max_count: int = 1,
+        min_count: int = 1,
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        poll_interval: int = 20,
+        max_attempts: int = 20,
+        config: dict | None = None,
+        wait_for_completion: bool = False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.image_id = image_id
+        self.max_count = max_count
+        self.min_count = min_count
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
+        self.config = config or {}
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: Context):
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
+        instances = ec2_hook.conn.run_instances(
+            ImageId=self.image_id,
+            MinCount=self.min_count,
+            MaxCount=self.max_count,
+            **self.config,
+        )["Instances"]
+        instance_ids = []
+        for instance in instances:
+            instance_ids.append(instance["InstanceId"])
+            self.log.info("Created EC2 instance %s", instance["InstanceId"])
+
+            if self.wait_for_completion:
+                ec2_hook.get_waiter("instance_running").wait(
+                    InstanceIds=[instance["InstanceId"]],
+                    WaiterConfig={
+                        "Delay": self.poll_interval,
+                        "MaxAttempts": self.max_attempts,
+                    },
+                )
+
+        return instance_ids
+
+
+class EC2TerminateInstanceOperator(BaseOperator):
+    """
+    Terminate EC2 Instances using boto3
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:EC2TerminateInstanceOperator`
+
+    :param instance_id: ID of the instance to be terminated.
+    :param aws_conn_id: AWS connection to use
+    :param region_name: AWS region name associated with the client.
+    :param poll_interval: Number of seconds to wait before attempting to
+        check state of instance. Only used if wait_for_completion is True. Default is 20.
+    :param max_attempts: Maximum number of attempts when checking state of instance.
+        Only used if wait_for_completion is True. Default is 20.
+    :param wait_for_completion: If True, the operator will wait for the instance to be
+        in the `terminated` state before returning.
+    """
+
+    template_fields: Sequence[str] = ("instance_ids", "region_name", "aws_conn_id", "wait_for_completion")
+
+    def __init__(
+        self,
+        instance_ids: str | list[str],
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        poll_interval: int = 20,
+        max_attempts: int = 20,
+        wait_for_completion: bool = False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_ids = instance_ids
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: Context):
+        if isinstance(self.instance_ids, str):
+            self.instance_ids = [self.instance_ids]
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
+        ec2_hook.conn.terminate_instances(InstanceIds=self.instance_ids)
+
+        for instance_id in self.instance_ids:
+            self.log.info("Terminating EC2 instance %s", instance_id)
+            if self.wait_for_completion:
+                ec2_hook.get_waiter("instance_terminated").wait(
+                    InstanceIds=[instance_id],
+                    WaiterConfig={
+                        "Delay": self.poll_interval,
+                        "MaxAttempts": self.max_attempts,
+                    },
+                )
diff --git a/docs/apache-airflow-providers-amazon/operators/ec2.rst b/docs/apache-airflow-providers-amazon/operators/ec2.rst
index 37e5d804ee..5796c514e0 100644
--- a/docs/apache-airflow-providers-amazon/operators/ec2.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ec2.rst
@@ -58,6 +58,34 @@ To stop an Amazon EC2 instance you can use
     :start-after: [START howto_operator_ec2_stop_instance]
     :end-before: [END howto_operator_ec2_stop_instance]
 
+.. _howto/operator:EC2CreateInstanceOperator:
+
+Create and start an Amazon EC2 instance
+=======================================
+
+To create and start an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2CreateInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_ec2_create_instance]
+    :end-before: [END howto_operator_ec2_create_instance]
+
+.. _howto/operator:EC2TerminateInstanceOperator:
+
+Terminate an Amazon EC2 instance
+================================
+
+To terminate an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2TerminateInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_ec2_terminate_instance]
+    :end-before: [END howto_operator_ec2_terminate_instance]
+
 Sensors
 -------
 
diff --git a/tests/providers/amazon/aws/operators/test_ec2.py b/tests/providers/amazon/aws/operators/test_ec2.py
index 8fd4c535a1..adf3ffeb91 100644
--- a/tests/providers/amazon/aws/operators/test_ec2.py
+++ b/tests/providers/amazon/aws/operators/test_ec2.py
@@ -20,23 +20,121 @@ from __future__ import annotations
 from moto import mock_ec2
 
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
-from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
+from airflow.providers.amazon.aws.operators.ec2 import (
+    EC2CreateInstanceOperator,
+    EC2StartInstanceOperator,
+    EC2StopInstanceOperator,
+    EC2TerminateInstanceOperator,
+)
 
 
 class BaseEc2TestClass:
     @classmethod
-    def _create_instance(cls, hook: EC2Hook):
-        """Create Instance and return instance id."""
+    def _get_image_id(cls, hook):
+        """Get a valid image id to create an instance."""
         conn = hook.get_conn()
         try:
             ec2_client = conn.meta.client
         except AttributeError:
             ec2_client = conn
 
-        # We need existed AMI Image ID otherwise `moto` will raise DeprecationWarning.
+        # We need an existing AMI Image ID otherwise `moto` will raise DeprecationWarning.
         images = ec2_client.describe_images()["Images"]
-        response = ec2_client.run_instances(MaxCount=1, MinCount=1, ImageId=images[0]["ImageId"])
-        return response["Instances"][0]["InstanceId"]
+        return images[0]["ImageId"]
+
+
+class TestEC2CreateInstanceOperator(BaseEc2TestClass):
+    def test_init(self):
+        ec2_operator = EC2CreateInstanceOperator(
+            task_id="test_create_instance",
+            image_id="test_image_id",
+        )
+
+        assert ec2_operator.task_id == "test_create_instance"
+        assert ec2_operator.image_id == "test_image_id"
+        assert ec2_operator.max_count == 1
+        assert ec2_operator.min_count == 1
+        assert ec2_operator.max_attempts == 20
+        assert ec2_operator.poll_interval == 20
+
+    @mock_ec2
+    def test_create_instance(self):
+        ec2_hook = EC2Hook()
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id = create_instance.execute(None)
+
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"
+
+    @mock_ec2
+    def test_create_multiple_instances(self):
+        ec2_hook = EC2Hook()
+        create_instances = EC2CreateInstanceOperator(
+            task_id="test_create_multiple_instances",
+            image_id=self._get_image_id(hook=ec2_hook),
+            min_count=5,
+            max_count=5,
+        )
+        instance_ids = create_instances.execute(None)
+        assert len(instance_ids) == 5
+
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+
+class TestEC2TerminateInstanceOperator(BaseEc2TestClass):
+    def test_init(self):
+        ec2_operator = EC2TerminateInstanceOperator(
+            task_id="test_terminate_instance",
+            instance_ids="test_image_id",
+        )
+
+        assert ec2_operator.task_id == "test_terminate_instance"
+        assert ec2_operator.max_attempts == 20
+        assert ec2_operator.poll_interval == 20
+
+    @mock_ec2
+    def test_terminate_instance(self):
+        ec2_hook = EC2Hook()
+
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id = create_instance.execute(None)
+
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"
+
+        terminate_instance = EC2TerminateInstanceOperator(
+            task_id="test_terminate_instance", instance_ids=instance_id
+        )
+        terminate_instance.execute(None)
+
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "terminated"
+
+    @mock_ec2
+    def test_terminate_multiple_instances(self):
+        ec2_hook = EC2Hook()
+        create_instances = EC2CreateInstanceOperator(
+            task_id="test_create_multiple_instances",
+            image_id=self._get_image_id(hook=ec2_hook),
+            min_count=5,
+            max_count=5,
+        )
+        instance_ids = create_instances.execute(None)
+        assert len(instance_ids) == 5
+
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+        terminate_instance = EC2TerminateInstanceOperator(
+            task_id="test_terminate_instance", instance_ids=instance_ids
+        )
+        terminate_instance.execute(None)
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "terminated"
 
 
 class TestEC2StartInstanceOperator(BaseEc2TestClass):
@@ -58,16 +156,20 @@ class TestEC2StartInstanceOperator(BaseEc2TestClass):
     def test_start_instance(self):
         # create instance
         ec2_hook = EC2Hook()
-        instance_id = self._create_instance(ec2_hook)
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id = create_instance.execute(None)
 
         # start instance
         start_test = EC2StartInstanceOperator(
             task_id="start_test",
-            instance_id=instance_id,
+            instance_id=instance_id[0],
         )
         start_test.execute(None)
         # assert instance state is running
-        assert ec2_hook.get_instance_state(instance_id=instance_id) == "running"
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"
 
 
 class TestEC2StopInstanceOperator(BaseEc2TestClass):
@@ -89,13 +191,17 @@ class TestEC2StopInstanceOperator(BaseEc2TestClass):
     def test_stop_instance(self):
         # create instance
         ec2_hook = EC2Hook()
-        instance_id = self._create_instance(ec2_hook)
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id = create_instance.execute(None)
 
         # stop instance
         stop_test = EC2StopInstanceOperator(
             task_id="stop_test",
-            instance_id=instance_id,
+            instance_id=instance_id[0],
         )
         stop_test.execute(None)
         # assert instance state is running
-        assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped"
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"
diff --git a/tests/system/providers/amazon/aws/example_ec2.py b/tests/system/providers/amazon/aws/example_ec2.py
index 36bddc5548..1e58933dd0 100644
--- a/tests/system/providers/amazon/aws/example_ec2.py
+++ b/tests/system/providers/amazon/aws/example_ec2.py
@@ -24,7 +24,12 @@ import boto3
 from airflow import DAG
 from airflow.decorators import task
 from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
+from airflow.providers.amazon.aws.operators.ec2 import (
+    EC2CreateInstanceOperator,
+    EC2StartInstanceOperator,
+    EC2StopInstanceOperator,
+    EC2TerminateInstanceOperator,
+)
 from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor
 from airflow.utils.trigger_rule import TriggerRule
 from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
@@ -34,7 +39,8 @@ DAG_ID = "example_ec2"
 sys_test_context_task = SystemTestContextBuilder().build()
 
 
-def _get_latest_ami_id():
+@task
+def get_latest_ami_id():
     """Returns the AMI ID of the most recently-created Amazon Linux image"""
 
     # Amazon is retiring AL2 in 2023 and replacing it with Amazon Linux 2022.
@@ -62,40 +68,16 @@ def create_key_pair(key_name: str):
     return key_pair_id
 
 
-@task
-def create_instance(instance_name: str, key_pair_id: str):
-    client = boto3.client("ec2")
-
-    # Create the instance
-    instance_id = client.run_instances(
-        ImageId=_get_latest_ami_id(),
-        MinCount=1,
-        MaxCount=1,
-        InstanceType="t2.micro",
-        KeyName=key_pair_id,
-        TagSpecifications=[{"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": instance_name}]}],
-        # Use IMDSv2 for greater security, see the following doc for more details:
-        # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
-        MetadataOptions={"HttpEndpoint": "enabled", "HttpTokens": "required"},
-    )["Instances"][0]["InstanceId"]
-
-    # Wait for it to exist
-    waiter = client.get_waiter("instance_status_ok")
-    waiter.wait(InstanceIds=[instance_id])
-
-    return instance_id
-
-
-@task(trigger_rule=TriggerRule.ALL_DONE)
-def terminate_instance(instance: str):
-    boto3.client("ec2").terminate_instances(InstanceIds=[instance])
-
-
 @task(trigger_rule=TriggerRule.ALL_DONE)
 def delete_key_pair(key_pair_id: str):
     boto3.client("ec2").delete_key_pair(KeyName=key_pair_id)
 
 
+@task
+def parse_response(instance_ids: list):
+    return instance_ids[0]
+
+
 with DAG(
     dag_id=DAG_ID,
     schedule="@once",
@@ -105,9 +87,43 @@ with DAG(
 ) as dag:
     test_context = sys_test_context_task()
     env_id = test_context[ENV_ID_KEY]
-
+    instance_name = f"{env_id}-instance"
     key_name = create_key_pair(key_name=f"{env_id}_key_pair")
-    instance_id = create_instance(instance_name=f"{env_id}-instance", key_pair_id=key_name)
+    image_id = get_latest_ami_id()
+
+    config = {
+        "InstanceType": "t2.micro",
+        "KeyName": key_name,
+        "TagSpecifications": [
+            {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": instance_name}]}
+        ],
+        # Use IMDSv2 for greater security, see the following doc for more details:
+        # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
+        "MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens": "required"},
+    }
+
+    # EC2CreateInstanceOperator creates and starts the EC2 instances. To test the EC2StartInstanceOperator,
+    # we will stop the instance, then start them again before terminating them.
+
+    # [START howto_operator_ec2_create_instance]
+    create_instance = EC2CreateInstanceOperator(
+        task_id="create_instance",
+        image_id=image_id,
+        max_count=1,
+        min_count=1,
+        config=config,
+    )
+    # [END howto_operator_ec2_create_instance]
+    create_instance.wait_for_completion = True
+    instance_id = parse_response(create_instance.output)
+    # [START howto_operator_ec2_stop_instance]
+    stop_instance = EC2StopInstanceOperator(
+        task_id="stop_instance",
+        instance_id=instance_id,
+    )
+    # [END howto_operator_ec2_stop_instance]
+    stop_instance.trigger_rule = TriggerRule.ALL_DONE
+
     # [START howto_operator_ec2_start_instance]
     start_instance = EC2StartInstanceOperator(
         task_id="start_instance",
@@ -123,25 +139,27 @@ with DAG(
     )
     # [END howto_sensor_ec2_instance_state]
 
-    # [START howto_operator_ec2_stop_instance]
-    stop_instance = EC2StopInstanceOperator(
-        task_id="stop_instance",
-        instance_id=instance_id,
+    # [START howto_operator_ec2_terminate_instance]
+    terminate_instance = EC2TerminateInstanceOperator(
+        task_id="terminate_instance",
+        instance_ids=instance_id,
+        wait_for_completion=True,
     )
-    # [END howto_operator_ec2_stop_instance]
-    stop_instance.trigger_rule = TriggerRule.ALL_DONE
-
+    # [END howto_operator_ec2_terminate_instance]
+    terminate_instance.trigger_rule = TriggerRule.ALL_DONE
     chain(
         # TEST SETUP
         test_context,
         key_name,
-        instance_id,
+        image_id,
         # TEST BODY
+        create_instance,
+        instance_id,
+        stop_instance,
         start_instance,
         await_instance,
-        stop_instance,
+        terminate_instance,
         # TEST TEARDOWN
-        terminate_instance(instance_id),
         delete_key_pair(key_name),
     )