You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/11/26 18:23:48 UTC

incubator-airflow git commit: [AIRFLOW-1790] Add support for AWS Batch operator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 2728cde34 -> 68d3a80dc


[AIRFLOW-1790] Add support for AWS Batch operator

Closes #2762 from hprudent/aws-batch


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/68d3a80d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/68d3a80d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/68d3a80d

Branch: refs/heads/master
Commit: 68d3a80dcb427bcb9e3bee354fbff3d0379d9c1c
Parents: 2728cde
Author: Hugo Prudente <hp...@amazon.com>
Authored: Sun Nov 26 19:23:26 2017 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Sun Nov 26 19:23:29 2017 +0100

----------------------------------------------------------------------
 airflow/contrib/operators/awsbatch_operator.py  | 162 ++++++++++++++
 docs/integration.rst                            |  13 ++
 .../contrib/operators/test_awsbatch_operator.py | 214 +++++++++++++++++++
 3 files changed, 389 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68d3a80d/airflow/contrib/operators/awsbatch_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py
new file mode 100644
index 0000000..25262dd
--- /dev/null
+++ b/airflow/contrib/operators/awsbatch_operator.py
@@ -0,0 +1,162 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import sys
+
+from math import pow
+from time import sleep
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils import apply_defaults
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+class AWSBatchOperator(BaseOperator):
+    """
+    Execute a job on AWS Batch Service
+
+    :param job_name: the name for the job that will run on AWS Batch
+    :type job_name: str
+    :param job_definition: the job definition name on AWS Batch
+    :type job_definition: str
+    :param queue: the queue name on AWS Batch
+    :type queue: str
+    :param: overrides: the same parameter that boto3 will receive on containerOverrides:
+            http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
+    :type: overrides: dict
+    :param max_retries: exponential backoff retries while waiter is not merged
+    :type max_retries: int
+    :param aws_conn_id: connection id of AWS credentials / region name. If None,
+            credential boto3 strategy will be used (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+    :type aws_conn_id: str
+    :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided)
+    """
+
+    ui_color = '#c3dae0'
+    client = None
+    arn = None
+    template_fields = ('overrides',)
+
+    @apply_defaults
+    def __init__(self, job_name, job_definition, queue, overrides, max_retries=288,
+                 aws_conn_id=None, region_name=None, **kwargs):
+        super(AWSBatchOperator, self).__init__(**kwargs)
+
+        self.job_name = job_name
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.job_definition = job_definition
+        self.queue = queue
+        self.overrides = overrides
+        self.max_retries = max_retries
+
+        self.jobId = None
+        self.jobName = None
+
+        self.hook = self.get_hook()
+
+    def execute(self, context):
+        self.log.info(
+            'Running AWS Batch Job - Job definition: %s - on queue %s',
+            self.job_definition, self.queue
+        )
+        self.log.info('AWSBatchOperator overrides: %s', self.overrides)
+
+        self.client = self.hook.get_client_type(
+            'batch',
+            region_name=self.region_name
+        )
+
+        try:
+            response = self.client.submit_job(
+                jobName=self.job_name,
+                jobQueue=self.queue,
+                jobDefinition=self.job_definition,
+                containerOverrides=self.overrides)
+
+            self.log.info('AWS Batch Job started: %s', response)
+
+            self.jobId = response['jobId']
+            self.jobName = response['jobName']
+
+            self._wait_for_task_ended()
+
+            self._check_success_task()
+
+            self.log.info('AWS Batch Job has been successfully executed: %s', response)
+        except Exception as e:
+            self.log.info('AWS Batch Job has failed executed')
+            raise AirflowException(e)
+
+    def _wait_for_task_ended(self):
+        """
+        Try to use a waiter from the below pull request
+
+            * https://github.com/boto/botocore/pull/1307
+
+        If the waiter is not available apply a exponential backoff
+
+            * docs.aws.amazon.com/general/latest/gr/api-retries.html
+        """
+        try:
+            waiter = self.client.get_waiter('job_execution_complete')
+            waiter.config.max_attempts = sys.maxsize  # timeout is managed by airflow
+            waiter.wait(jobs=[self.jobId])
+        except ValueError:
+            # If waiter not available use expo
+            retry = True
+            retries = 0
+
+            while retries < self.max_retries or retry:
+                response = self.client.describe_jobs(
+                    jobs=[self.jobId]
+                )
+                if response['jobs'][-1]['status'] in ['SUCCEEDED', 'FAILED']:
+                    retry = False
+
+                sleep(pow(2, retries) * 100)
+                retries += 1
+
+    def _check_success_task(self):
+        response = self.client.describe_jobs(
+            jobs=[self.jobId],
+        )
+
+        self.log.info('AWS Batch stopped, check status: %s', response)
+        if len(response.get('jobs')) < 1:
+            raise AirflowException('No job found for {}'.format(response))
+
+        for job in response['jobs']:
+            if 'attempts' in job:
+                containers = job['attempts']
+                for container in containers:
+                    if job['status'] == 'FAILED' or container['attempts']['exitCode'] != 0:
+                        print("@@@@")
+                        raise AirflowException('This containers encounter an error during execution {}'.format(job))
+            elif job['status'] is not 'SUCCEEDED':
+                raise AirflowException('This task is still pending {}'.format(job['status']))
+
+    def get_hook(self):
+        return AwsHook(
+            aws_conn_id=self.aws_conn_id
+        )
+
+    def on_kill(self):
+        response = self.client.terminate_job(
+            jobId=self.jobId,
+            reason='Task killed by the user')
+
+        self.log.info(response)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68d3a80d/docs/integration.rst
----------------------------------------------------------------------
diff --git a/docs/integration.rst b/docs/integration.rst
index 4887486..20f0583 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -145,6 +145,19 @@ ECSOperator
 .. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator
 
 
