You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/04/25 22:10:11 UTC

[airflow] branch main updated: Add RedshiftCreateClusterOperator

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 754e293c54 Add RedshiftCreateClusterOperator
754e293c54 is described below

commit 754e293c546ebffc32422ff8883db57755f8518b
Author: Pankaj <pa...@astronomer.io>
AuthorDate: Sat Apr 16 02:31:32 2022 +0530

    Add RedshiftCreateClusterOperator
---
 .../aws/example_dags/example_redshift_cluster.py   |  20 +-
 .../providers/amazon/aws/hooks/redshift_cluster.py |  37 +++-
 .../amazon/aws/operators/redshift_cluster.py       | 216 ++++++++++++++++++++-
 .../operators/redshift_cluster.rst                 |  14 ++
 .../amazon/aws/operators/test_redshift_cluster.py  |  46 +++++
 5 files changed, 330 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py b/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
index 1469688266..cbeed2da34 100644
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
+++ b/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
@@ -22,6 +22,7 @@ from os import getenv
 from airflow import DAG
 from airflow.models.baseoperator import chain
 from airflow.providers.amazon.aws.operators.redshift_cluster import (
+    RedshiftCreateClusterOperator,
     RedshiftPauseClusterOperator,
     RedshiftResumeClusterOperator,
 )
@@ -36,6 +37,17 @@ with DAG(
     catchup=False,
     tags=['example'],
 ) as dag:
