You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/12/08 09:16:51 UTC

incubator-airflow git commit: [AIRFLOW-1888] Add AWS Redshift Cluster Sensor

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 9ad6d1202 -> 4936a8077


[AIRFLOW-1888] Add AWS Redshift Cluster Sensor

Add AWS Redshift Cluster Sensor to contrib, along
with corresponding
unit tests. Additionally, updated Redshift Hook
cluster_status method to
better handle cluster_not_found exception, added
unit tests, and
corrected linting errors.

Closes #2849 from andyxhadji/AIRFLOW-1888


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/4936a807
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/4936a807
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/4936a807

Branch: refs/heads/master
Commit: 4936a807736557718dbc0690b92240806de5f3a9
Parents: 9ad6d12
Author: Andy Hadjigeorgiou <ah...@columbia.edu>
Authored: Fri Dec 8 10:16:44 2017 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Fri Dec 8 10:16:44 2017 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/redshift_hook.py          | 42 ++++++-----
 .../sensors/aws_redshift_cluster_sensor.py      | 46 ++++++++++++
 tests/contrib/hooks/test_redshift_hook.py       | 22 +++++-
 .../sensors/test_aws_redshift_cluster_sensor.py | 76 ++++++++++++++++++++
 4 files changed, 167 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4936a807/airflow/contrib/hooks/redshift_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/redshift_hook.py b/airflow/contrib/hooks/redshift_hook.py
index 071caf2..70a4854 100644
--- a/airflow/contrib/hooks/redshift_hook.py
+++ b/airflow/contrib/hooks/redshift_hook.py
@@ -14,6 +14,7 @@
 
 from airflow.contrib.hooks.aws_hook import AwsHook
 
+
 class RedshiftHook(AwsHook):
     """
     Interact with AWS Redshift, using the boto3 library
@@ -26,29 +27,36 @@ class RedshiftHook(AwsHook):
         """
         Return status of a cluster
 
-        :param cluster_identifier: unique identifier of a cluster whose properties you are requesting
+        :param cluster_identifier: unique identifier of a cluster
         :type cluster_identifier: str
         """
-        # Use describe clusters
-        response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)
-        # Possibly return error if cluster does not exist
-        return response['Clusters'][0]['ClusterStatus'] if response['Clusters'] else None
+        conn = self.get_conn()
+        try:
+            response = conn.describe_clusters(
+                ClusterIdentifier=cluster_identifier)['Clusters']
+            return response[0]['ClusterStatus'] if response else None
+        except conn.exceptions.ClusterNotFoundFault:
+            return 'cluster_not_found'
 
-    def delete_cluster(self, cluster_identifier, skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''):
+    def delete_cluster(
+            self,
+            cluster_identifier,
+            skip_final_cluster_snapshot=True,
+            final_cluster_snapshot_identifier=''):
         """
         Delete a cluster and optionally create a snapshot
 
-        :param cluster_identifier: unique identifier of a cluster whose properties you are requesting
+        :param cluster_identifier: unique identifier of a cluster
         :type cluster_identifier: str
-        :param skip_final_cluster_snapshot: determines if a final cluster snapshot is made before shut-down
+        :param skip_final_cluster_snapshot: determines cluster snapshot creation
         :type skip_final_cluster_snapshot: bool
         :param final_cluster_snapshot_identifier: name of final cluster snapshot
         :type final_cluster_snapshot_identifier: str
         """
         response = self.get_conn().delete_cluster(
-            ClusterIdentifier = cluster_identifier,
-            SkipFinalClusterSnapshot = skip_final_cluster_snapshot,
-            FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier
+            ClusterIdentifier=cluster_identifier,
+            SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
+            FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier
         )
         return response['Cluster'] if response['Cluster'] else None
 
@@ -56,11 +64,11 @@ class RedshiftHook(AwsHook):
         """
         Gets a list of snapshots for a cluster
 
-        :param cluster_identifier: unique identifier of a cluster whose properties you are requesting
+        :param cluster_identifier: unique identifier of a cluster
         :type cluster_identifier: str
         """
         response = self.get_conn().describe_cluster_snapshots(
-            ClusterIdentifier = cluster_identifier
+            ClusterIdentifier=cluster_identifier
         )
         if 'Snapshots' not in response:
             return None
@@ -73,14 +81,14 @@ class RedshiftHook(AwsHook):
         """
         Restores a cluster from it's snapshot
 