+AWS Batch Service
+''''''''''''''''''''''''''
+
+- :ref:`AWSBatchOperator` : Execute a task on AWS Batch Service.
+
+.. _AWSBatchOperator:
+
+AWSBatchOperator
+""""""""""""
+
+.. autoclass:: airflow.contrib.operators.awsbatch_operator.AWSBatchOperator
+
+
 AWS RedShift
 '''''''''''''
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68d3a80d/tests/contrib/operators/test_awsbatch_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_awsbatch_operator.py b/tests/contrib/operators/test_awsbatch_operator.py
new file mode 100644
index 0000000..d19daa5
--- /dev/null
+++ b/tests/contrib/operators/test_awsbatch_operator.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+RESPONSE_WITHOUT_FAILURES = {
+    "jobName": "51455483-c62c-48ac-9b88-53a6a725baa3",
+    "jobId": "8ba9d676-4108-4474-9dca-8bbac1da9b19"
+}
+
+
+class TestAWSBatchOperator(unittest.TestCase):
+
+    @mock.patch('airflow.contrib.operators.awsbatch_operator.AwsHook')
+    def setUp(self, aws_hook_mock):
+        configuration.load_test_config()
+
+        self.aws_hook_mock = aws_hook_mock
+        self.batch = AWSBatchOperator(
+            task_id='task',
+            job_name='51455483-c62c-48ac-9b88-53a6a725baa3',
+            queue='queue',
+            job_definition='hello-world',
+            max_retries=5,
+            overrides={},
+            aws_conn_id=None,
+            region_name='eu-west-1')
+
+    def test_init(self):
+
+        self.assertEqual(self.batch.job_name, '51455483-c62c-48ac-9b88-53a6a725baa3')
+        self.assertEqual(self.batch.queue, 'queue')
+        self.assertEqual(self.batch.job_definition, 'hello-world')
+        self.assertEqual(self.batch.max_retries, 5)
+        self.assertEqual(self.batch.overrides, {})
+        self.assertEqual(self.batch.region_name, 'eu-west-1')
+        self.assertEqual(self.batch.aws_conn_id, None)
+        self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value)
+
+        self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
+
+    def test_template_fields_overrides(self):
+        self.assertEqual(self.batch.template_fields, ('overrides',))
+
+    @mock.patch.object(AWSBatchOperator, '_wait_for_task_ended')
+    @mock.patch.object(AWSBatchOperator, '_check_success_task')
+    def test_execute_without_failures(self, check_mock, wait_mock):
+
+        client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+        client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
+
+        self.batch.execute(None)
+
+        self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch', region_name='eu-west-1')
+        client_mock.submit_job.assert_called_once_with(
+            jobQueue='queue',
+            jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
+            containerOverrides={},
+            jobDefinition='hello-world'
+        )
+
+        wait_mock.assert_called_once_with()
+        check_mock.assert_called_once_with()
+        self.assertEqual(self.batch.jobId, '8ba9d676-4108-4474-9dca-8bbac1da9b19')
+
+    def test_execute_with_failures(self):
+
+        client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+        client_mock.submit_job.return_value = ""
+
+        with self.assertRaises(AirflowException):
+            self.batch.execute(None)
+
+        self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch', region_name='eu-west-1')
+        client_mock.submit_job.assert_called_once_with(
+            jobQueue='queue',
+            jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
+            containerOverrides={},
+            jobDefinition='hello-world'
+        )
+
+    def test_wait_end_tasks(self):
+
+        client_mock = mock.Mock()
+        self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+        self.batch.client = client_mock
+
+        self.batch._wait_for_task_ended()
+
+        client_mock.get_waiter.assert_called_once_with('job_execution_complete')
+        client_mock.get_waiter.return_value.wait.assert_called_once_with(
+            jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19']
+        )
+        self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts)
+
+    def test_check_success_tasks_raises(self):
+        client_mock = mock.Mock()
+        self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+        self.batch.client = client_mock
+
+        client_mock.describe_jobs.return_value = {
+            'jobs': []
+        }
+
+        with self.assertRaises(Exception) as e:
+            self.batch._check_success_task()
+
+        # Ordering of str(dict) is not guaranteed.
+        self.assertIn('No job found for ', str(e.exception))
+
+    def test_check_success_tasks_raises_failed(self):
+        client_mock = mock.Mock()
+        self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+        self.batch.client = client_mock
+
+        client_mock.describe_jobs.return_value = {
+            'jobs': [{
+                'status': 'FAILED',
+                'attempts': [{
+                    'exitCode': 1
+                }]
+            }]
+        }
+
+        with self.assertRaises(Exception) as e:
+            self.batch._check_success_task()
+
+        # Ordering of str(dict) is not guaranteed.
+        self.assertIn('This containers encounter an error during execution ', str(e.exception))
+
+    def test_check_success_tasks_raises_pending(self):
+        client_mock = mock.Mock()
+        self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+        self.batch.client = client_mock
+
+        client_mock.describe_jobs.return_value = {
+            'jobs': [{
+                'status': 'RUNNABLE'
+            }]
+        }
+
+        with self.assertRaises(Exception) as e:
+            self.batch._check_success_task()
+
+        # Ordering of str(dict) is not guaranteed.
+        self.assertIn('This task is still pending ', str(e.exception))
+
+    def test_check_success_tasks_raises_mutliple(self):
+        client_mock = mock.Mock()
+        self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+        self.batch.client = client_mock
+
+        client_mock.describe_jobs.return_value = {
+            'jobs': [{
+                'status': 'FAILED',
+                'attempts': [{
+                    'exitCode': 1
+                }, {
+                    'exitCode': 10
+                }]
+            }]
+        }
+
+        with self.assertRaises(Exception) as e:
+            self.batch._check_success_task()
+
+        # Ordering of str(dict) is not guaranteed.
+        self.assertIn('This containers encounter an error during execution ', str(e.exception))
+
+    def test_check_success_task_not_raises(self):
+        client_mock = mock.Mock()
+        self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+        self.batch.client = client_mock
+
+        client_mock.describe_jobs.return_value = {
+            'jobs': [{
+                'status': 'SUCCEEDED'
+            }]
+        }
+
+        self.batch._check_success_task()
+
+        # Ordering of str(dict) is not guaranteed.
+        client_mock.describe_jobs.assert_called_once_with(jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'])
+
+
+if __name__ == '__main__':
+    unittest.main()