You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/03/06 15:56:31 UTC
[airflow] branch main updated: Added AWS RDS operators (#20907)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f968eba Added AWS RDS operators (#20907)
f968eba is described below
commit f968eba4700e8963ddc3bebcabf959d2f2adaadd
Author: Dmytro Kazanzhy <dk...@gmail.com>
AuthorDate: Sun Mar 6 17:55:52 2022 +0200
Added AWS RDS operators (#20907)
---
.../amazon/aws/example_dags/example_rds.py | 127 +++++
airflow/providers/amazon/aws/operators/rds.py | 560 +++++++++++++++++++++
airflow/providers/amazon/aws/utils/rds.py | 25 +
airflow/providers/amazon/provider.yaml | 5 +
.../operators/rds.rst | 163 ++++++
tests/providers/amazon/aws/operators/test_rds.py | 452 +++++++++++++++++
6 files changed, 1332 insertions(+)
diff --git a/airflow/providers/amazon/aws/example_dags/example_rds.py b/airflow/providers/amazon/aws/example_dags/example_rds.py
new file mode 100644
index 0000000..58a1927
--- /dev/null
+++ b/airflow/providers/amazon/aws/example_dags/example_rds.py
@@ -0,0 +1,127 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This is an example dag for using `RedshiftSQLOperator` to authenticate with Amazon Redshift
+then execute a simple select statement
+"""
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.providers.amazon.aws.operators.rds import (
+ RdsCancelExportTaskOperator,
+ RdsCopyDbSnapshotOperator,
+ RdsCreateDbSnapshotOperator,
+ RdsCreateEventSubscriptionOperator,
+ RdsDeleteDbSnapshotOperator,
+ RdsDeleteEventSubscriptionOperator,
+ RdsStartExportTaskOperator,
+)
+
+# [START rds_snapshots_howto_guide]
+with DAG(
+ dag_id='rds_snapshots', start_date=datetime(2021, 1, 1), schedule_interval=None, catchup=False
+) as dag:
+ # [START howto_guide_rds_create_snapshot]
+ create_snapshot = RdsCreateDbSnapshotOperator(
+ task_id='create_snapshot',
+ db_type='instance',
+ db_identifier='auth-db',
+ db_snapshot_identifier='auth-db-snap',
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_create_snapshot]
+
+ # [START howto_guide_rds_copy_snapshot]
+ copy_snapshot = RdsCopyDbSnapshotOperator(
+ task_id='copy_snapshot',
+ db_type='instance',
+ target_db_snapshot_identifier='auth-db-snap-backup',
+ source_db_snapshot_identifier='auth-db-snap',
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_copy_snapshot]
+
+ # [START howto_guide_rds_delete_snapshot]
+ delete_snapshot = RdsDeleteDbSnapshotOperator(
+ task_id='delete_snapshot',
+ db_type='instance',
+ db_snapshot_identifier='auth-db-snap-backup',
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_delete_snapshot]
+
+ create_snapshot >> copy_snapshot >> delete_snapshot
+# [END rds_snapshots_howto_guide]
+
+# [START rds_exports_howto_guide]
+with DAG(dag_id='rds_exports', start_date=datetime(2021, 1, 1), schedule_interval=None, catchup=False) as dag:
+ # [START howto_guide_rds_start_export]
+ start_export = RdsStartExportTaskOperator(
+ task_id='start_export',
+ export_task_identifier='export-auth-db-snap-{{ ds }}',
+ source_arn='arn:aws:rds:<region>:<account number>:snapshot:auth-db-snap',
+ s3_bucket_name='my_s3_bucket',
+ iam_role_arn='arn:aws:iam:<region>:<account number>:role/MyRole',
+ kms_key_id='arn:aws:kms:<region>:<account number>:key/*****-****-****-****-********',
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_start_export]
+
+ # [START howto_guide_rds_cancel_export]
+ cancel_export = RdsCancelExportTaskOperator(
+ task_id='cancel_export',
+ export_task_identifier='export-auth-db-snap-{{ ds }}',
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_cancel_export]
+
+ start_export >> cancel_export
+# [END rds_exports_howto_guide]
+
+# [START rds_events_howto_guide]
+with DAG(dag_id='rds_events', start_date=datetime(2021, 1, 1), schedule_interval=None, catchup=False) as dag:
+ # [START howto_guide_rds_create_subscription]
+ create_subscription = RdsCreateEventSubscriptionOperator(
+ task_id='create_subscription',
+ subscription_name='my_topic_subscription',
+ sns_topic_arn='arn:aws:sns:<region>:<account number>:MyTopic',
+ source_type='db-instance',
+ source_ids=['auth-db'],
+ event_categories=['Availability', 'Backup'],
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_create_subscription]
+
+ # [START howto_guide_rds_delete_subscription]
+ delete_subscription = RdsDeleteEventSubscriptionOperator(
+ task_id='delete_subscription',
+ subscription_name='my_topic_subscription',
+ aws_conn_id='aws_default',
+ hook_params={'region_name': 'us-east-1'},
+ )
+ # [END howto_guide_rds_delete_subscription]
+
+ create_subscription >> delete_subscription
+# [END rds_events_howto_guide]
diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py
new file mode 100644
index 0000000..e14df92
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -0,0 +1,560 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import time
+from typing import TYPE_CHECKING, List, Optional, Sequence
+
+from mypy_boto3_rds.type_defs import TagTypeDef
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.utils.rds import RdsDbType
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class RdsBaseOperator(BaseOperator):
+ """Base operator that implements common functions for all operators"""
+
+ ui_color = "#eeaa88"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optional[dict] = None, **kwargs):
+ hook_params = hook_params or {}
+ self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params)
+ super().__init__(*args, **kwargs)
+
+ self._await_interval = 60 # seconds
+
+ def _describe_item(self, item_type: str, item_name: str) -> list:
+
+ if item_type == 'instance_snapshot':
+ db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
+ return db_snaps['DBSnapshots']
+ elif item_type == 'cluster_snapshot':
+ cl_snaps = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
+ return cl_snaps['DBClusterSnapshots']
+ elif item_type == 'export_task':
+ exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
+ return exports['ExportTasks']
+ elif item_type == 'event_subscription':
+ subscriptions = self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
+ return subscriptions['EventSubscriptionsList']
+ else:
+ raise AirflowException(f"Method for {item_type} is not implemented")
+
+ def _await_status(
+ self,
+ item_type: str,
+ item_name: str,
+ wait_statuses: Optional[List[str]] = None,
+ ok_statuses: Optional[List[str]] = None,
+ error_statuses: Optional[List[str]] = None,
+ ) -> None:
+ """
+ Continuously gets item description from `_describe_item()` and waits until:
+ - status is in `wait_statuses`
+ - status not in `ok_statuses` and `error_statuses`
+ """
+ while True:
+ items = self._describe_item(item_type, item_name)
+
+ if len(items) == 0:
+ raise AirflowException(f"There is no {item_type} with identifier {item_name}")
+ if len(items) > 1:
+ raise AirflowException(f"There are {len(items)} {item_type} with identifier {item_name}")
+
+ if wait_statuses and items[0]['Status'] in wait_statuses:
+ continue
+ elif ok_statuses and items[0]['Status'] in ok_statuses:
+ break
+ elif error_statuses and items[0]['Status'] in error_statuses:
+ raise AirflowException(f"Item has error status ({error_statuses}): {items[0]}")
+
+ time.sleep(self._await_interval)
+
+ return None
+
+ def execute(self, context: 'Context') -> str:
+ """Different implementations for snapshots, tasks and events"""
+ raise NotImplementedError
+
+ def on_kill(self) -> None:
+ """Different implementations for snapshots, tasks and events"""
+ raise NotImplementedError
+
+
+class RdsCreateDbSnapshotOperator(RdsBaseOperator):
+ """
+ Creates a snapshot of a DB instance or DB cluster.
+ The source DB instance or cluster must be in the available or storage-optimization state.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsCreateDbSnapshotOperator`
+
+ :param db_type: Type of the DB - either "instance" or "cluster"
+ :param db_identifier: The identifier of the instance or cluster that you want to create the snapshot of
+ :param db_snapshot_identifier: The identifier for the DB snapshot
+ :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
+ `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
+ """
+
+ template_fields = ("db_snapshot_identifier", "db_instance_identifier", "tags")
+
+ def __init__(
+ self,
+ *,
+ db_type: str,
+ db_identifier: str,
+ db_snapshot_identifier: str,
+ tags: Optional[Sequence[TagTypeDef]] = None,
+ aws_conn_id: str = "aws_conn_id",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_type = RdsDbType(db_type)
+ self.db_identifier = db_identifier
+ self.db_snapshot_identifier = db_snapshot_identifier
+ self.tags = tags or []
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info(
+ "Starting to create snapshot of RDS %s '%s': %s",
+ self.db_type,
+ self.db_identifier,
+ self.db_snapshot_identifier,
+ )
+
+ if self.db_type.value == "instance":
+ create_instance_snap = self.hook.conn.create_db_snapshot(
+ DBInstanceIdentifier=self.db_identifier,
+ DBSnapshotIdentifier=self.db_snapshot_identifier,
+ Tags=self.tags,
+ )
+ create_response = json.dumps(create_instance_snap, default=str)
+ self._await_status(
+ 'instance_snapshot',
+ self.db_snapshot_identifier,
+ wait_statuses=['creating'],
+ ok_statuses=['available'],
+ )
+ else:
+ create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
+ DBClusterIdentifier=self.db_identifier,
+ DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
+ Tags=self.tags,
+ )
+ create_response = json.dumps(create_cluster_snap, default=str)
+ self._await_status(
+ 'cluster_snapshot',
+ self.db_snapshot_identifier,
+ wait_statuses=['creating'],
+ ok_statuses=['available'],
+ )
+
+ return create_response
+
+
+class RdsCopyDbSnapshotOperator(RdsBaseOperator):
+ """
+ Copies the specified DB instance or DB cluster snapshot
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsCopyDbSnapshotOperator`
+
+ :param db_type: Type of the DB - either "instance" or "cluster"
+ :param source_db_snapshot_identifier: The identifier of the source snapshot
+ :param target_db_snapshot_identifier: The identifier of the target snapshot
+ :param kms_key_id: The AWS KMS key identifier for an encrypted DB snapshot
+ :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
+ `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
+ :param copy_tags: Whether to copy all tags from the source snapshot to the target snapshot (default False)
+ :param pre_signed_url: The URL that contains a Signature Version 4 signed request
+ :param option_group_name: The name of an option group to associate with the copy of the snapshot
+ Only when db_type='instance'
+ :param target_custom_availability_zone: The external custom Availability Zone identifier for the target
+ Only when db_type='instance'
+ :param source_region: The ID of the region that contains the snapshot to be copied
+ """
+
+ template_fields = (
+ "source_db_snapshot_identifier",
+ "target_db_snapshot_identifier",
+ "tags",
+ "pre_signed_url",
+ "option_group_name",
+ )
+
+ def __init__(
+ self,
+ *,
+ db_type: str,
+ source_db_snapshot_identifier: str,
+ target_db_snapshot_identifier: str,
+ kms_key_id: str = "",
+ tags: Optional[Sequence[TagTypeDef]] = None,
+ copy_tags: bool = False,
+ pre_signed_url: str = "",
+ option_group_name: str = "",
+ target_custom_availability_zone: str = "",
+ source_region: str = "",
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.db_type = RdsDbType(db_type)
+ self.source_db_snapshot_identifier = source_db_snapshot_identifier
+ self.target_db_snapshot_identifier = target_db_snapshot_identifier
+ self.kms_key_id = kms_key_id
+ self.tags = tags or []
+ self.copy_tags = copy_tags
+ self.pre_signed_url = pre_signed_url
+ self.option_group_name = option_group_name
+ self.target_custom_availability_zone = target_custom_availability_zone
+ self.source_region = source_region
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info(
+ "Starting to copy snapshot '%s' as '%s'",
+ self.source_db_snapshot_identifier,
+ self.target_db_snapshot_identifier,
+ )
+
+ if self.db_type.value == "instance":
+ copy_instance_snap = self.hook.conn.copy_db_snapshot(
+ SourceDBSnapshotIdentifier=self.source_db_snapshot_identifier,
+ TargetDBSnapshotIdentifier=self.target_db_snapshot_identifier,
+ KmsKeyId=self.kms_key_id,
+ Tags=self.tags,
+ CopyTags=self.copy_tags,
+ PreSignedUrl=self.pre_signed_url,
+ OptionGroupName=self.option_group_name,
+ TargetCustomAvailabilityZone=self.target_custom_availability_zone,
+ SourceRegion=self.source_region,
+ )
+ copy_response = json.dumps(copy_instance_snap, default=str)
+ self._await_status(
+ 'instance_snapshot',
+ self.target_db_snapshot_identifier,
+ wait_statuses=['copying'],
+ ok_statuses=['available'],
+ )
+ else:
+ copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot(
+ SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
+ TargetDBClusterSnapshotIdentifier=self.target_db_snapshot_identifier,
+ KmsKeyId=self.kms_key_id,
+ Tags=self.tags,
+ CopyTags=self.copy_tags,
+ PreSignedUrl=self.pre_signed_url,
+ SourceRegion=self.source_region,
+ )
+ copy_response = json.dumps(copy_cluster_snap, default=str)
+ self._await_status(
+ 'cluster_snapshot',
+ self.target_db_snapshot_identifier,
+ wait_statuses=['copying'],
+ ok_statuses=['available'],
+ )
+
+ return copy_response
+
+
+class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
+ """
+ Deletes a DB instance or cluster snapshot or terminating the copy operation
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsDeleteDbSnapshotOperator`
+
+ :param db_type: Type of the DB - either "instance" or "cluster"
+ :param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot
+ """
+
+ template_fields = ("db_snapshot_identifier",)
+
+ def __init__(
+ self,
+ *,
+ db_type: str,
+ db_snapshot_identifier: str,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.db_type = RdsDbType(db_type)
+ self.db_snapshot_identifier = db_snapshot_identifier
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info("Starting to delete snapshot '%s'", self.db_snapshot_identifier)
+
+ if self.db_type.value == "instance":
+ delete_instance_snap = self.hook.conn.delete_db_snapshot(
+ DBSnapshotIdentifier=self.db_snapshot_identifier,
+ )
+ delete_response = json.dumps(delete_instance_snap, default=str)
+ else:
+ delete_cluster_snap = self.hook.conn.delete_db_cluster_snapshot(
+ DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
+ )
+ delete_response = json.dumps(delete_cluster_snap, default=str)
+
+ return delete_response
+
+
+class RdsStartExportTaskOperator(RdsBaseOperator):
+ """
+ Starts an export of a snapshot to Amazon S3. The provided IAM role must have access to the S3 bucket.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsStartExportTaskOperator`
+
+ :param export_task_identifier: A unique identifier for the snapshot export task.
+ :param source_arn: The Amazon Resource Name (ARN) of the snapshot to export to Amazon S3.
+ :param s3_bucket_name: The name of the Amazon S3 bucket to export the snapshot to.
+ :param iam_role_arn: The name of the IAM role to use for writing to the Amazon S3 bucket.
+ :param kms_key_id: The ID of the Amazon Web Services KMS key to use to encrypt the snapshot.
+ :param s3_prefix: The Amazon S3 bucket prefix to use as the file name and path of the exported snapshot.
+ :param export_only: The data to be exported from the snapshot.
+ """
+
+ template_fields = (
+ "export_task_identifier",
+ "source_arn",
+ "s3_bucket_name",
+ "iam_role_arn",
+ "kms_key_id",
+ "s3_prefix",
+ "export_only",
+ )
+
+ def __init__(
+ self,
+ *,
+ export_task_identifier: str,
+ source_arn: str,
+ s3_bucket_name: str,
+ iam_role_arn: str,
+ kms_key_id: str,
+ s3_prefix: str = '',
+ export_only: Optional[List[str]] = None,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.export_task_identifier = export_task_identifier
+ self.source_arn = source_arn
+ self.s3_bucket_name = s3_bucket_name
+ self.iam_role_arn = iam_role_arn
+ self.kms_key_id = kms_key_id
+ self.s3_prefix = s3_prefix
+ self.export_only = export_only or []
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info("Starting export task %s for snapshot %s", self.export_task_identifier, self.source_arn)
+
+ start_export = self.hook.conn.start_export_task(
+ ExportTaskIdentifier=self.export_task_identifier,
+ SourceArn=self.source_arn,
+ S3BucketName=self.s3_bucket_name,
+ IamRoleArn=self.iam_role_arn,
+ KmsKeyId=self.kms_key_id,
+ S3Prefix=self.s3_prefix,
+ ExportOnly=self.export_only,
+ )
+
+ self._await_status(
+ 'export_task',
+ self.export_task_identifier,
+ wait_statuses=['starting', 'in_progress'],
+ ok_statuses=['available', 'complete'],
+ )
+
+ return json.dumps(start_export, default=str)
+
+
+class RdsCancelExportTaskOperator(RdsBaseOperator):
+ """
+ Cancels an export task in progress that is exporting a snapshot to Amazon S3
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsCancelExportTaskOperator`
+
+ :param export_task_identifier: The identifier of the snapshot export task to cancel
+ """
+
+ template_fields = ("export_task_identifier",)
+
+ def __init__(
+ self,
+ *,
+ export_task_identifier: str,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.export_task_identifier = export_task_identifier
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info("Canceling export task %s", self.export_task_identifier)
+
+ cancel_export = self.hook.conn.cancel_export_task(
+ ExportTaskIdentifier=self.export_task_identifier,
+ )
+ self._await_status(
+ 'export_task',
+ self.export_task_identifier,
+ wait_statuses=['canceling'],
+ ok_statuses=['canceled'],
+ )
+
+ return json.dumps(cancel_export, default=str)
+
+
+class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
+ """
+ Creates an RDS event notification subscription
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsCreateEventSubscriptionOperator`
+
+ :param subscription_name: The name of the subscription (must be less than 255 characters)
+ :param sns_topic_arn: The ARN of the SNS topic created for event notification
+ :param source_type: The type of source that is generating the events. Valid values: db-instance |
+ db-cluster | db-parameter-group | db-security-group | db-snapshot | db-cluster-snapshot | db-proxy
+ :param event_categories: A list of event categories for a source type that you want to subscribe to
+ `USER Events <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Events.Messages.html>`__
+ :param source_ids: The list of identifiers of the event sources for which events are returned
+ :param enabled: A value that indicates whether to activate the subscription (default True)l
+ :param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
+ `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
+ """
+
+ template_fields = (
+ "subscription_name",
+ "sns_topic_arn",
+ "source_type",
+ "event_categories",
+ "source_ids",
+ "tags",
+ )
+
+ def __init__(
+ self,
+ *,
+ subscription_name: str,
+ sns_topic_arn: str,
+ source_type: str = "",
+ event_categories: Optional[Sequence[str]] = None,
+ source_ids: Optional[Sequence[str]] = None,
+ enabled: bool = True,
+ tags: Optional[Sequence[TagTypeDef]] = None,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.subscription_name = subscription_name
+ self.sns_topic_arn = sns_topic_arn
+ self.source_type = source_type
+ self.event_categories = event_categories or []
+ self.source_ids = source_ids or []
+ self.enabled = enabled
+ self.tags = tags or []
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info("Creating event subscription '%s' to '%s'", self.subscription_name, self.sns_topic_arn)
+
+ create_subscription = self.hook.conn.create_event_subscription(
+ SubscriptionName=self.subscription_name,
+ SnsTopicArn=self.sns_topic_arn,
+ SourceType=self.source_type,
+ EventCategories=self.event_categories,
+ SourceIds=self.source_ids,
+ Enabled=self.enabled,
+ Tags=self.tags,
+ )
+ self._await_status(
+ 'event_subscription',
+ self.subscription_name,
+ wait_statuses=['creating'],
+ ok_statuses=['created', 'available'],
+ )
+
+ return json.dumps(create_subscription, default=str)
+
+
+class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
+ """
+ Deletes an RDS event notification subscription
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsDeleteEventSubscriptionOperator`
+
+ :param subscription_name: The name of the RDS event notification subscription you want to delete
+ """
+
+ template_fields = ("subscription_name",)
+
+ def __init__(
+ self,
+ *,
+ subscription_name: str,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.subscription_name = subscription_name
+
+ def execute(self, context: 'Context') -> str:
+ self.log.info(
+ "Deleting event subscription %s",
+ self.subscription_name,
+ )
+
+ delete_subscription = self.hook.conn.delete_event_subscription(
+ SubscriptionName=self.subscription_name,
+ )
+
+ return json.dumps(delete_subscription, default=str)
+
+
+__all__ = [
+ "RdsCreateDbSnapshotOperator",
+ "RdsCopyDbSnapshotOperator",
+ "RdsDeleteDbSnapshotOperator",
+ "RdsCreateEventSubscriptionOperator",
+ "RdsDeleteEventSubscriptionOperator",
+ "RdsStartExportTaskOperator",
+ "RdsCancelExportTaskOperator",
+]
diff --git a/airflow/providers/amazon/aws/utils/rds.py b/airflow/providers/amazon/aws/utils/rds.py
new file mode 100644
index 0000000..154f65b
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/rds.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from enum import Enum
+
+
+class RdsDbType(Enum):
+ """Only available types for the RDS"""
+
+ INSTANCE: str = "instance"
+ CLUSTER: str = "cluster"
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 678e037..167e09c 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -106,6 +106,8 @@ integrations:
- integration-name: Amazon RDS
external-doc-url: https://aws.amazon.com/rds/
logo: /integration-logos/aws/Amazon-RDS_light-bg@4x.png
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/rds.rst
tags: [aws]
- integration-name: Amazon Redshift
external-doc-url: https://aws.amazon.com/redshift/
@@ -272,6 +274,9 @@ operators:
- airflow.providers.amazon.aws.operators.step_function_get_execution_output
- airflow.providers.amazon.aws.operators.step_function_start_execution
- airflow.providers.amazon.aws.operators.step_function
+ - integration-name: Amazon RDS
+ python-modules:
+ - airflow.providers.amazon.aws.operators.rds
- integration-name: Amazon Redshift
python-modules:
- airflow.providers.amazon.aws.operators.redshift
diff --git a/docs/apache-airflow-providers-amazon/operators/rds.rst b/docs/apache-airflow-providers-amazon/operators/rds.rst
new file mode 100644
index 0000000..cb28b4d
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/rds.rst
@@ -0,0 +1,163 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+RDS management operators
+=====================================
+
+.. contents::
+ :depth: 1
+ :local:
+
+
+.. _howto/operator:RDSCreateDBSnapshotOperator:
+
+Create DB snapshot
+""""""""""""""""""
+
+To create a snapshot of AWS RDS DB instance or DB cluster snapshot you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSCreateDBSnapshotOperator`.
+The source DB instance must be in the ``available`` or ``storage-optimization`` state.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START rds_snapshots_howto_guide]
+ :end-before: [END rds_snapshots_howto_guide]
+
+
+This Operator leverages the AWS CLI
+`create-db-snapshot <https://docs.aws.amazon.com/cli/latest/reference/rds/create-db-snapshot.html>`__ API
+`create-db-cluster-snapshot <https://docs.aws.amazon.com/cli/latest/reference/rds/create-db-cluster-snapshot.html>`__ API
+
+
+.. _howto/operator:RDSCopyDBSnapshotOperator:
+
+Copy DB snapshot
+""""""""""""""""
+
+To copy AWS RDS DB instance or DB cluster snapshot you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSCopyDBSnapshotOperator`.
+The source DB snapshot must be in the ``available`` state.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START howto_guide_rds_copy_snapshot]
+ :end-before: [END howto_guide_rds_copy_snapshot]
+
+This Operator leverages the AWS CLI
+`copy-db-snapshot <https://docs.aws.amazon.com/cli/latest/reference/rds/copy-db-snapshot.html>`__ API
+`copy-db-cluster-snapshot <https://docs.aws.amazon.com/cli/latest/reference/rds/copy-db-cluster-snapshot.html>`__ API
+
+
+.. _howto/operator:RDSDeleteDBSnapshotOperator:
+
+Delete DB snapshot
+""""""""""""""""""
+
+To delete AWS RDS DB instance or DB cluster snapshot you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteDBSnapshotOperator`.
+The DB snapshot must be in the ``available`` state to be deleted.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START howto_guide_rds_delete_snapshot]
+ :end-before: [END howto_guide_rds_delete_snapshot]
+
+This Operator leverages the AWS CLI
+`delete-db-snapshot <https://docs.aws.amazon.com/cli/latest/reference/rds/delete-db-snapshot.html>`__ API
+`delete-db-cluster-snapshot <https://docs.aws.amazon.com/cli/latest/reference/rds/delete-db-cluster-snapshot.html>`__ API
+
+
+.. _howto/operator:RDSStartExportTaskOperator:
+
+Start export task
+"""""""""""""""""
+
+To start task that exports RDS snapshot to S3 you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSStartExportTaskOperator`.
+The provided IAM role must have access to the S3 bucket.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START howto_guide_rds_start_export]
+ :end-before: [END howto_guide_rds_start_export]
+
+This Operator leverages the AWS CLI
+`start-export-task <https://docs.aws.amazon.com/cli/latest/reference/rds/start-export-task.html>`__ API
+
+
+.. _howto/operator:RDSCancelExportTaskOperator:
+
+Cancel export task
+""""""""""""""""""
+
+To cancel task that exports RDS snapshot to S3 you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSCancelExportTaskOperator`.
+Any data that has already been written to the S3 bucket isn't removed.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START howto_guide_rds_cancel_export]
+ :end-before: [END howto_guide_rds_cancel_export]
+
+This Operator leverages the AWS CLI
+`cancel-export-task <https://docs.aws.amazon.com/cli/latest/reference/rds/cancel-export-task.html>`__ API
+
+
+.. _howto/operator:RDSCreateEventSubscriptionOperator:
+
+Create event subscription
+"""""""""""""""""""""""""
+
+To create event subscription you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSCreateEventSubscriptionOperator`.
+This action requires a topic Amazon Resource Name (ARN) created by either the RDS console, the SNS console, or the SNS API.
+To obtain an ARN with SNS, you must create a topic in Amazon SNS and subscribe to the topic.
+RDS event notification is only available for not encrypted SNS topics.
+If you specify an encrypted SNS topic, event notifications are not sent for the topic.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START howto_guide_rds_create_subscription]
+ :end-before: [END howto_guide_rds_create_subscription]
+
+This Operator leverages the AWS CLI
+`create-event-subscription <https://docs.aws.amazon.com/cli/latest/reference/rds/create-event-subscription.html>`__ API
+
+
+.. _howto/operator:RDSDeleteEventSubscriptionOperator:
+
+Delete event subscription
+"""""""""""""""""""""""""
+
+To delete event subscription you can use
+:class:`~airflow.providers.amazon.aws.operators.rds.RDSDeleteEventSubscriptionOperator`
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_rds.py
+ :language: python
+ :start-after: [START howto_guide_rds_delete_subscription]
+ :end-before: [END howto_guide_rds_delete_subscription]
+
+This Operator leverages the AWS CLI
+`delete-event-subscription <https://docs.aws.amazon.com/cli/latest/reference/rds/delete-event-subscription.html>`__ API
+
+
+Reference
+---------
+
+For further information, look at:
+
+* `Boto3 Library Documentation for RDS <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html>`__
diff --git a/tests/providers/amazon/aws/operators/test_rds.py b/tests/providers/amazon/aws/operators/test_rds.py
new file mode 100644
index 0000000..0989736
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -0,0 +1,452 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG
+from airflow.providers.amazon.aws.hooks.rds import RdsHook
+from airflow.providers.amazon.aws.operators.rds import (
+ RdsBaseOperator,
+ RdsCancelExportTaskOperator,
+ RdsCopyDbSnapshotOperator,
+ RdsCreateDbSnapshotOperator,
+ RdsCreateEventSubscriptionOperator,
+ RdsDeleteDbSnapshotOperator,
+ RdsDeleteEventSubscriptionOperator,
+ RdsStartExportTaskOperator,
+)
+from airflow.utils import timezone
+
+try:
+ from moto import mock_rds2
+except ImportError:
+ mock_rds2 = None
+
+
+DEFAULT_DATE = timezone.datetime(2019, 1, 1)
+
+AWS_CONN = 'amazon_default'
+
+DB_INSTANCE_NAME = 'my-db-instance'
+DB_CLUSTER_NAME = 'my-db-cluster'
+
+DB_INSTANCE_SNAPSHOT = 'my-db-instance-snap'
+DB_CLUSTER_SNAPSHOT = 'my-db-cluster-snap'
+
+DB_INSTANCE_SNAPSHOT_COPY = 'my-db-instance-snap-copy'
+DB_CLUSTER_SNAPSHOT_COPY = 'my-db-cluster-snap-copy'
+
+EXPORT_TASK_NAME = 'my-db-instance-snap-export'
+EXPORT_TASK_SOURCE = 'arn:aws:rds:es-east-1::snapshot:my-db-instance-snap'
+EXPORT_TASK_ROLE_NAME = 'MyRole'
+EXPORT_TASK_ROLE_ARN = 'arn:aws:iam:es-east-1::role/MyRole'
+EXPORT_TASK_KMS = 'arn:aws:kms:es-east-1::key/*****-****-****-****-********'
+EXPORT_TASK_BUCKET = 'my-exports-bucket'
+
+SUBSCRIPTION_NAME = 'my-db-instance-subscription'
+SUBSCRIPTION_TOPIC = 'arn:aws:sns:us-east-1::MyTopic'
+
+
+def _create_db_instance(hook: RdsHook):
+ hook.conn.create_db_instance(
+ DBInstanceIdentifier=DB_INSTANCE_NAME,
+ DBInstanceClass='db.m4.large',
+ Engine='postgres',
+ )
+ if not hook.conn.describe_db_instances()['DBInstances']:
+ raise ValueError('AWS not properly mocked')
+
+
+def _create_db_cluster(hook: RdsHook):
+ hook.conn.create_db_cluster(
+ DBClusterIdentifier=DB_CLUSTER_NAME,
+ Engine='mysql',
+ MasterUsername='admin',
+ MasterUserPassword='admin-pass',
+ )
+ if not hook.conn.describe_db_clusters()['DBClusters']:
+ raise ValueError('AWS not properly mocked')
+
+
+def _create_db_instance_snapshot(hook: RdsHook):
+ hook.conn.create_db_snapshot(
+ DBInstanceIdentifier=DB_INSTANCE_NAME,
+ DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT,
+ )
+ if not hook.conn.describe_db_snapshots()['DBSnapshots']:
+ raise ValueError('AWS not properly mocked')
+
+
+def _create_db_cluster_snapshot(hook: RdsHook):
+ hook.conn.create_db_cluster_snapshot(
+ DBClusterIdentifier=DB_CLUSTER_NAME,
+ DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT,
+ )
+ if not hook.conn.describe_db_cluster_snapshots()['DBClusterSnapshots']:
+ raise ValueError('AWS not properly mocked')
+
+
+def _start_export_task(hook: RdsHook):
+ hook.conn.start_export_task(
+ ExportTaskIdentifier=EXPORT_TASK_NAME,
+ SourceArn=EXPORT_TASK_SOURCE,
+ IamRoleArn=EXPORT_TASK_ROLE_ARN,
+ KmsKeyId=EXPORT_TASK_KMS,
+ S3BucketName=EXPORT_TASK_BUCKET,
+ )
+ if not hook.conn.describe_export_tasks()['ExportTasks']:
+ raise ValueError('AWS not properly mocked')
+
+
+def _create_event_subscription(hook: RdsHook):
+ hook.conn.create_event_subscription(
+ SubscriptionName=SUBSCRIPTION_NAME,
+ SnsTopicArn=SUBSCRIPTION_TOPIC,
+ SourceType='db-instance',
+ SourceIds=[DB_INSTANCE_NAME],
+ )
+ if not hook.conn.describe_event_subscriptions()['EventSubscriptionsList']:
+ raise ValueError('AWS not properly mocked')
+
+
+class TestBaseRdsOperator:
+ dag = None
+ op = None
+
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.op = RdsBaseOperator(task_id='test_task', aws_conn_id='aws_default', dag=cls.dag)
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.op
+
+ def test_hook_attribute(self):
+ assert hasattr(self.op, 'hook')
+ assert self.op.hook.__class__.__name__ == 'RdsHook'
+
+ def test_describe_item_wrong_type(self):
+ with pytest.raises(AirflowException):
+ self.op._describe_item('database', 'auth-db')
+
+ def test_await_status_error(self):
+ self.op._describe_item = lambda item_type, item_name: [{'Status': 'error'}]
+ with pytest.raises(AirflowException):
+ self.op._await_status(
+ item_type='instance_snapshot',
+ item_name='',
+ wait_statuses=['wait'],
+ error_statuses=['error'],
+ )
+
+ def test_await_status_ok(self):
+ self.op._describe_item = lambda item_type, item_name: [{'Status': 'ok'}]
+ self.op._await_status(
+ item_type='instance_snapshot', item_name='', wait_statuses=['wait'], ok_statuses=['ok']
+ )
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsCreateDbSnapshotOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_create_db_instance_snapshot(self):
+ _create_db_instance(self.hook)
+ instance_snapshot_operator = RdsCreateDbSnapshotOperator(
+ task_id='test_instance',
+ db_type='instance',
+ db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+ db_identifier=DB_INSTANCE_NAME,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ instance_snapshot_operator.execute(None)
+
+ result = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT)
+ instance_snapshots = result.get("DBSnapshots")
+
+ assert instance_snapshots
+ assert len(instance_snapshots) == 1
+
+ @mock_rds2
+ def test_create_db_cluster_snapshot(self):
+ _create_db_cluster(self.hook)
+ cluster_snapshot_operator = RdsCreateDbSnapshotOperator(
+ task_id='test_cluster',
+ db_type='cluster',
+ db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+ db_identifier=DB_CLUSTER_NAME,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ cluster_snapshot_operator.execute(None)
+
+ result = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+ cluster_snapshots = result.get("DBClusterSnapshots")
+
+ assert cluster_snapshots
+ assert len(cluster_snapshots) == 1
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsCopyDbSnapshotOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_copy_db_instance_snapshot(self):
+ _create_db_instance(self.hook)
+ _create_db_instance_snapshot(self.hook)
+
+ instance_snapshot_operator = RdsCopyDbSnapshotOperator(
+ task_id='test_instance',
+ db_type='instance',
+ source_db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+ target_db_snapshot_identifier=DB_INSTANCE_SNAPSHOT_COPY,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ instance_snapshot_operator.execute(None)
+ result = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_INSTANCE_SNAPSHOT_COPY)
+ instance_snapshots = result.get("DBSnapshots")
+
+ assert instance_snapshots
+ assert len(instance_snapshots) == 1
+
+ @mock_rds2
+ def test_copy_db_cluster_snapshot(self):
+ _create_db_cluster(self.hook)
+ _create_db_cluster_snapshot(self.hook)
+
+ cluster_snapshot_operator = RdsCopyDbSnapshotOperator(
+ task_id='test_cluster',
+ db_type='cluster',
+ source_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+ target_db_snapshot_identifier=DB_CLUSTER_SNAPSHOT_COPY,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ cluster_snapshot_operator.execute(None)
+ result = self.hook.conn.describe_db_cluster_snapshots(
+ DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT_COPY
+ )
+ cluster_snapshots = result.get("DBClusterSnapshots")
+
+ assert cluster_snapshots
+ assert len(cluster_snapshots) == 1
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsDeleteDbSnapshotOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_delete_db_instance_snapshot(self):
+ _create_db_instance(self.hook)
+ _create_db_instance_snapshot(self.hook)
+
+ instance_snapshot_operator = RdsDeleteDbSnapshotOperator(
+ task_id='test_instance',
+ db_type='instance',
+ db_snapshot_identifier=DB_INSTANCE_SNAPSHOT,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ instance_snapshot_operator.execute(None)
+
+ with pytest.raises(self.hook.conn.exceptions.ClientError):
+ self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+
+ @mock_rds2
+ def test_delete_db_cluster_snapshot(self):
+ _create_db_cluster(self.hook)
+ _create_db_cluster_snapshot(self.hook)
+
+ cluster_snapshot_operator = RdsDeleteDbSnapshotOperator(
+ task_id='test_cluster',
+ db_type='cluster',
+ db_snapshot_identifier=DB_CLUSTER_SNAPSHOT,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ cluster_snapshot_operator.execute(None)
+
+ with pytest.raises(self.hook.conn.exceptions.ClientError):
+ self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsStartExportTaskOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_start_export_task(self):
+ _create_db_instance(self.hook)
+ _create_db_instance_snapshot(self.hook)
+
+ start_export_operator = RdsStartExportTaskOperator(
+ task_id='test_start',
+ export_task_identifier=EXPORT_TASK_NAME,
+ source_arn=EXPORT_TASK_SOURCE,
+ iam_role_arn=EXPORT_TASK_ROLE_ARN,
+ kms_key_id=EXPORT_TASK_KMS,
+ s3_bucket_name=EXPORT_TASK_BUCKET,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ start_export_operator.execute(None)
+
+ result = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=EXPORT_TASK_NAME)
+ export_tasks = result.get("ExportTasks")
+
+ assert export_tasks
+ assert len(export_tasks) == 1
+ assert export_tasks[0]['Status'] == 'available'
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsCancelExportTaskOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_cancel_export_task(self):
+ _create_db_instance(self.hook)
+ _create_db_instance_snapshot(self.hook)
+ _start_export_task(self.hook)
+
+ cancel_export_operator = RdsCancelExportTaskOperator(
+ task_id='test_cancel',
+ export_task_identifier=EXPORT_TASK_NAME,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ cancel_export_operator.execute(None)
+
+ result = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=EXPORT_TASK_NAME)
+ export_tasks = result.get("ExportTasks")
+
+ assert export_tasks
+ assert len(export_tasks) == 1
+ assert export_tasks[0]['Status'] == 'canceled'
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsCreateEventSubscriptionOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_create_event_subscription(self):
+ _create_db_instance(self.hook)
+
+ create_subscription_operator = RdsCreateEventSubscriptionOperator(
+ task_id='test_create',
+ subscription_name=SUBSCRIPTION_NAME,
+ sns_topic_arn=SUBSCRIPTION_TOPIC,
+ source_type='db-instance',
+ source_ids=[DB_INSTANCE_NAME],
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ create_subscription_operator.execute(None)
+
+ result = self.hook.conn.describe_event_subscriptions(SubscriptionName=SUBSCRIPTION_NAME)
+ subscriptions = result.get("EventSubscriptionsList")
+
+ assert subscriptions
+ assert len(subscriptions) == 1
+ assert subscriptions[0]['Status'] == 'available'
+
+
+@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
+class TestRdsDeleteEventSubscriptionOperator:
+ @classmethod
+ def setup_class(cls):
+ cls.dag = DAG('test_dag', default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE})
+ cls.hook = RdsHook(aws_conn_id=AWS_CONN, region_name='us-east-1')
+
+ @classmethod
+ def teardown_class(cls):
+ del cls.dag
+ del cls.hook
+
+ @mock_rds2
+ def test_delete_event_subscription(self):
+ _create_event_subscription(self.hook)
+
+ delete_subscription_operator = RdsDeleteEventSubscriptionOperator(
+ task_id='test_delete',
+ subscription_name=SUBSCRIPTION_NAME,
+ aws_conn_id=AWS_CONN,
+ dag=self.dag,
+ )
+ delete_subscription_operator.execute(None)
+
+ with pytest.raises(self.hook.conn.exceptions.ClientError):
+ self.hook.conn.describe_event_subscriptions(SubscriptionName=EXPORT_TASK_NAME)