You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jo...@apache.org on 2023/03/07 22:22:11 UTC
[airflow] branch main updated: Add `EC2CreateInstanceOperator`, `EC2TerminateInstanceOperator` (#29548)
This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d2cc9df82c Add `EC2CreateInstanceOperator`, `EC2TerminateInstanceOperator` (#29548)
d2cc9df82c is described below
commit d2cc9df82c8b6ae6cccb51462b8b5a37155666a7
Author: Syed Hussaain <10...@users.noreply.github.com>
AuthorDate: Tue Mar 7 14:22:02 2023 -0800
Add `EC2CreateInstanceOperator`, `EC2TerminateInstanceOperator` (#29548)
* Add EC2CreateInstanceOperator and EC2TerminteInstanceOperator
Change system test to use the new operators
Add unit tests for new operators
* Add support for multiple ids to EC2TerminateInstanceOperator
Change system test to terminate without stopping instances
* Fix failing tests for terminate operator
* Update doc strings to add that the operators can create/terminate multiple instances
Add tests for creating/terminating multiple instances
* Fix system test so it passes
Fix doc string on EC2TerminateInstanceOperator
---------
Co-authored-by: syedahsn <sy...@ud74d1a752d7e5b.ant.amazon.com>
---
airflow/providers/amazon/aws/operators/ec2.py | 138 +++++++++++++++++++++
.../operators/ec2.rst | 28 +++++
tests/providers/amazon/aws/operators/test_ec2.py | 130 +++++++++++++++++--
tests/system/providers/amazon/aws/example_ec2.py | 104 +++++++++-------
4 files changed, 345 insertions(+), 55 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/ec2.py b/airflow/providers/amazon/aws/operators/ec2.py
index 60cb43a32d..5f6b76ce15 100644
--- a/airflow/providers/amazon/aws/operators/ec2.py
+++ b/airflow/providers/amazon/aws/operators/ec2.py
@@ -116,3 +116,141 @@ class EC2StopInstanceOperator(BaseOperator):
target_state="stopped",
check_interval=self.check_interval,
)
+
+
+class EC2CreateInstanceOperator(BaseOperator):
+ """
+ Create and start a specified number of EC2 Instances using boto3
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EC2CreateInstanceOperator`
+
+ :param image_id: ID of the AMI used to create the instance.
+ :param max_count: Maximum number of instances to launch. Defaults to 1.
+ :param min_count: Minimum number of instances to launch. Defaults to 1.
+ :param aws_conn_id: AWS connection to use
+ :param region_name: AWS region name associated with the client.
+ :param poll_interval: Number of seconds to wait before attempting to
+ check state of instance. Only used if wait_for_completion is True. Default is 20.
+ :param max_attempts: Maximum number of attempts when checking state of instance.
+ Only used if wait_for_completion is True. Default is 20.
+ :param config: Dictionary for arbitrary parameters to the boto3 run_instances call.
+ :param wait_for_completion: If True, the operator will wait for the instance to be
+ in the `running` state before returning.
+ """
+
+ template_fields: Sequence[str] = (
+ "image_id",
+ "max_count",
+ "min_count",
+ "aws_conn_id",
+ "region_name",
+ "config",
+ "wait_for_completion",
+ )
+
+ def __init__(
+ self,
+ image_id: str,
+ max_count: int = 1,
+ min_count: int = 1,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ poll_interval: int = 20,
+ max_attempts: int = 20,
+ config: dict | None = None,
+ wait_for_completion: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.image_id = image_id
+ self.max_count = max_count
+ self.min_count = min_count
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.config = config or {}
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context):
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
+ instances = ec2_hook.conn.run_instances(
+ ImageId=self.image_id,
+ MinCount=self.min_count,
+ MaxCount=self.max_count,
+ **self.config,
+ )["Instances"]
+ instance_ids = []
+ for instance in instances:
+ instance_ids.append(instance["InstanceId"])
+ self.log.info("Created EC2 instance %s", instance["InstanceId"])
+
+ if self.wait_for_completion:
+ ec2_hook.get_waiter("instance_running").wait(
+ InstanceIds=[instance["InstanceId"]],
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ },
+ )
+
+ return instance_ids
+
+
+class EC2TerminateInstanceOperator(BaseOperator):
+ """
+ Terminate EC2 Instances using boto3
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EC2TerminateInstanceOperator`
+
+ :param instance_id: ID of the instance to be terminated.
+ :param aws_conn_id: AWS connection to use
+ :param region_name: AWS region name associated with the client.
+ :param poll_interval: Number of seconds to wait before attempting to
+ check state of instance. Only used if wait_for_completion is True. Default is 20.
+ :param max_attempts: Maximum number of attempts when checking state of instance.
+ Only used if wait_for_completion is True. Default is 20.
+ :param wait_for_completion: If True, the operator will wait for the instance to be
+ in the `terminated` state before returning.
+ """
+
+ template_fields: Sequence[str] = ("instance_ids", "region_name", "aws_conn_id", "wait_for_completion")
+
+ def __init__(
+ self,
+ instance_ids: str | list[str],
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ poll_interval: int = 20,
+ max_attempts: int = 20,
+ wait_for_completion: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.instance_ids = instance_ids
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context):
+ if isinstance(self.instance_ids, str):
+ self.instance_ids = [self.instance_ids]
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
+ ec2_hook.conn.terminate_instances(InstanceIds=self.instance_ids)
+
+ for instance_id in self.instance_ids:
+ self.log.info("Terminating EC2 instance %s", instance_id)
+ if self.wait_for_completion:
+ ec2_hook.get_waiter("instance_terminated").wait(
+ InstanceIds=[instance_id],
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ },
+ )
diff --git a/docs/apache-airflow-providers-amazon/operators/ec2.rst b/docs/apache-airflow-providers-amazon/operators/ec2.rst
index 37e5d804ee..5796c514e0 100644
--- a/docs/apache-airflow-providers-amazon/operators/ec2.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ec2.rst
@@ -58,6 +58,34 @@ To stop an Amazon EC2 instance you can use
:start-after: [START howto_operator_ec2_stop_instance]
:end-before: [END howto_operator_ec2_stop_instance]
+.. _howto/operator:EC2CreateInstanceOperator:
+
+Create and start an Amazon EC2 instance
+=======================================
+
+To create and start an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2CreateInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_ec2_create_instance]
+ :end-before: [END howto_operator_ec2_create_instance]
+
+.. _howto/operator:EC2TerminateInstanceOperator:
+
+Terminate an Amazon EC2 instance
+================================
+
+To terminate an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2TerminateInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_ec2_terminate_instance]
+ :end-before: [END howto_operator_ec2_terminate_instance]
+
Sensors
-------
diff --git a/tests/providers/amazon/aws/operators/test_ec2.py b/tests/providers/amazon/aws/operators/test_ec2.py
index 8fd4c535a1..adf3ffeb91 100644
--- a/tests/providers/amazon/aws/operators/test_ec2.py
+++ b/tests/providers/amazon/aws/operators/test_ec2.py
@@ -20,23 +20,121 @@ from __future__ import annotations
from moto import mock_ec2
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
-from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
+from airflow.providers.amazon.aws.operators.ec2 import (
+ EC2CreateInstanceOperator,
+ EC2StartInstanceOperator,
+ EC2StopInstanceOperator,
+ EC2TerminateInstanceOperator,
+)
class BaseEc2TestClass:
@classmethod
- def _create_instance(cls, hook: EC2Hook):
- """Create Instance and return instance id."""
+ def _get_image_id(cls, hook):
+ """Get a valid image id to create an instance."""
conn = hook.get_conn()
try:
ec2_client = conn.meta.client
except AttributeError:
ec2_client = conn
- # We need existed AMI Image ID otherwise `moto` will raise DeprecationWarning.
+ # We need an existing AMI Image ID otherwise `moto` will raise DeprecationWarning.
images = ec2_client.describe_images()["Images"]
- response = ec2_client.run_instances(MaxCount=1, MinCount=1, ImageId=images[0]["ImageId"])
- return response["Instances"][0]["InstanceId"]
+ return images[0]["ImageId"]
+
+
+class TestEC2CreateInstanceOperator(BaseEc2TestClass):
+ def test_init(self):
+ ec2_operator = EC2CreateInstanceOperator(
+ task_id="test_create_instance",
+ image_id="test_image_id",
+ )
+
+ assert ec2_operator.task_id == "test_create_instance"
+ assert ec2_operator.image_id == "test_image_id"
+ assert ec2_operator.max_count == 1
+ assert ec2_operator.min_count == 1
+ assert ec2_operator.max_attempts == 20
+ assert ec2_operator.poll_interval == 20
+
+ @mock_ec2
+ def test_create_instance(self):
+ ec2_hook = EC2Hook()
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id = create_instance.execute(None)
+
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"
+
+ @mock_ec2
+ def test_create_multiple_instances(self):
+ ec2_hook = EC2Hook()
+ create_instances = EC2CreateInstanceOperator(
+ task_id="test_create_multiple_instances",
+ image_id=self._get_image_id(hook=ec2_hook),
+ min_count=5,
+ max_count=5,
+ )
+ instance_ids = create_instances.execute(None)
+ assert len(instance_ids) == 5
+
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+
+class TestEC2TerminateInstanceOperator(BaseEc2TestClass):
+ def test_init(self):
+ ec2_operator = EC2TerminateInstanceOperator(
+ task_id="test_terminate_instance",
+ instance_ids="test_image_id",
+ )
+
+ assert ec2_operator.task_id == "test_terminate_instance"
+ assert ec2_operator.max_attempts == 20
+ assert ec2_operator.poll_interval == 20
+
+ @mock_ec2
+ def test_terminate_instance(self):
+ ec2_hook = EC2Hook()
+
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id = create_instance.execute(None)
+
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"
+
+ terminate_instance = EC2TerminateInstanceOperator(
+ task_id="test_terminate_instance", instance_ids=instance_id
+ )
+ terminate_instance.execute(None)
+
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "terminated"
+
+ @mock_ec2
+ def test_terminate_multiple_instances(self):
+ ec2_hook = EC2Hook()
+ create_instances = EC2CreateInstanceOperator(
+ task_id="test_create_multiple_instances",
+ image_id=self._get_image_id(hook=ec2_hook),
+ min_count=5,
+ max_count=5,
+ )
+ instance_ids = create_instances.execute(None)
+ assert len(instance_ids) == 5
+
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+ terminate_instance = EC2TerminateInstanceOperator(
+ task_id="test_terminate_instance", instance_ids=instance_ids
+ )
+ terminate_instance.execute(None)
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "terminated"
class TestEC2StartInstanceOperator(BaseEc2TestClass):
@@ -58,16 +156,20 @@ class TestEC2StartInstanceOperator(BaseEc2TestClass):
def test_start_instance(self):
# create instance
ec2_hook = EC2Hook()
- instance_id = self._create_instance(ec2_hook)
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id = create_instance.execute(None)
# start instance
start_test = EC2StartInstanceOperator(
task_id="start_test",
- instance_id=instance_id,
+ instance_id=instance_id[0],
)
start_test.execute(None)
# assert instance state is running
- assert ec2_hook.get_instance_state(instance_id=instance_id) == "running"
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"
class TestEC2StopInstanceOperator(BaseEc2TestClass):
@@ -89,13 +191,17 @@ class TestEC2StopInstanceOperator(BaseEc2TestClass):
def test_stop_instance(self):
# create instance
ec2_hook = EC2Hook()
- instance_id = self._create_instance(ec2_hook)
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id = create_instance.execute(None)
# stop instance
stop_test = EC2StopInstanceOperator(
task_id="stop_test",
- instance_id=instance_id,
+ instance_id=instance_id[0],
)
stop_test.execute(None)
# assert instance state is running
- assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped"
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"
diff --git a/tests/system/providers/amazon/aws/example_ec2.py b/tests/system/providers/amazon/aws/example_ec2.py
index 36bddc5548..1e58933dd0 100644
--- a/tests/system/providers/amazon/aws/example_ec2.py
+++ b/tests/system/providers/amazon/aws/example_ec2.py
@@ -24,7 +24,12 @@ import boto3
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
+from airflow.providers.amazon.aws.operators.ec2 import (
+ EC2CreateInstanceOperator,
+ EC2StartInstanceOperator,
+ EC2StopInstanceOperator,
+ EC2TerminateInstanceOperator,
+)
from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
@@ -34,7 +39,8 @@ DAG_ID = "example_ec2"
sys_test_context_task = SystemTestContextBuilder().build()
-def _get_latest_ami_id():
+@task
+def get_latest_ami_id():
"""Returns the AMI ID of the most recently-created Amazon Linux image"""
# Amazon is retiring AL2 in 2023 and replacing it with Amazon Linux 2022.
@@ -62,40 +68,16 @@ def create_key_pair(key_name: str):
return key_pair_id
-@task
-def create_instance(instance_name: str, key_pair_id: str):
- client = boto3.client("ec2")
-
- # Create the instance
- instance_id = client.run_instances(
- ImageId=_get_latest_ami_id(),
- MinCount=1,
- MaxCount=1,
- InstanceType="t2.micro",
- KeyName=key_pair_id,
- TagSpecifications=[{"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": instance_name}]}],
- # Use IMDSv2 for greater security, see the following doc for more details:
- # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
- MetadataOptions={"HttpEndpoint": "enabled", "HttpTokens": "required"},
- )["Instances"][0]["InstanceId"]
-
- # Wait for it to exist
- waiter = client.get_waiter("instance_status_ok")
- waiter.wait(InstanceIds=[instance_id])
-
- return instance_id
-
-
-@task(trigger_rule=TriggerRule.ALL_DONE)
-def terminate_instance(instance: str):
- boto3.client("ec2").terminate_instances(InstanceIds=[instance])
-
-
@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_key_pair(key_pair_id: str):
boto3.client("ec2").delete_key_pair(KeyName=key_pair_id)
+@task
+def parse_response(instance_ids: list):
+ return instance_ids[0]
+
+
with DAG(
dag_id=DAG_ID,
schedule="@once",
@@ -105,9 +87,43 @@ with DAG(
) as dag:
test_context = sys_test_context_task()
env_id = test_context[ENV_ID_KEY]
-
+ instance_name = f"{env_id}-instance"
key_name = create_key_pair(key_name=f"{env_id}_key_pair")
- instance_id = create_instance(instance_name=f"{env_id}-instance", key_pair_id=key_name)
+ image_id = get_latest_ami_id()
+
+ config = {
+ "InstanceType": "t2.micro",
+ "KeyName": key_name,
+ "TagSpecifications": [
+ {"ResourceType": "instance", "Tags": [{"Key": "Name", "Value": instance_name}]}
+ ],
+ # Use IMDSv2 for greater security, see the following doc for more details:
+ # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
+ "MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens": "required"},
+ }
+
+ # EC2CreateInstanceOperator creates and starts the EC2 instances. To test the EC2StartInstanceOperator,
+ # we will stop the instance, then start them again before terminating them.
+
+ # [START howto_operator_ec2_create_instance]
+ create_instance = EC2CreateInstanceOperator(
+ task_id="create_instance",
+ image_id=image_id,
+ max_count=1,
+ min_count=1,
+ config=config,
+ )
+ # [END howto_operator_ec2_create_instance]
+ create_instance.wait_for_completion = True
+ instance_id = parse_response(create_instance.output)
+ # [START howto_operator_ec2_stop_instance]
+ stop_instance = EC2StopInstanceOperator(
+ task_id="stop_instance",
+ instance_id=instance_id,
+ )
+ # [END howto_operator_ec2_stop_instance]
+ stop_instance.trigger_rule = TriggerRule.ALL_DONE
+
# [START howto_operator_ec2_start_instance]
start_instance = EC2StartInstanceOperator(
task_id="start_instance",
@@ -123,25 +139,27 @@ with DAG(
)
# [END howto_sensor_ec2_instance_state]
- # [START howto_operator_ec2_stop_instance]
- stop_instance = EC2StopInstanceOperator(
- task_id="stop_instance",
- instance_id=instance_id,
+ # [START howto_operator_ec2_terminate_instance]
+ terminate_instance = EC2TerminateInstanceOperator(
+ task_id="terminate_instance",
+ instance_ids=instance_id,
+ wait_for_completion=True,
)
- # [END howto_operator_ec2_stop_instance]
- stop_instance.trigger_rule = TriggerRule.ALL_DONE
-
+ # [END howto_operator_ec2_terminate_instance]
+ terminate_instance.trigger_rule = TriggerRule.ALL_DONE
chain(
# TEST SETUP
test_context,
key_name,
- instance_id,
+ image_id,
# TEST BODY
+ create_instance,
+ instance_id,
+ stop_instance,
start_instance,
await_instance,
- stop_instance,
+ terminate_instance,
# TEST TEARDOWN
- terminate_instance(instance_id),
delete_key_pair(key_name),
)