+    # [START howto_operator_redshift_cluster]
+    task_create_cluster = RedshiftCreateClusterOperator(
+        task_id="redshift_create_cluster",
+        cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
+        cluster_type="single-node",
+        node_type="dc2.large",
+        master_username="adminuser",
+        master_user_password="dummypass",
+    )
+    # [END howto_operator_redshift_cluster]
+
     # [START howto_sensor_redshift_cluster]
     task_wait_cluster_available = RedshiftClusterSensor(
         task_id='sensor_redshift_cluster_available',
@@ -68,4 +80,10 @@ with DAG(
     )
     # [END howto_operator_redshift_resume_cluster]
 
-    chain(task_wait_cluster_available, task_pause_cluster, task_wait_cluster_paused, task_resume_cluster)
+    chain(
+        task_create_cluster,
+        task_wait_cluster_available,
+        task_pause_cluster,
+        task_wait_cluster_paused,
+        task_resume_cluster,
+    )
diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index 88b0e85308..80fbe0da31 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
+
+from botocore.exceptions import ClientError
 
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 
@@ -37,6 +39,39 @@ class RedshiftHook(AwsBaseHook):
         kwargs["client_type"] = "redshift"
         super().__init__(*args, **kwargs)
 
+    def create_cluster(
+        self,
+        cluster_identifier: str,
+        node_type: str,
+        master_username: str,
+        master_user_password: str,
+        params: Dict[str, Any],
+    ) -> Dict[str, Any]:
+        """
+        Creates a new cluster with the specified parameters
+
+        :param cluster_identifier: A unique identifier for the cluster.
+        :param node_type: The node type to be provisioned for the cluster.
+            Valid Values: ``ds2.xlarge | ds2.8xlarge | dc1.large | dc1.8xlarge
+            | dc2.large | dc2.8xlarge | ra3.xlplus | ra3.4xlarge | ra3.16xlarge``
+        :param master_username: The username associated with the admin user account
+            for the cluster that is being created.
+        :param master_user_password: password associated with the admin user account
+            for the cluster that is being created.
+        :param params: Remaining AWS Create cluster API params.
+        """
+        try:
+            response = self.get_conn().create_cluster(
+                ClusterIdentifier=cluster_identifier,
+                NodeType=node_type,
+                MasterUsername=master_username,
+                MasterUserPassword=master_user_password,
+                **params,
+            )
+            return response
+        except ClientError as e:
+            raise e
+
     # TODO: Wrap create_cluster_snapshot
     def cluster_status(self, cluster_identifier: str) -> str:
         """
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index ec4a88f498..616cf7c909 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
 
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
@@ -23,6 +23,220 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
+class RedshiftCreateClusterOperator(BaseOperator):
+    """
+    Creates a new cluster with the specified parameters.
+
+    :param cluster_identifier:  A unique identifier for the cluster.
+    :param node_type: The node type to be provisioned for the cluster.
+        Valid Values: ``ds2.xlarge | ds2.8xlarge | dc1.large | dc1.8xlarge
+            | dc2.large | dc2.8xlarge | ra3.xlplus | ra3.4xlarge | ra3.16xlarge``
+    :param master_username: The username associated with the admin user account for
+        the cluster that is being created.
+    :param master_user_password: The password associated with the admin user account for
+        the cluster that is being created.
+    :param cluster_type: The type of the cluster ``single-node`` or ``multi-node``.
+        The default value is ``multi-node``.
+    :param db_name: The name of the first database to be created when the cluster is created.
+    :param number_of_nodes: The number of compute nodes in the cluster.
+        This param require when ``cluster_type`` is ``multi-node``.
+    :param cluster_security_groups: A list of security groups to be associated with this cluster.
+    :param vpc_security_group_ids: A list of  VPC security groups to be associated with the cluster.
+    :param cluster_subnet_group_name: The name of a cluster subnet group to be associated with this cluster.
+    :param availability_zone: The EC2 Availability Zone (AZ).
+    :param preferred_maintenance_window: The time range (in UTC) during which automated cluster
+        maintenance can occur.
+    :param cluster_parameter_group_name: The name of the parameter group to be associated with this cluster.
+    :param automated_snapshot_retention_period: The number of days that automated snapshots are retained.
+        The default value is ``1``.
+    :param manual_snapshot_retention_period: The default number of days to retain a manual snapshot.
+    :param port: The port number on which the cluster accepts incoming connections.
+        The Default value is ``5439``.
+    :param cluster_version: The version of a Redshift engine software that you want to deploy on the cluster.
+    :param allow_version_upgrade: Whether major version upgrades can be applied during the maintenance window.
+        The Default value is ``True``.
+    :parma publicly_accessible: Whether cluster can be accessed from a public network.
+    :parma encrypted: Whether data in the cluster is encrypted at rest.
+        The default value is ``False``.
+    :parma hsm_client_certificate_identifier: Name of the HSM client certificate
+        the Amazon Redshift cluster uses to retrieve the data.
+    :parma hsm_configuration_identifier: Name of the HSM configuration
+    :parma elastic_ip: The Elastic IP (EIP) address for the cluster.
+    :parma tags: A list of tag instances
+    :parma kms_key_id: KMS key id of encryption key.
+    :param enhanced_vpc_routing: Whether to create the cluster with enhanced VPC routing enabled
+        Default value is ``False``.
+    :param additional_info: Reserved
+    :param iam_roles: A list of IAM roles that can be used by the cluster to access other AWS services.
+    :param maintenance_track_name: Name of the maintenance track for the cluster.
+    :param snapshot_schedule_identifier: A  unique identifier for the snapshot schedule.
+    :param availability_zone_relocation: Enable relocation for a Redshift cluster
+        between Availability Zones after the cluster is created.
+    :param aqua_configuration_status: The cluster is configured to use AQUA .
+    :param default_iam_role_arn: ARN for the IAM role.
+    :param aws_conn_id: str = The Airflow connection used for AWS credentials.
+        The default connection id is ``aws_default``.
+    """
+
+    template_fields: Sequence[str] = (
+        "cluster_identifier",
+        "cluster_type",
+        "node_type",
+        "number_of_nodes",
+    )
+    ui_color = "#eeaa11"
+    ui_fgcolor = "#ffffff"
+
+    def __init__(
+        self,
+        *,
+        cluster_identifier: str,
+        node_type: str,
+        master_username: str,
+        master_user_password: str,
+        cluster_type: str = "multi-node",
+        db_name: str = "dev",
+        number_of_nodes: int = 1,
+        cluster_security_groups: Optional[List[str]] = None,
+        vpc_security_group_ids: Optional[List[str]] = None,
+        cluster_subnet_group_name: Optional[str] = None,
+        availability_zone: Optional[str] = None,
+        preferred_maintenance_window: Optional[str] = None,
+        cluster_parameter_group_name: Optional[str] = None,
+        automated_snapshot_retention_period: int = 1,
+        manual_snapshot_retention_period: Optional[int] = None,
+        port: int = 5439,
+        cluster_version: str = "1.0",
+        allow_version_upgrade: bool = True,
+        publicly_accessible: bool = True,
+        encrypted: bool = False,
+        hsm_client_certificate_identifier: Optional[str] = None,
+        hsm_configuration_identifier: Optional[str] = None,
+        elastic_ip: Optional[str] = None,
+        tags: Optional[List[Any]] = None,
+        kms_key_id: Optional[str] = None,
+        enhanced_vpc_routing: bool = False,
+        additional_info: Optional[str] = None,
+        iam_roles: Optional[List[str]] = None,
+        maintenance_track_name: Optional[str] = None,
+        snapshot_schedule_identifier: Optional[str] = None,
+        availability_zone_relocation: Optional[bool] = None,
+        aqua_configuration_status: Optional[str] = None,
+        default_iam_role_arn: Optional[str] = None,
+        aws_conn_id: str = "aws_default",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.cluster_identifier = cluster_identifier
+        self.node_type = node_type
+        self.master_username = master_username
+        self.master_user_password = master_user_password
+        self.cluster_type = cluster_type
+        self.db_name = db_name
+        self.number_of_nodes = number_of_nodes
+        self.cluster_security_groups = cluster_security_groups
+        self.vpc_security_group_ids = vpc_security_group_ids
+        self.cluster_subnet_group_name = cluster_subnet_group_name
+        self.availability_zone = availability_zone
+        self.preferred_maintenance_window = preferred_maintenance_window
+        self.cluster_parameter_group_name = cluster_parameter_group_name
+        self.automated_snapshot_retention_period = automated_snapshot_retention_period
+        self.manual_snapshot_retention_period = manual_snapshot_retention_period
+        self.port = port
+        self.cluster_version = cluster_version
+        self.allow_version_upgrade = allow_version_upgrade
+        self.publicly_accessible = publicly_accessible
+        self.encrypted = encrypted
+        self.hsm_client_certificate_identifier = hsm_client_certificate_identifier
+        self.hsm_configuration_identifier = hsm_configuration_identifier
+        self.elastic_ip = elastic_ip
+        self.tags = tags
+        self.kms_key_id = kms_key_id
+        self.enhanced_vpc_routing = enhanced_vpc_routing
+        self.additional_info = additional_info
+        self.iam_roles = iam_roles
+        self.maintenance_track_name = maintenance_track_name
+        self.snapshot_schedule_identifier = snapshot_schedule_identifier
+        self.availability_zone_relocation = availability_zone_relocation
+        self.aqua_configuration_status = aqua_configuration_status
+        self.default_iam_role_arn = default_iam_role_arn
+        self.aws_conn_id = aws_conn_id
+        self.kwargs = kwargs
+
+    def execute(self, context: 'Context'):
+        redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+        self.log.info("Creating Redshift cluster %s", self.cluster_identifier)
+        params: Dict[str, Any] = {}
+        if self.db_name:
+            params["DBName"] = self.db_name
+        if self.cluster_type:
+            params["ClusterType"] = self.cluster_type
+        if self.number_of_nodes:
+            params["NumberOfNodes"] = self.number_of_nodes
+        if self.cluster_security_groups:
+            params["ClusterSecurityGroups"] = self.cluster_security_groups
+        if self.vpc_security_group_ids:
+            params["VpcSecurityGroupIds"] = self.vpc_security_group_ids
+        if self.cluster_subnet_group_name:
+            params["ClusterSubnetGroupName"] = self.cluster_subnet_group_name
+        if self.availability_zone:
+            params["AvailabilityZone"] = self.availability_zone
+        if self.preferred_maintenance_window:
+            params["PreferredMaintenanceWindow"] = self.preferred_maintenance_window
+        if self.cluster_parameter_group_name:
+            params["ClusterParameterGroupName"] = self.cluster_parameter_group_name
+        if self.automated_snapshot_retention_period:
+            params["AutomatedSnapshotRetentionPeriod"] = self.automated_snapshot_retention_period
+        if self.manual_snapshot_retention_period:
+            params["ManualSnapshotRetentionPeriod"] = self.manual_snapshot_retention_period
+        if self.port:
+            params["Port"] = self.port
+        if self.cluster_version:
+            params["ClusterVersion"] = self.cluster_version
+        if self.allow_version_upgrade:
+            params["AllowVersionUpgrade"] = self.allow_version_upgrade
+        if self.publicly_accessible:
+            params["PubliclyAccessible"] = self.publicly_accessible
+        if self.encrypted:
+            params["Encrypted"] = self.encrypted
+        if self.hsm_client_certificate_identifier:
+            params["HsmClientCertificateIdentifier"] = self.hsm_client_certificate_identifier
+        if self.hsm_configuration_identifier:
+            params["HsmConfigurationIdentifier"] = self.hsm_configuration_identifier
+        if self.elastic_ip:
+            params["ElasticIp"] = self.elastic_ip
+        if self.tags:
+            params["Tags"] = self.tags
+        if self.kms_key_id:
+            params["KmsKeyId"] = self.kms_key_id
+        if self.enhanced_vpc_routing:
+            params["EnhancedVpcRouting"] = self.enhanced_vpc_routing
+        if self.additional_info:
+            params["AdditionalInfo"] = self.additional_info
+        if self.iam_roles:
+            params["IamRoles"] = self.iam_roles
+        if self.maintenance_track_name:
+            params["MaintenanceTrackName"] = self.maintenance_track_name
+        if self.snapshot_schedule_identifier:
+            params["SnapshotScheduleIdentifier"] = self.snapshot_schedule_identifier
+        if self.availability_zone_relocation:
+            params["AvailabilityZoneRelocation"] = self.availability_zone_relocation
+        if self.aqua_configuration_status:
+            params["AquaConfigurationStatus"] = self.aqua_configuration_status
+        if self.default_iam_role_arn:
+            params["DefaultIamRoleArn"] = self.default_iam_role_arn
+
+        cluster = redshift_hook.create_cluster(
+            self.cluster_identifier,
+            self.node_type,
+            self.master_username,
+            self.master_user_password,
+            params,
+        )
+        self.log.info("Created Redshift cluster %s", self.cluster_identifier)
+        self.log.info(cluster)
+
+
 class RedshiftResumeClusterOperator(BaseOperator):
     """
     Resume a paused AWS Redshift Cluster
diff --git a/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst b/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
index f7a771c7d2..0e27fee65f 100644
--- a/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
+++ b/docs/apache-airflow-providers-amazon/operators/redshift_cluster.rst
@@ -33,6 +33,20 @@ Prerequisite Tasks
 Manage Amazon Redshift Clusters
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+.. _howto/operator:RedshiftCreateClusterOperator:
+
+Amazon Redshift Cluster Operator
+""""""""""""""""""""""""""""""
+
+To create an Amazon Redshift Cluster with the specified parameters
+:class:`~airflow.providers.amazon.aws.operators.redshift_cluster.RedshiftCreateClusterOperator`.
+
+.. exampleinclude:: /../../../airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_redshift_cluster]
+    :end-before: [END howto_operator_redshift_cluster]
+
 .. _howto/sensor:RedshiftClusterSensor:
 
 Amazon Redshift Cluster Sensor
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index a1944a8d38..494d0281b3 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -18,11 +18,57 @@
 from unittest import mock
 
 from airflow.providers.amazon.aws.operators.redshift_cluster import (
+    RedshiftCreateClusterOperator,
     RedshiftPauseClusterOperator,
     RedshiftResumeClusterOperator,
 )
 
 
