You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/04/25 22:10:11 UTC
[airflow] branch main updated: Add RedshiftCreateClusterOperator
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 754e293c54 Add RedshiftCreateClusterOperator
754e293c54 is described below
commit 754e293c546ebffc32422ff8883db57755f8518b
Author: Pankaj <pa...@astronomer.io>
AuthorDate: Sat Apr 16 02:31:32 2022 +0530
Add RedshiftCreateClusterOperator
---
.../aws/example_dags/example_redshift_cluster.py | 20 +-
.../providers/amazon/aws/hooks/redshift_cluster.py | 37 +++-
.../amazon/aws/operators/redshift_cluster.py | 216 ++++++++++++++++++++-
.../operators/redshift_cluster.rst | 14 ++
.../amazon/aws/operators/test_redshift_cluster.py | 46 +++++
5 files changed, 330 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py b/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
index 1469688266..cbeed2da34 100644
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
+++ b/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
@@ -22,6 +22,7 @@ from os import getenv
from airflow import DAG
from airflow.models.baseoperator import chain
from airflow.providers.amazon.aws.operators.redshift_cluster import (
+ RedshiftCreateClusterOperator,
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
@@ -36,6 +37,17 @@ with DAG(
catchup=False,
tags=['example'],
) as dag:
+ # [START howto_operator_redshift_cluster]
+ task_create_cluster = RedshiftCreateClusterOperator(
+ task_id="redshift_create_cluster",
+ cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
+ cluster_type="single-node",
+ node_type="dc2.large",
+ master_username="adminuser",
+ master_user_password="dummypass",
+ )
+ # [END howto_operator_redshift_cluster]
+
# [START howto_sensor_redshift_cluster]
task_wait_cluster_available = RedshiftClusterSensor(
task_id='sensor_redshift_cluster_available',
@@ -68,4 +80,10 @@ with DAG(
)
# [END howto_operator_redshift_resume_cluster]
- chain(task_wait_cluster_available, task_pause_cluster, task_wait_cluster_paused, task_resume_cluster)
+ chain(
+ task_create_cluster,
+ task_wait_cluster_available,
+ task_pause_cluster,
+ task_wait_cluster_paused,
+ task_resume_cluster,
+ )
diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index 88b0e85308..80fbe0da31 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
+
+from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -37,6 +39,39 @@ class RedshiftHook(AwsBaseHook):
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)
+ def create_cluster(
+ self,
+ cluster_identifier: str,
+ node_type: str,
+ master_username: str,
+ master_user_password: str,
+ params: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """
+ Creates a new cluster with the specified parameters
+
+ :param cluster_identifier: A unique identifier for the cluster.
+ :param node_type: The node type to be provisioned for the cluster.
+ Valid Values: ``ds2.xlarge | ds2.8xlarge | dc1.large | dc1.8xlarge
+ | dc2.large | dc2.8xlarge | ra3.xlplus | ra3.4xlarge | ra3.16xlarge``
+ :param master_username: The username associated with the admin user account
+ for the cluster that is being created.
+ :param master_user_password: password associated with the admin user account
+ for the cluster that is being created.
+ :param params: Remaining AWS Create cluster API params.
+ """
+ try:
+ response = self.get_conn().create_cluster(
+ ClusterIdentifier=cluster_identifier,
+ NodeType=node_type,
+ MasterUsername=master_username,
+ MasterUserPassword=master_user_password,
+ **params,
+ )
+ return response
+ except ClientError as e:
+ raise e
+
# TODO: Wrap create_cluster_snapshot
def cluster_status(self, cluster_identifier: str) -> str:
"""
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index ec4a88f498..616cf7c909 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
@@ -23,6 +23,220 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
+class RedshiftCreateClusterOperator(BaseOperator):
+ """
+ Creates a new cluster with the specified parameters.
+
+ :param cluster_identifier: A unique identifier for the cluster.
+ :param node_type: The node type to be provisioned for the cluster.
+ Valid Values: ``ds2.xlarge | ds2.8xlarge | dc1.large | dc1.8xlarge
+ | dc2.large | dc2.8xlarge | ra3.xlplus | ra3.4xlarge | ra3.16xlarge``
+ :param master_username: The username associated with the admin user account for
+ the cluster that is being created.
+ :param master_user_password: The password associated with the admin user account for
+ the cluster that is being created.
+ :param cluster_type: The type of the cluster ``single-node`` or ``multi-node``.
+ The default value is ``multi-node``.
+ :param db_name: The name of the first database to be created when the cluster is created.
+ :param number_of_nodes: The number of compute nodes in the cluster.
+ This param require when ``cluster_type`` is ``multi-node``.
+ :param cluster_security_groups: A list of security groups to be associated with this cluster.
+ :param vpc_security_group_ids: A list of VPC security groups to be associated with the cluster.
+ :param cluster_subnet_group_name: The name of a cluster subnet group to be associated with this cluster.
+ :param availability_zone: The EC2 Availability Zone (AZ).
+ :param preferred_maintenance_window: The time range (in UTC) during which automated cluster
+ maintenance can occur.
+ :param cluster_parameter_group_name: The name of the parameter group to be associated with this cluster.
+ :param automated_snapshot_retention_period: The number of days that automated snapshots are retained.
+ The default value is ``1``.
+ :param manual_snapshot_retention_period: The default number of days to retain a manual snapshot.
+ :param port: The port number on which the cluster accepts incoming connections.
+ The Default value is ``5439``.
+ :param cluster_version: The version of a Redshift engine software that you want to deploy on the cluster.
+ :param allow_version_upgrade: Whether major version upgrades can be applied during the maintenance window.
+ The Default value is ``True``.
+ :parma publicly_accessible: Whether cluster can be accessed from a public network.
+ :parma encrypted: Whether data in the cluster is encrypted at rest.
+ The default value is ``False``.
+ :parma hsm_client_certificate_identifier: Name of the HSM client certificate
+ the Amazon Redshift cluster uses to retrieve the data.
+ :parma hsm_configuration_identifier: Name of the HSM configuration
+ :parma elastic_ip: The Elastic IP (EIP) address for the cluster.
+ :parma tags: A list of tag instances
+ :parma kms_key_id: KMS key id of encryption key.
+ :param enhanced_vpc_routing: Whether to create the cluster with enhanced VPC routing enabled
+ Default value is ``False``.
+ :param additional_info: Reserved
+ :param iam_roles: A list of IAM roles that can be used by the cluster to access other AWS services.
+ :param maintenance_track_name: Name of the maintenance track for the cluster.
+ :param snapshot_schedule_identifier: A unique identifier for the snapshot schedule.
+ :param availability_zone_relocation: Enable relocation for a Redshift cluster
+ between Availability Zones after the cluster is created.
+ :param aqua_configuration_status: The cluster is configured to use AQUA .
+ :param default_iam_role_arn: ARN for the IAM role.
+ :param aws_conn_id: str = The Airflow connection used for AWS credentials.
+ The default connection id is ``aws_default``.
+ """
+
+ template_fields: Sequence[str] = (
+ "cluster_identifier",
+ "cluster_type",
+ "node_type",
+ "number_of_nodes",
+ )
+ ui_color = "#eeaa11"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(
+ self,
+ *,
+ cluster_identifier: str,
+ node_type: str,
+ master_username: str,
+ master_user_password: str,
+ cluster_type: str = "multi-node",
+ db_name: str = "dev",
+ number_of_nodes: int = 1,
+ cluster_security_groups: Optional[List[str]] = None,
+ vpc_security_group_ids: Optional[List[str]] = None,
+ cluster_subnet_group_name: Optional[str] = None,
+ availability_zone: Optional[str] = None,
+ preferred_maintenance_window: Optional[str] = None,
+ cluster_parameter_group_name: Optional[str] = None,
+ automated_snapshot_retention_period: int = 1,
+ manual_snapshot_retention_period: Optional[int] = None,
+ port: int = 5439,
+ cluster_version: str = "1.0",
+ allow_version_upgrade: bool = True,
+ publicly_accessible: bool = True,
+ encrypted: bool = False,
+ hsm_client_certificate_identifier: Optional[str] = None,
+ hsm_configuration_identifier: Optional[str] = None,
+ elastic_ip: Optional[str] = None,
+ tags: Optional[List[Any]] = None,
+ kms_key_id: Optional[str] = None,
+ enhanced_vpc_routing: bool = False,
+ additional_info: Optional[str] = None,
+ iam_roles: Optional[List[str]] = None,
+ maintenance_track_name: Optional[str] = None,
+ snapshot_schedule_identifier: Optional[str] = None,
+ availability_zone_relocation: Optional[bool] = None,
+ aqua_configuration_status: Optional[str] = None,
+ default_iam_role_arn: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.cluster_identifier = cluster_identifier
+ self.node_type = node_type
+ self.master_username = master_username
+ self.master_user_password = master_user_password
+ self.cluster_type = cluster_type
+ self.db_name = db_name
+ self.number_of_nodes = number_of_nodes
+ self.cluster_security_groups = cluster_security_groups
+ self.vpc_security_group_ids = vpc_security_group_ids
+ self.cluster_subnet_group_name = cluster_subnet_group_name
+ self.availability_zone = availability_zone
+ self.preferred_maintenance_window = preferred_maintenance_window
+ self.cluster_parameter_group_name = cluster_parameter_group_name
+ self.automated_snapshot_retention_period = automated_snapshot_retention_period
+ self.manual_snapshot_retention_period = manual_snapshot_retention_period
+ self.port = port
+ self.cluster_version = cluster_version
+ self.allow_version_upgrade = allow_version_upgrade
+ self.publicly_accessible = publicly_accessible
+ self.encrypted = encrypted
+ self.hsm_client_certificate_identifier = hsm_client_certificate_identifier
+ self.hsm_configuration_identifier = hsm_configuration_identifier
+ self.elastic_ip = elastic_ip
+ self.tags = tags
+ self.kms_key_id = kms_key_id
+ self.enhanced_vpc_routing = enhanced_vpc_routing
+ self.additional_info = additional_info
+ self.iam_roles = iam_roles
+ self.maintenance_track_name = maintenance_track_name
+ self.snapshot_schedule_identifier = snapshot_schedule_identifier
+ self.availability_zone_relocation = availability_zone_relocation
+ self.aqua_configuration_status = aqua_configuration_status
+ self.default_iam_role_arn = default_iam_role_arn
+ self.aws_conn_id = aws_conn_id
+ self.kwargs = kwargs
+
+ def execute(self, context: 'Context'):
+ redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+ self.log.info("Creating Redshift cluster %s", self.cluster_identifier)
+ params: Dict[str, Any] = {}
+ if self.db_name:
+ params["DBName"] = self.db_name
+ if self.cluster_type:
+ params["ClusterType"] = self.cluster_type
+ if self.number_of_nodes:
+ params["NumberOfNodes"] = self.number_of_nodes
+ if self.cluster_security_groups:
+ params["ClusterSecurityGroups"] = self.cluster_security_groups
+ if self.vpc_security_group_ids:
+ params["VpcSecurityGroupIds"] = self.vpc_security_group_ids
+ if self.cluster_subnet_group_name:
+ params["ClusterSubnetGroupName"] = self.cluster_subnet_group_name
+ if self.availability_zone:
+ params["AvailabilityZone"] = self.availability_zone
+ if self.preferred_maintenance_window:
+ params["PreferredMaintenanceWindow"] = self.preferred_maintenance_window
+ if self.cluster_parameter_group_name:
+ params["ClusterParameterGroupName"] = self.cluster_parameter_group_name
+ if self.automated_snapshot_retention_period:
+ params["AutomatedSnapshotRetentionPeriod"] = self.automated_snapshot_retention_period
+ if self.manual_snapshot_retention_period:
+ params["ManualSnapshotRetentionPeriod"] = self.manual_snapshot_retention_period
+ if self.port:
+ params["Port"] = self.port
+ if self.cluster_version:
+ params["ClusterVersion"] = self.cluster_version
+ if self.allow_version_upgrade:
+ params["AllowVersionUpgrade"] = self.allow_version_upgrade
+ if self.publicly_accessible:
+ params["PubliclyAccessible"] = self.publicly_accessible
+ if self.encrypted:
+ params["Encrypted"] = self.encrypted
+ if self.hsm_client_certificate_identifier:
+ params["HsmClientCertificateIdentifier"] = self.hsm_client_certificate_identifier
+ if self.hsm_configuration_identifier:
+ params["HsmConfigurationIdentifier"] = self.hsm_configuration_identifier
+ if self.elastic_ip:
+ params["ElasticIp"] = self.elastic_ip
+ if self.tags:
+ params["Tags"] = self.tags
+ if self.kms_key_id:
+ params["KmsKeyId"] = self.kms_key_id
+ if self.enhanced_vpc_routing:
+ params["EnhancedVpcRouting"] = self.enhanced_vpc_routing
+ if self.additional_info:
+ params["AdditionalInfo"] = self.additional_info
+ if self.iam_roles:
+ params["IamRoles"] = self.iam_roles
+ if self.maintenance_track_name:
+ params["MaintenanceTrackName"] = self.maintenance_track_name
+ if self.snapshot_schedule_identifier:
+ params["SnapshotScheduleIdentifier"] = self.snapshot_schedule_identifier
+ if self.availability_zone_relocation:
+ params["AvailabilityZoneRelocation"] = self.availability_zone_relocation
+ if self.aqua_configuration_status:
+ params["AquaConfigurationStatus"] = self.aqua_configuration_status
+ if self.default_iam_role_arn:
+ params["DefaultIamRoleArn"] = self.default_iam_role_arn
+
+ cluster = redshift_hook.create_cluster(
+ self.cluster_identifier,
+ self.node_type,
+ self.master_username,
+ self.master_user_password,
+ params,
+ )
+ self.log.info("Created Redshift cluster %s", self.cluster_identifier)
+ self.log.info(cluster)
+
+
class RedshiftResumeClusterOperator(BaseOperator):
"""
Resume a paused AWS Redshift Cluster
diff --git a/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst b/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
index f7a771c7d2..0e27fee65f 100644
--- a/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
+++ b/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
@@ -33,6 +33,20 @@ Prerequisite Tasks
Manage Amazon Redshift Clusters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+.. _howto/operator:RedshiftCreateClusterOperator:
+
+Amazon Redshift Cluster Operator
+""""""""""""""""""""""""""""""
+
+To create an Amazon Redshift Cluster with the specified parameters
+:class:`~airflow.providers.amazon.aws.operators.redshift_cluster.RedshiftCreateClusterOperator`.
+
+.. exampleinclude:: /../../../airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_redshift_cluster]
+ :end-before: [END howto_operator_redshift_cluster]
+
.. _howto/sensor:RedshiftClusterSensor:
Amazon Redshift Cluster Sensor
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index a1944a8d38..494d0281b3 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -18,11 +18,57 @@
from unittest import mock
from airflow.providers.amazon.aws.operators.redshift_cluster import (
+ RedshiftCreateClusterOperator,
RedshiftPauseClusterOperator,
RedshiftResumeClusterOperator,
)
+class TestRedshiftCreateClusterOperator:
+ def test_init(self):
+ redshift_operator = RedshiftCreateClusterOperator(
+ task_id="task_test",
+ cluster_identifier="test_cluster",
+ node_type="dc2.large",
+ master_username="adminuser",
+ master_user_password="Test123$",
+ )
+ assert redshift_operator.task_id == "task_test"
+ assert redshift_operator.cluster_identifier == "test_cluster"
+ assert redshift_operator.node_type == "dc2.large"
+ assert redshift_operator.master_username == "adminuser"
+ assert redshift_operator.master_user_password == "Test123$"
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn")
+ def test_create_cluster(self, mock_get_conn):
+ redshift_operator = RedshiftCreateClusterOperator(
+ task_id="task_test",
+ cluster_identifier="test-cluster",
+ node_type="dc2.large",
+ master_username="adminuser",
+ master_user_password="Test123$",
+ cluster_type="single-node",
+ )
+ redshift_operator.execute(None)
+ params = {
+ "DBName": "dev",
+ "ClusterType": "single-node",
+ "NumberOfNodes": 1,
+ "AutomatedSnapshotRetentionPeriod": 1,
+ "ClusterVersion": "1.0",
+ "AllowVersionUpgrade": True,
+ "PubliclyAccessible": True,
+ "Port": 5439,
+ }
+ mock_get_conn.return_value.create_cluster.assert_called_once_with(
+ ClusterIdentifier='test-cluster',
+ NodeType="dc2.large",
+ MasterUsername="adminuser",
+ MasterUserPassword="Test123$",
+ **params,
+ )
+
+
class TestResumeClusterOperator:
def test_init(self):
redshift_operator = RedshiftResumeClusterOperator(