You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/06/28 19:32:54 UTC

[airflow] branch main updated: Add AWS operators to create and delete RDS Database (#24099)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bf727525e1 Add AWS operators to create and delete RDS Database (#24099)
bf727525e1 is described below

commit bf727525e1fd777e51cc8bc17285f6093277fdef
Author: Eugene Karimov <13...@users.noreply.github.com>
AuthorDate: Tue Jun 28 21:32:17 2022 +0200

    Add AWS operators to create and delete RDS Database (#24099)
    
    * Add RdsCreateDbInstanceOperator
    
    * Add RdsDeleteDbInstanceOperator
---
 .../amazon/aws/example_dags/example_dms.py         | 129 +++++++++------------
 airflow/providers/amazon/aws/operators/rds.py      | 106 ++++++++++++++++-
 .../operators/rds.rst                              |  28 +++++
 tests/providers/amazon/aws/operators/test_rds.py   |  68 +++++++++++
 tests/system/providers/amazon/aws/rds/__init__.py  |  17 +++
 .../amazon/aws/rds/example_rds_instance.py         |  74 ++++++++++++
 6 files changed, 350 insertions(+), 72 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_dms.py b/airflow/providers/amazon/aws/example_dags/example_dms.py
index caffe44353..46e97d92a5 100644
--- a/airflow/providers/amazon/aws/example_dags/example_dms.py
+++ b/airflow/providers/amazon/aws/example_dags/example_dms.py
@@ -38,6 +38,10 @@ from airflow.providers.amazon.aws.operators.dms import (
     DmsStartTaskOperator,
     DmsStopTaskOperator,
 )
+from airflow.providers.amazon.aws.operators.rds import (
+    RdsCreateDbInstanceOperator,
+    RdsDeleteDbInstanceOperator,
+)
 from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor
 
 S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name')
@@ -109,29 +113,20 @@ TABLE_MAPPINGS = {
 }
 
 
-def _create_rds_instance():
-    print('Creating RDS Instance.')
-
+def _get_rds_instance_endpoint():
+    print('Retrieving RDS instance endpoint.')
     rds_client = boto3.client('rds')
-    rds_client.create_db_instance(
-        DBName=RDS_DB_NAME,
-        DBInstanceIdentifier=RDS_INSTANCE_NAME,
-        AllocatedStorage=20,
-        DBInstanceClass='db.t3.micro',
-        Engine=RDS_ENGINE,
-        MasterUsername=RDS_USERNAME,
-        MasterUserPassword=RDS_PASSWORD,
-    )
-
-    rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)
 
     response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME)
-    return response['DBInstances'][0]['Endpoint']
+    rds_instance_endpoint = response['DBInstances'][0]['Endpoint']
+    return rds_instance_endpoint
 
 
-def _create_rds_table(rds_endpoint):
-    print('Creating table.')
+@task
+def create_sample_table():
+    print('Creating sample table.')
 
+    rds_endpoint = _get_rds_instance_endpoint()
     hostname = rds_endpoint['Address']
     port = rds_endpoint['Port']
     rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}'
@@ -154,7 +149,13 @@ def _create_rds_table(rds_endpoint):
         connection.execute(table.select())
 
 
-def _create_dms_replication_instance(ti, dms_client):
+@task
+def create_dms_assets():
+    print('Creating DMS assets.')
+    ti = get_current_context()['ti']
+    dms_client = boto3.client('dms')
+    rds_instance_endpoint = _get_rds_instance_endpoint()
+
     print('Creating replication instance.')
     instance_arn = dms_client.create_replication_instance(
         ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME,
@@ -162,10 +163,7 @@ def _create_dms_replication_instance(ti, dms_client):
     )['ReplicationInstance']['ReplicationInstanceArn']
 
     ti.xcom_push(key='replication_instance_arn', value=instance_arn)
-    return instance_arn
-
 
-def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
     print('Creating DMS source endpoint.')
     source_endpoint_arn = dms_client.create_endpoint(
         EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER,
@@ -194,28 +192,16 @@ def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
     ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn)
     ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn)
 
-
-def _await_setup_assets(dms_client, instance_arn):
-    print("Awaiting asset provisioning.")
+    print("Awaiting replication instance provisioning.")
     dms_client.get_waiter('replication_instance_available').wait(
         Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}]
     )
 
 