-        :param cluster_identifier: unique identifier of a cluster whose properties you are requesting
+        :param cluster_identifier: unique identifier of a cluster
         :type cluster_identifier: str
         :param snapshot_identifier: unique identifier for a snapshot of a cluster
         :type snapshot_identifier: str
         """
         response = self.get_conn().restore_from_cluster_snapshot(
-            ClusterIdentifier = cluster_identifier,
-            SnapshotIdentifier = snapshot_identifier
+            ClusterIdentifier=cluster_identifier,
+            SnapshotIdentifier=snapshot_identifier
         )
         return response['Cluster'] if response['Cluster'] else None
 
@@ -90,7 +98,7 @@ class RedshiftHook(AwsHook):
 
         :param snapshot_identifier: unique identifier for a snapshot of a cluster
         :type snapshot_identifier: str
-        :param cluster_identifier: unique identifier of a cluster whose properties you are requesting
+        :param cluster_identifier: unique identifier of a cluster
         :type cluster_identifier: str
         """
         response = self.get_conn().create_cluster_snapshot(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4936a807/airflow/contrib/sensors/aws_redshift_cluster_sensor.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py
new file mode 100644
index 0000000..8db85e6
--- /dev/null
+++ b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.operators.sensors import BaseSensorOperator
+from airflow.contrib.hooks.redshift_hook import RedshiftHook
+from airflow.utils.decorators import apply_defaults
+
+
+class AwsRedshiftClusterSensor(BaseSensorOperator):
+    """
+    Waits for a Redshift cluster to reach a specific status.
+
+    :param cluster_identifier: The identifier for the cluster being pinged.
+    :type cluster_identifier: str
+    :param target_status: The cluster status desired.
+    :type target_status: str
+    """
+    template_fields = ('cluster_identifier', 'target_status')
+
+    @apply_defaults
+    def __init__(
+            self, cluster_identifier,
+            target_status='available',
+            aws_conn_id='aws_default',
+            *args, **kwargs):
+        super(AwsRedshiftClusterSensor, self).__init__(*args, **kwargs)
+        self.cluster_identifier = cluster_identifier
+        self.target_status = target_status
+        self.aws_conn_id = aws_conn_id
+
+    def poke(self, context):
+        self.log.info('Poking for status : {self.target_status}\n'
+                      'for cluster {self.cluster_identifier}'.format(**locals()))
+        hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+        return hook.cluster_status(self.cluster_identifier) == self.target_status

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4936a807/tests/contrib/hooks/test_redshift_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_redshift_hook.py b/tests/contrib/hooks/test_redshift_hook.py
index 185be5e..c7884a3 100644
--- a/tests/contrib/hooks/test_redshift_hook.py
+++ b/tests/contrib/hooks/test_redshift_hook.py
@@ -25,6 +25,7 @@ try:
 except ImportError:
     mock_redshift = None
 
+
 @mock_redshift
 class TestRedshiftHook(unittest.TestCase):
     def setUp(self):
@@ -56,8 +57,12 @@ class TestRedshiftHook(unittest.TestCase):
     @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
     def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self):
         hook = RedshiftHook(aws_conn_id='aws_default')
-        snapshot = hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
-        self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'], 'test_cluster_3')
+        hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
+        self.assertEqual(
+            hook.restore_from_cluster_snapshot(
+                'test_cluster_3', 'test_snapshot'
+            )['ClusterIdentifier'],
+            'test_cluster_3')
 
     @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
     def test_delete_cluster_returns_a_dict_with_cluster_data(self):
@@ -73,5 +78,18 @@ class TestRedshiftHook(unittest.TestCase):
         snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster')
         self.assertNotEqual(snapshot, None)
 
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_cluster_status_returns_cluster_not_found(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+        status = hook.cluster_status('test_cluster_not_here')
+        self.assertEqual(status, 'cluster_not_found')
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_cluster_status_returns_available_cluster(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+        status = hook.cluster_status('test_cluster')
+        self.assertEqual(status, 'available')
+
+
 if __name__ == '__main__':
     unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4936a807/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py b/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py
new file mode 100644
index 0000000..a5c9e66
--- /dev/null
+++ b/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.sensors.aws_redshift_cluster_sensor import AwsRedshiftClusterSensor
+
+try:
+    from moto import mock_redshift
+except ImportError:
+    mock_redshift = None
+
+
+@mock_redshift
+class TestAwsRedshiftClusterSensor(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        client = boto3.client('redshift', region_name='us-east-1')
+        client.create_cluster(
+            ClusterIdentifier='test_cluster',
+            NodeType='dc1.large',
+            MasterUsername='admin',
+            MasterUserPassword='mock_password'
+        )
+        if len(client.describe_clusters()['Clusters']) == 0:
+            raise ValueError('AWS not properly mocked')
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_poke(self):
+        op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
+                                      poke_interval=1,
+                                      timeout=5,
+                                      aws_conn_id='aws_default',
+                                      cluster_identifier='test_cluster',
+                                      target_status='available')
+        self.assertTrue(op.poke(None))
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_poke_false(self):
+        op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
+                                      poke_interval=1,
+                                      timeout=5,
+                                      aws_conn_id='aws_default',
+                                      cluster_identifier='test_cluster_not_found',
+                                      target_status='available')
+
+        self.assertFalse(op.poke(None))
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_poke_cluster_not_found(self):
+        op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
+                                      poke_interval=1,
+                                      timeout=5,
+                                      aws_conn_id='aws_default',
+                                      cluster_identifier='test_cluster_not_found',
+                                      target_status='cluster_not_found')
+
+        self.assertTrue(op.poke(None))
+
+
+if __name__ == '__main__':
+    unittest.main()