You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/06/28 19:32:54 UTC
[airflow] branch main updated: Add AWS operators to create and delete RDS Database (#24099)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bf727525e1 Add AWS operators to create and delete RDS Database (#24099)
bf727525e1 is described below
commit bf727525e1fd777e51cc8bc17285f6093277fdef
Author: Eugene Karimov <13...@users.noreply.github.com>
AuthorDate: Tue Jun 28 21:32:17 2022 +0200
Add AWS operators to create and delete RDS Database (#24099)
* Add RdsCreateDbInstanceOperator
* Add RdsDeleteDbInstanceOperator
---
.../amazon/aws/example_dags/example_dms.py | 129 +++++++++------------
airflow/providers/amazon/aws/operators/rds.py | 106 ++++++++++++++++-
.../operators/rds.rst | 28 +++++
tests/providers/amazon/aws/operators/test_rds.py | 68 +++++++++++
tests/system/providers/amazon/aws/rds/__init__.py | 17 +++
.../amazon/aws/rds/example_rds_instance.py | 74 ++++++++++++
6 files changed, 350 insertions(+), 72 deletions(-)
diff --git a/airflow/providers/amazon/aws/example_dags/example_dms.py b/airflow/providers/amazon/aws/example_dags/example_dms.py
index caffe44353..46e97d92a5 100644
--- a/airflow/providers/amazon/aws/example_dags/example_dms.py
+++ b/airflow/providers/amazon/aws/example_dags/example_dms.py
@@ -38,6 +38,10 @@ from airflow.providers.amazon.aws.operators.dms import (
DmsStartTaskOperator,
DmsStopTaskOperator,
)
+from airflow.providers.amazon.aws.operators.rds import (
+ RdsCreateDbInstanceOperator,
+ RdsDeleteDbInstanceOperator,
+)
from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor
S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name')
@@ -109,29 +113,20 @@ TABLE_MAPPINGS = {
}
-def _create_rds_instance():
- print('Creating RDS Instance.')
-
+def _get_rds_instance_endpoint():
+ print('Retrieving RDS instance endpoint.')
rds_client = boto3.client('rds')
- rds_client.create_db_instance(
- DBName=RDS_DB_NAME,
- DBInstanceIdentifier=RDS_INSTANCE_NAME,
- AllocatedStorage=20,
- DBInstanceClass='db.t3.micro',
- Engine=RDS_ENGINE,
- MasterUsername=RDS_USERNAME,
- MasterUserPassword=RDS_PASSWORD,
- )
-
- rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)
response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME)
- return response['DBInstances'][0]['Endpoint']
+ rds_instance_endpoint = response['DBInstances'][0]['Endpoint']
+ return rds_instance_endpoint
-def _create_rds_table(rds_endpoint):
- print('Creating table.')
+@task
+def create_sample_table():
+ print('Creating sample table.')
+ rds_endpoint = _get_rds_instance_endpoint()
hostname = rds_endpoint['Address']
port = rds_endpoint['Port']
rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}'
@@ -154,7 +149,13 @@ def _create_rds_table(rds_endpoint):
connection.execute(table.select())
-def _create_dms_replication_instance(ti, dms_client):
+@task
+def create_dms_assets():
+ print('Creating DMS assets.')
+ ti = get_current_context()['ti']
+ dms_client = boto3.client('dms')
+ rds_instance_endpoint = _get_rds_instance_endpoint()
+
print('Creating replication instance.')
instance_arn = dms_client.create_replication_instance(
ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME,
@@ -162,10 +163,7 @@ def _create_dms_replication_instance(ti, dms_client):
)['ReplicationInstance']['ReplicationInstanceArn']
ti.xcom_push(key='replication_instance_arn', value=instance_arn)
- return instance_arn
-
-def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
print('Creating DMS source endpoint.')
source_endpoint_arn = dms_client.create_endpoint(
EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER,
@@ -194,28 +192,16 @@ def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn)
ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn)
-
-def _await_setup_assets(dms_client, instance_arn):
- print("Awaiting asset provisioning.")
+ print("Awaiting replication instance provisioning.")
dms_client.get_waiter('replication_instance_available').wait(
Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}]
)
-def _delete_rds_instance():
- print('Deleting RDS Instance.')
-
- rds_client = boto3.client('rds')
- rds_client.delete_db_instance(
- DBInstanceIdentifier=RDS_INSTANCE_NAME,
- SkipFinalSnapshot=True,
- )
-
- rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)
-
-
-def _delete_dms_assets(dms_client):
+@task(trigger_rule='all_done')
+def delete_dms_assets():
ti = get_current_context()['ti']
+ dms_client = boto3.client('dms')
replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
source_arn = ti.xcom_pull(key='source_endpoint_arn')
target_arn = ti.xcom_pull(key='target_endpoint_arn')
@@ -225,13 +211,10 @@ def _delete_dms_assets(dms_client):
dms_client.delete_endpoint(EndpointArn=source_arn)
dms_client.delete_endpoint(EndpointArn=target_arn)
-
-def _await_all_teardowns(dms_client):
- print('Awaiting tear-down.')
+ print('Awaiting DMS assets tear-down.')
dms_client.get_waiter('replication_instance_deleted').wait(
Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}]
)
-
dms_client.get_waiter('endpoint_deleted').wait(
Filters=[
{
@@ -242,27 +225,6 @@ def _await_all_teardowns(dms_client):
)
-@task
-def set_up():
- ti = get_current_context()['ti']
- dms_client = boto3.client('dms')
-
- rds_instance_endpoint = _create_rds_instance()
- _create_rds_table(rds_instance_endpoint)
- instance_arn = _create_dms_replication_instance(ti, dms_client)
- _create_dms_endpoints(ti, dms_client, rds_instance_endpoint)
- _await_setup_assets(dms_client, instance_arn)
-
-
-@task(trigger_rule='all_done')
-def clean_up():
- dms_client = boto3.client('dms')
-
- _delete_rds_instance()
- _delete_dms_assets(dms_client)
- _await_all_teardowns(dms_client)
-
-
with DAG(
dag_id='example_dms',
schedule_interval=None,
@@ -271,6 +233,19 @@ with DAG(
catchup=False,
) as dag:
+ create_db_instance = RdsCreateDbInstanceOperator(
+ task_id="create_db_instance",
+ db_instance_identifier=RDS_INSTANCE_NAME,
+ db_instance_class='db.t3.micro',
+ engine=RDS_ENGINE,
+ rds_kwargs={
+ "DBName": RDS_DB_NAME,
+ "AllocatedStorage": 20,
+ "MasterUsername": RDS_USERNAME,
+ "MasterUserPassword": RDS_PASSWORD,
+ },
+ )
+
# [START howto_operator_dms_create_task]
create_task = DmsCreateTaskOperator(
task_id='create_task',
@@ -334,14 +309,26 @@ with DAG(
)
# [END howto_operator_dms_delete_task]
+ delete_db_instance = RdsDeleteDbInstanceOperator(
+ task_id='delete_db_instance',
+ db_instance_identifier=RDS_INSTANCE_NAME,
+ rds_kwargs={
+ "SkipFinalSnapshot": True,
+ },
+ trigger_rule='all_done',
+ )
+
chain(
- set_up()
- >> create_task
- >> start_task
- >> describe_tasks
- >> await_task_start
- >> stop_task
- >> await_task_stop
- >> delete_task
- >> clean_up()
+ create_db_instance,
+ create_sample_table(),
+ create_dms_assets(),
+ create_task,
+ start_task,
+ describe_tasks,
+ await_task_start,
+ stop_task,
+ await_task_stop,
+ delete_task,
+ delete_dms_assets(),
+ delete_db_instance,
)
diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py
index a527107e80..fe38bfed69 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -18,7 +18,7 @@
import json
import time
-from typing import TYPE_CHECKING, List, Optional, Sequence
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
from mypy_boto3_rds.type_defs import TagTypeDef
@@ -551,6 +551,108 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
return json.dumps(delete_subscription, default=str)
+class RdsCreateDbInstanceOperator(RdsBaseOperator):
+ """
+ Creates an RDS DB instance
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsCreateDbInstanceOperator`
+
+ :param db_instance_identifier: The DB instance identifier, must start with a letter and
+ contain from 1 to 63 letters, numbers, or hyphens
+ :param db_instance_class: The compute and memory capacity of the DB instance, for example db.m5.large
+ :param engine: The name of the database engine to be used for this instance
+ :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param wait_for_completion: Whether or not wait for creation of the DB instance to
+ complete. (default: True)
+ """
+
+ def __init__(
+ self,
+ *,
+ db_instance_identifier: str,
+ db_instance_class: str,
+ engine: str,
+ rds_kwargs: Optional[Dict] = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.db_instance_identifier = db_instance_identifier
+ self.db_instance_class = db_instance_class
+ self.engine = engine
+ self.rds_kwargs = rds_kwargs or {}
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info("Creating new DB instance %s", self.db_instance_identifier)
+
+ create_db_instance = self.hook.conn.create_db_instance(
+ DBInstanceIdentifier=self.db_instance_identifier,
+ DBInstanceClass=self.db_instance_class,
+ Engine=self.engine,
+ **self.rds_kwargs,
+ )
+
+ if self.wait_for_completion:
+ self.hook.conn.get_waiter("db_instance_available").wait(
+ DBInstanceIdentifier=self.db_instance_identifier
+ )
+
+ return json.dumps(create_db_instance, default=str)
+
+
+class RdsDeleteDbInstanceOperator(RdsBaseOperator):
+ """
+ Deletes an RDS DB Instance
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsDeleteDbInstanceOperator`
+
+ :param db_instance_identifier: The DB instance identifier for the DB instance to be deleted
+ :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param wait_for_completion: Whether or not wait for deletion of the DB instance to
+ complete. (default: True)
+ """
+
+ def __init__(
+ self,
+ *,
+ db_instance_identifier: str,
+ rds_kwargs: Optional[Dict] = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_instance_identifier = db_instance_identifier
+ self.rds_kwargs = rds_kwargs or {}
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info("Deleting DB instance %s", self.db_instance_identifier)
+
+ delete_db_instance = self.hook.conn.delete_db_instance(
+ DBInstanceIdentifier=self.db_instance_identifier,
+ **self.rds_kwargs,
+ )
+
+ if self.wait_for_completion:
+ self.hook.conn.get_waiter("db_instance_deleted").wait(
+ DBInstanceIdentifier=self.db_instance_identifier
+ )
+
+ return json.dumps(delete_db_instance, default=str)
+
+
__all__ = [
"RdsCreateDbSnapshotOperator",
"RdsCopyDbSnapshotOperator",
@@ -559,4 +661,6 @@ __all__ = [
"RdsDeleteEventSubscriptionOperator",
"RdsStartExportTaskOperator",
"RdsCancelExportTaskOperator",
+ "RdsCreateDbInstanceOperator",
+ "RdsDeleteDbInstanceOperator",
]
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst b/docs/apache-airflow-providers-amazon/operators/rds.rst
index 7e48a983f4..434804d910 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -138,6 +138,34 @@ To delete an Amazon RDS event subscription you can use
:start-after: [START howto_operator_rds_delete_event_subscription]
:end-before: [END howto_operator_rds_delete_event_subscription]
+.. _howto/operator:RdsCreateDbInstanceOperator:
+
+Create a database instance
+==========================
+
+To create a AWS DB instance you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_rds_create_db_instance]
+ :end-before: [END howto_operator_rds_create_db_instance]
+
+.. _howto/operator:RDSDeleteDbInstanceOperator:
+
+Delete a database instance
+==========================
+
+To delete a AWS DB instance you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_rds_delete_db_instance]
+ :end-before: [END howto_operator_rds_delete_db_instance]
+
Sensors
-------
diff --git a/tests/providers/amazon/aws/operators/test_rds.py b/tests/providers/amazon/aws/operators/test_rds.py
index d952fbc11a..355ff529c9 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -26,8 +26,10 @@ from airflow.providers.amazon.aws.operators.rds import (
RdsBaseOperator,
RdsCancelExportTaskOperator,
RdsCopyDbSnapshotOperator,
+ RdsCreateDbInstanceOperator,
RdsCreateDbSnapshotOperator,
RdsCreateEventSubscriptionOperator,
+ RdsDeleteDbInstanceOperator,
RdsDeleteDbSnapshotOperator,
RdsDeleteEventSubscriptionOperator,
RdsStartExportTaskOperator,
@@ -450,3 +452,69 @@ class TestRdsDeleteEventSubscriptionOperator:
with pytest.raises(self.hook.conn.exceptions.ClientError):
self.hook.conn.describe_event_subscriptions(SubscriptionName=EXPORT_TASK_NAME)
+
+
+@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
+class TestRdsCreateDbInstanceOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds
+ def test_create_db_instance(self):
+ create_db_instance_operator = RdsCreateDbInstanceOperator(
+ task_id='test_create_db_instance',
+ db_instance_identifier=DB_INSTANCE_NAME,
+ db_instance_class="db.m5.large",
+ engine="postgres",
+ rds_kwargs={
+ "DBName": DB_INSTANCE_NAME,
+ },
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ create_db_instance_operator.execute(None)
+
+ result = self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+ db_instances = result.get("DBInstances")
+
+ assert db_instances
+ assert len(db_instances) == 1
+ assert db_instances[0]['DBInstanceStatus'] == 'available'
+
+
+@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
+class TestRdsDeleteDbInstanceOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds
+ def test_delete_event_subscription(self):
+ _create_db_instance(self.hook)
+
+ delete_db_instance_operator = RdsDeleteDbInstanceOperator(
+ task_id='test_delete_db_instance',
+ db_instance_identifier=DB_INSTANCE_NAME,
+ rds_kwargs={
+ "SkipFinalSnapshot": True,
+ },
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ delete_db_instance_operator.execute(None)
+
+ with pytest.raises(self.hook.conn.exceptions.ClientError):
+ self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
diff --git a/tests/system/providers/amazon/aws/rds/__init__.py b/tests/system/providers/amazon/aws/rds/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/system/providers/amazon/aws/rds/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/amazon/aws/rds/example_rds_instance.py b/tests/system/providers/amazon/aws/rds/example_rds_instance.py
new file mode 100644
index 0000000000..da157dbdf9
--- /dev/null
+++ b/tests/system/providers/amazon/aws/rds/example_rds_instance.py
@@ -0,0 +1,74 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.rds import (
+ RdsCreateDbInstanceOperator,
+ RdsDeleteDbInstanceOperator,
+)
+from tests.system.providers.amazon.aws.utils import set_env_id
+
+ENV_ID = set_env_id()
+DAG_ID = "example_rds_instance"
+
+RDS_DB_IDENTIFIER = f'{ENV_ID}-database'
+RDS_USERNAME = 'database_username'
+# NEVER store your production password in plaintext in a DAG like this.
+# Use Airflow Secrets or a secret manager for this in production.
+RDS_PASSWORD = 'database_password'
+
+with DAG(
+ dag_id=DAG_ID,
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+ catchup=False,
+) as dag:
+ # [START howto_operator_rds_create_db_instance]
+ create_db_instance = RdsCreateDbInstanceOperator(
+ task_id='create_db_instance',
+ db_instance_identifier=RDS_DB_IDENTIFIER,
+ db_instance_class="db.m5.large",
+ engine="postgres",
+ rds_kwargs={
+ "MasterUsername": RDS_USERNAME,
+ "MasterUserPassword": RDS_PASSWORD,
+ "AllocatedStorage": 20,
+ },
+ )
+ # [END howto_operator_rds_create_db_instance]
+
+ # [START howto_operator_rds_delete_db_instance]
+ delete_db_instance = RdsDeleteDbInstanceOperator(
+ task_id='delete_db_instance',
+ db_instance_identifier=RDS_DB_IDENTIFIER,
+ rds_kwargs={
+ "SkipFinalSnapshot": True,
+ },
+ )
+ # [END howto_operator_rds_delete_db_instance]
+
+ chain(create_db_instance, delete_db_instance)
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)