+class TestRedshiftCreateClusterOperator:
+    def test_init(self):
+        redshift_operator = RedshiftCreateClusterOperator(
+            task_id="task_test",
+            cluster_identifier="test_cluster",
+            node_type="dc2.large",
+            master_username="adminuser",
+            master_user_password="Test123$",
+        )
+        assert redshift_operator.task_id == "task_test"
+        assert redshift_operator.cluster_identifier == "test_cluster"
+        assert redshift_operator.node_type == "dc2.large"
+        assert redshift_operator.master_username == "adminuser"
+        assert redshift_operator.master_user_password == "Test123$"
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.redshift.RedshiftHook.get_conn")
+    def test_create_cluster(self, mock_get_conn):
+        redshift_operator = RedshiftCreateClusterOperator(
+            task_id="task_test",
+            cluster_identifier="test-cluster",
+            node_type="dc2.large",
+            master_username="adminuser",
+            master_user_password="Test123$",
+            cluster_type="single-node",
+        )
+        redshift_operator.execute(None)
+        params = {
+            "DBName": "dev",
+            "ClusterType": "single-node",
+            "NumberOfNodes": 1,
+            "AutomatedSnapshotRetentionPeriod": 1,
+            "ClusterVersion": "1.0",
+            "AllowVersionUpgrade": True,
+            "PubliclyAccessible": True,
+            "Port": 5439,
+        }
+        mock_get_conn.return_value.create_cluster.assert_called_once_with(
+            ClusterIdentifier='test-cluster',
+            NodeType="dc2.large",
+            MasterUsername="adminuser",
+            MasterUserPassword="Test123$",
+            **params,
+        )
+
+
 class TestResumeClusterOperator:
     def test_init(self):
         redshift_operator = RedshiftResumeClusterOperator(