-def _delete_rds_instance():
-    print('Deleting RDS Instance.')
-
-    rds_client = boto3.client('rds')
-    rds_client.delete_db_instance(
-        DBInstanceIdentifier=RDS_INSTANCE_NAME,
-        SkipFinalSnapshot=True,
-    )
-
-    rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)
-
-
-def _delete_dms_assets(dms_client):
+@task(trigger_rule='all_done')
+def delete_dms_assets():
     ti = get_current_context()['ti']
+    dms_client = boto3.client('dms')
     replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
     source_arn = ti.xcom_pull(key='source_endpoint_arn')
     target_arn = ti.xcom_pull(key='target_endpoint_arn')
@@ -225,13 +211,10 @@ def _delete_dms_assets(dms_client):
     dms_client.delete_endpoint(EndpointArn=source_arn)
     dms_client.delete_endpoint(EndpointArn=target_arn)
 
-
-def _await_all_teardowns(dms_client):
-    print('Awaiting tear-down.')
+    print('Awaiting DMS assets tear-down.')
     dms_client.get_waiter('replication_instance_deleted').wait(
         Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}]
     )
-
     dms_client.get_waiter('endpoint_deleted').wait(
         Filters=[
             {
@@ -242,27 +225,6 @@ def _await_all_teardowns(dms_client):
     )
 
 
-@task
-def set_up():
-    ti = get_current_context()['ti']
-    dms_client = boto3.client('dms')
-
-    rds_instance_endpoint = _create_rds_instance()
-    _create_rds_table(rds_instance_endpoint)
-    instance_arn = _create_dms_replication_instance(ti, dms_client)
-    _create_dms_endpoints(ti, dms_client, rds_instance_endpoint)
-    _await_setup_assets(dms_client, instance_arn)
-
-
-@task(trigger_rule='all_done')
-def clean_up():
-    dms_client = boto3.client('dms')
-
-    _delete_rds_instance()
-    _delete_dms_assets(dms_client)
-    _await_all_teardowns(dms_client)
-
-
 with DAG(
     dag_id='example_dms',
     schedule_interval=None,
@@ -271,6 +233,19 @@ with DAG(
     catchup=False,
 ) as dag:
 
+    create_db_instance = RdsCreateDbInstanceOperator(
+        task_id="create_db_instance",
+        db_instance_identifier=RDS_INSTANCE_NAME,
+        db_instance_class='db.t3.micro',
+        engine=RDS_ENGINE,
+        rds_kwargs={
+            "DBName": RDS_DB_NAME,
+            "AllocatedStorage": 20,
+            "MasterUsername": RDS_USERNAME,
+            "MasterUserPassword": RDS_PASSWORD,
+        },
+    )
+
     # [START howto_operator_dms_create_task]
     create_task = DmsCreateTaskOperator(
         task_id='create_task',
@@ -334,14 +309,26 @@ with DAG(
     )
     # [END howto_operator_dms_delete_task]
 
+    delete_db_instance = RdsDeleteDbInstanceOperator(
+        task_id='delete_db_instance',
+        db_instance_identifier=RDS_INSTANCE_NAME,
+        rds_kwargs={
+            "SkipFinalSnapshot": True,
+        },
+        trigger_rule='all_done',
+    )
+
     chain(
-        set_up()
-        >> create_task
-        >> start_task
-        >> describe_tasks
-        >> await_task_start
-        >> stop_task
-        >> await_task_stop
-        >> delete_task
-        >> clean_up()
+        create_db_instance,
+        create_sample_table(),
+        create_dms_assets(),
+        create_task,
+        start_task,
+        describe_tasks,
+        await_task_start,
+        stop_task,
+        await_task_stop,
+        delete_task,
+        delete_dms_assets(),
+        delete_db_instance,
     )
diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py
index a527107e80..fe38bfed69 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -18,7 +18,7 @@
 
 import json
 import time
-from typing import TYPE_CHECKING, List, Optional, Sequence
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
 
 from mypy_boto3_rds.type_defs import TagTypeDef
 
@@ -551,6 +551,108 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
         return json.dumps(delete_subscription, default=str)
 
 
+class RdsCreateDbInstanceOperator(RdsBaseOperator):
+    """
+    Creates an RDS DB instance
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:RdsCreateDbInstanceOperator`
+
+    :param db_instance_identifier: The DB instance identifier, must start with a letter and
+        contain from 1 to 63 letters, numbers, or hyphens
+    :param db_instance_class: The compute and memory capacity of the DB instance, for example db.m5.large
+    :param engine: The name of the database engine to be used for this instance
+    :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param wait_for_completion:  Whether or not wait for creation of the DB instance to
+        complete. (default: True)
+    """
+
+    def __init__(
+        self,
+        *,
+        db_instance_identifier: str,
+        db_instance_class: str,
+        engine: str,
+        rds_kwargs: Optional[Dict] = None,
+        aws_conn_id: str = "aws_default",
+        wait_for_completion: bool = True,
+        **kwargs,
+    ):
+        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+        self.db_instance_identifier = db_instance_identifier
+        self.db_instance_class = db_instance_class
+        self.engine = engine
+        self.rds_kwargs = rds_kwargs or {}
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: 'Context') -> str:
+        self.log.info("Creating new DB instance %s", self.db_instance_identifier)
+
+        create_db_instance = self.hook.conn.create_db_instance(
+            DBInstanceIdentifier=self.db_instance_identifier,
+            DBInstanceClass=self.db_instance_class,
+            Engine=self.engine,
+            **self.rds_kwargs,
+        )
+
+        if self.wait_for_completion:
+            self.hook.conn.get_waiter("db_instance_available").wait(
+                DBInstanceIdentifier=self.db_instance_identifier
+            )
+
+        return json.dumps(create_db_instance, default=str)
+
+
+class RdsDeleteDbInstanceOperator(RdsBaseOperator):
+    """
+    Deletes an RDS DB Instance
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:RdsDeleteDbInstanceOperator`
+
+    :param db_instance_identifier: The DB instance identifier for the DB instance to be deleted
+    :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param wait_for_completion:  Whether or not wait for deletion of the DB instance to
+        complete. (default: True)
+    """
+
+    def __init__(
+        self,
+        *,
+        db_instance_identifier: str,
+        rds_kwargs: Optional[Dict] = None,
+        aws_conn_id: str = "aws_default",
+        wait_for_completion: bool = True,
+        **kwargs,
+    ):
+        super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+        self.db_instance_identifier = db_instance_identifier
+        self.rds_kwargs = rds_kwargs or {}
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: 'Context') -> str:
+        self.log.info("Deleting DB instance %s", self.db_instance_identifier)
+
+        delete_db_instance = self.hook.conn.delete_db_instance(
+            DBInstanceIdentifier=self.db_instance_identifier,
+            **self.rds_kwargs,
+        )
+
+        if self.wait_for_completion:
+            self.hook.conn.get_waiter("db_instance_deleted").wait(
+                DBInstanceIdentifier=self.db_instance_identifier
+            )
+
+        return json.dumps(delete_db_instance, default=str)
+
+
 __all__ = [
     "RdsCreateDbSnapshotOperator",
     "RdsCopyDbSnapshotOperator",
@@ -559,4 +661,6 @@ __all__ = [
     "RdsDeleteEventSubscriptionOperator",
     "RdsStartExportTaskOperator",
     "RdsCancelExportTaskOperator",
+    "RdsCreateDbInstanceOperator",
+    "RdsDeleteDbInstanceOperator",
 ]
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst b/docs/apache-airflow-providers-amazon/operators/rds.rst
index 7e48a983f4..434804d910 100644
--- a/docs/apache-airflow-providers-amazon/operators/rds.rst
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -138,6 +138,34 @@ To delete an Amazon RDS event subscription you can use
     :start-after: [START howto_operator_rds_delete_event_subscription]
     :end-before: [END howto_operator_rds_delete_event_subscription]
 
+.. _howto/operator:RdsCreateDbInstanceOperator:
+
+Create a database instance
+==========================
+
+To create a AWS DB instance you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RdsCreateDbInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_rds_create_db_instance]
+    :end-before: [END howto_operator_rds_create_db_instance]
+
+.. _howto/operator:RDSDeleteDbInstanceOperator:
+
+Delete a database instance
+==========================
+
+To delete a AWS DB instance you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDbInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/rds/example_rds_instance.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_rds_delete_db_instance]
+    :end-before: [END howto_operator_rds_delete_db_instance]
+
 Sensors
 -------
 
diff --git a/tests/providers/amazon/aws/operators/test_rds.py b/tests/providers/amazon/aws/operators/test_rds.py
index d952fbc11a..355ff529c9 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -26,8 +26,10 @@ from airflow.providers.amazon.aws.operators.rds import (
     RdsBaseOperator,
     RdsCancelExportTaskOperator,
     RdsCopyDbSnapshotOperator,
+    RdsCreateDbInstanceOperator,
     RdsCreateDbSnapshotOperator,
     RdsCreateEventSubscriptionOperator,
+    RdsDeleteDbInstanceOperator,
     RdsDeleteDbSnapshotOperator,
     RdsDeleteEventSubscriptionOperator,
     RdsStartExportTaskOperator,
@@ -450,3 +452,69 @@ class TestRdsDeleteEventSubscriptionOperator:
 
         with pytest.raises(self.hook.conn.exceptions.ClientError):
             self.hook.conn.describe_event_subscriptions(SubscriptionName=EXPORT_TASK_NAME)
+
+
+@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
+class TestRdsCreateDbInstanceOperator:
+    @classmethod
+    def setup_class(cls):
+        cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+        cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+    @classmethod
+    def teardown_class(cls):
+        del cls.dag
+        del cls.hook
+
+    @mock_rds
+    def test_create_db_instance(self):
+        create_db_instance_operator = RdsCreateDbInstanceOperator(
+            task_id='test_create_db_instance',
+            db_instance_identifier=DB_INSTANCE_NAME,
+            db_instance_class="db.m5.large",
+            engine="postgres",
+            rds_kwargs={
+                "DBName": DB_INSTANCE_NAME,
+            },
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+        )
+        create_db_instance_operator.execute(None)
+
+        result = self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
+        db_instances = result.get("DBInstances")
+
+        assert db_instances
+        assert len(db_instances) == 1
+        assert db_instances[0]['DBInstanceStatus'] == 'available'
+
+
+@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
+class TestRdsDeleteDbInstanceOperator:
+    @classmethod
+    def setup_class(cls):
+        cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+        cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+    @classmethod
+    def teardown_class(cls):
+        del cls.dag
+        del cls.hook
+
+    @mock_rds
+    def test_delete_event_subscription(self):
+        _create_db_instance(self.hook)
+
+        delete_db_instance_operator = RdsDeleteDbInstanceOperator(
+            task_id='test_delete_db_instance',
+            db_instance_identifier=DB_INSTANCE_NAME,
+            rds_kwargs={
+                "SkipFinalSnapshot": True,
+            },
+            aws_conn_id=AWS_CONN,
+            dag=self.dag,
+        )
+        delete_db_instance_operator.execute(None)
+
+        with pytest.raises(self.hook.conn.exceptions.ClientError):
+            self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
diff --git a/tests/system/providers/amazon/aws/rds/__init__.py b/tests/system/providers/amazon/aws/rds/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/system/providers/amazon/aws/rds/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/amazon/aws/rds/example_rds_instance.py b/tests/system/providers/amazon/aws/rds/example_rds_instance.py
new file mode 100644
index 0000000000..da157dbdf9
--- /dev/null
+++ b/tests/system/providers/amazon/aws/rds/example_rds_instance.py
@@ -0,0 +1,74 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.rds import (
+    RdsCreateDbInstanceOperator,
+    RdsDeleteDbInstanceOperator,
+)
+from tests.system.providers.amazon.aws.utils import set_env_id
+
+ENV_ID = set_env_id()
+DAG_ID = "example_rds_instance"
+
+RDS_DB_IDENTIFIER = f'{ENV_ID}-database'
+RDS_USERNAME = 'database_username'
+# NEVER store your production password in plaintext in a DAG like this.
+# Use Airflow Secrets or a secret manager for this in production.
+RDS_PASSWORD = 'database_password'
+
+with DAG(
+    dag_id=DAG_ID,
+    schedule_interval=None,
+    start_date=datetime(2021, 1, 1),
+    tags=['example'],
+    catchup=False,
+) as dag:
+    # [START howto_operator_rds_create_db_instance]
+    create_db_instance = RdsCreateDbInstanceOperator(
+        task_id='create_db_instance',
+        db_instance_identifier=RDS_DB_IDENTIFIER,
+        db_instance_class="db.m5.large",
+        engine="postgres",
+        rds_kwargs={
+            "MasterUsername": RDS_USERNAME,
+            "MasterUserPassword": RDS_PASSWORD,
+            "AllocatedStorage": 20,
+        },
+    )
+    # [END howto_operator_rds_create_db_instance]
+
+    # [START howto_operator_rds_delete_db_instance]
+    delete_db_instance = RdsDeleteDbInstanceOperator(
+        task_id='delete_db_instance',
+        db_instance_identifier=RDS_DB_IDENTIFIER,
+        rds_kwargs={
+            "SkipFinalSnapshot": True,
+        },
+    )
+    # [END howto_operator_rds_delete_db_instance]
+
+    chain(create_db_instance, delete_db_instance)
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)