You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/11/26 18:23:48 UTC
incubator-airflow git commit: [AIRFLOW-1790] Add support for AWS
Batch operator
Repository: incubator-airflow
Updated Branches:
refs/heads/master 2728cde34 -> 68d3a80dc
[AIRFLOW-1790] Add support for AWS Batch operator
Closes #2762 from hprudent/aws-batch
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/68d3a80d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/68d3a80d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/68d3a80d
Branch: refs/heads/master
Commit: 68d3a80dcb427bcb9e3bee354fbff3d0379d9c1c
Parents: 2728cde
Author: Hugo Prudente <hp...@amazon.com>
Authored: Sun Nov 26 19:23:26 2017 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Sun Nov 26 19:23:29 2017 +0100
----------------------------------------------------------------------
airflow/contrib/operators/awsbatch_operator.py | 162 ++++++++++++++
docs/integration.rst | 13 ++
.../contrib/operators/test_awsbatch_operator.py | 214 +++++++++++++++++++
3 files changed, 389 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68d3a80d/airflow/contrib/operators/awsbatch_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py
new file mode 100644
index 0000000..25262dd
--- /dev/null
+++ b/airflow/contrib/operators/awsbatch_operator.py
@@ -0,0 +1,162 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import sys
+
+from math import pow
+from time import sleep
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils import apply_defaults
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+class AWSBatchOperator(BaseOperator):
+ """
+ Execute a job on AWS Batch Service
+
+ :param job_name: the name for the job that will run on AWS Batch
+ :type job_name: str
+ :param job_definition: the job definition name on AWS Batch
+ :type job_definition: str
+ :param queue: the queue name on AWS Batch
+ :type queue: str
+ :param: overrides: the same parameter that boto3 will receive on containerOverrides:
+ http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
+ :type: overrides: dict
+ :param max_retries: exponential backoff retries while waiter is not merged
+ :type max_retries: int
+ :param aws_conn_id: connection id of AWS credentials / region name. If None,
+ credential boto3 strategy will be used (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+ :type aws_conn_id: str
+ :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided)
+ """
+
+ ui_color = '#c3dae0'
+ client = None
+ arn = None
+ template_fields = ('overrides',)
+
+ @apply_defaults
+ def __init__(self, job_name, job_definition, queue, overrides, max_retries=288,
+ aws_conn_id=None, region_name=None, **kwargs):
+ super(AWSBatchOperator, self).__init__(**kwargs)
+
+ self.job_name = job_name
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.job_definition = job_definition
+ self.queue = queue
+ self.overrides = overrides
+ self.max_retries = max_retries
+
+ self.jobId = None
+ self.jobName = None
+
+ self.hook = self.get_hook()
+
+ def execute(self, context):
+ self.log.info(
+ 'Running AWS Batch Job - Job definition: %s - on queue %s',
+ self.job_definition, self.queue
+ )
+ self.log.info('AWSBatchOperator overrides: %s', self.overrides)
+
+ self.client = self.hook.get_client_type(
+ 'batch',
+ region_name=self.region_name
+ )
+
+ try:
+ response = self.client.submit_job(
+ jobName=self.job_name,
+ jobQueue=self.queue,
+ jobDefinition=self.job_definition,
+ containerOverrides=self.overrides)
+
+ self.log.info('AWS Batch Job started: %s', response)
+
+ self.jobId = response['jobId']
+ self.jobName = response['jobName']
+
+ self._wait_for_task_ended()
+
+ self._check_success_task()
+
+ self.log.info('AWS Batch Job has been successfully executed: %s', response)
+ except Exception as e:
+ self.log.info('AWS Batch Job has failed executed')
+ raise AirflowException(e)
+
+ def _wait_for_task_ended(self):
+ """
+ Try to use a waiter from the below pull request
+
+ * https://github.com/boto/botocore/pull/1307
+
+ If the waiter is not available apply a exponential backoff
+
+ * docs.aws.amazon.com/general/latest/gr/api-retries.html
+ """
+ try:
+ waiter = self.client.get_waiter('job_execution_complete')
+ waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
+ waiter.wait(jobs=[self.jobId])
+ except ValueError:
+ # If waiter not available use expo
+ retry = True
+ retries = 0
+
+ while retries < self.max_retries or retry:
+ response = self.client.describe_jobs(
+ jobs=[self.jobId]
+ )
+ if response['jobs'][-1]['status'] in ['SUCCEEDED', 'FAILED']:
+ retry = False
+
+ sleep(pow(2, retries) * 100)
+ retries += 1
+
+ def _check_success_task(self):
+ response = self.client.describe_jobs(
+ jobs=[self.jobId],
+ )
+
+ self.log.info('AWS Batch stopped, check status: %s', response)
+ if len(response.get('jobs')) < 1:
+ raise AirflowException('No job found for {}'.format(response))
+
+ for job in response['jobs']:
+ if 'attempts' in job:
+ containers = job['attempts']
+ for container in containers:
+ if job['status'] == 'FAILED' or container['attempts']['exitCode'] != 0:
+ print("@@@@")
+ raise AirflowException('This containers encounter an error during execution {}'.format(job))
+ elif job['status'] is not 'SUCCEEDED':
+ raise AirflowException('This task is still pending {}'.format(job['status']))
+
+ def get_hook(self):
+ return AwsHook(
+ aws_conn_id=self.aws_conn_id
+ )
+
+ def on_kill(self):
+ response = self.client.terminate_job(
+ jobId=self.jobId,
+ reason='Task killed by the user')
+
+ self.log.info(response)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68d3a80d/docs/integration.rst
----------------------------------------------------------------------
diff --git a/docs/integration.rst b/docs/integration.rst
index 4887486..20f0583 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -145,6 +145,19 @@ ECSOperator
.. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator
+AWS Batch Service
+''''''''''''''''''''''''''
+
+- :ref:`AWSBatchOperator` : Execute a task on AWS Batch Service.
+
+.. _AWSBatchOperator:
+
+AWSBatchOperator
+""""""""""""
+
+.. autoclass:: airflow.contrib.operators.awsbatch_operator.AWSBatchOperator
+
+
AWS RedShift
'''''''''''''
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68d3a80d/tests/contrib/operators/test_awsbatch_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_awsbatch_operator.py b/tests/contrib/operators/test_awsbatch_operator.py
new file mode 100644
index 0000000..d19daa5
--- /dev/null
+++ b/tests/contrib/operators/test_awsbatch_operator.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+RESPONSE_WITHOUT_FAILURES = {
+ "jobName": "51455483-c62c-48ac-9b88-53a6a725baa3",
+ "jobId": "8ba9d676-4108-4474-9dca-8bbac1da9b19"
+}
+
+
+class TestAWSBatchOperator(unittest.TestCase):
+
+ @mock.patch('airflow.contrib.operators.awsbatch_operator.AwsHook')
+ def setUp(self, aws_hook_mock):
+ configuration.load_test_config()
+
+ self.aws_hook_mock = aws_hook_mock
+ self.batch = AWSBatchOperator(
+ task_id='task',
+ job_name='51455483-c62c-48ac-9b88-53a6a725baa3',
+ queue='queue',
+ job_definition='hello-world',
+ max_retries=5,
+ overrides={},
+ aws_conn_id=None,
+ region_name='eu-west-1')
+
+ def test_init(self):
+
+ self.assertEqual(self.batch.job_name, '51455483-c62c-48ac-9b88-53a6a725baa3')
+ self.assertEqual(self.batch.queue, 'queue')
+ self.assertEqual(self.batch.job_definition, 'hello-world')
+ self.assertEqual(self.batch.max_retries, 5)
+ self.assertEqual(self.batch.overrides, {})
+ self.assertEqual(self.batch.region_name, 'eu-west-1')
+ self.assertEqual(self.batch.aws_conn_id, None)
+ self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value)
+
+ self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
+
+ def test_template_fields_overrides(self):
+ self.assertEqual(self.batch.template_fields, ('overrides',))
+
+ @mock.patch.object(AWSBatchOperator, '_wait_for_task_ended')
+ @mock.patch.object(AWSBatchOperator, '_check_success_task')
+ def test_execute_without_failures(self, check_mock, wait_mock):
+
+ client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+ client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
+
+ self.batch.execute(None)
+
+ self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch', region_name='eu-west-1')
+ client_mock.submit_job.assert_called_once_with(
+ jobQueue='queue',
+ jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
+ containerOverrides={},
+ jobDefinition='hello-world'
+ )
+
+ wait_mock.assert_called_once_with()
+ check_mock.assert_called_once_with()
+ self.assertEqual(self.batch.jobId, '8ba9d676-4108-4474-9dca-8bbac1da9b19')
+
+ def test_execute_with_failures(self):
+
+ client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+ client_mock.submit_job.return_value = ""
+
+ with self.assertRaises(AirflowException):
+ self.batch.execute(None)
+
+ self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch', region_name='eu-west-1')
+ client_mock.submit_job.assert_called_once_with(
+ jobQueue='queue',
+ jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
+ containerOverrides={},
+ jobDefinition='hello-world'
+ )
+
+ def test_wait_end_tasks(self):
+
+ client_mock = mock.Mock()
+ self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+ self.batch.client = client_mock
+
+ self.batch._wait_for_task_ended()
+
+ client_mock.get_waiter.assert_called_once_with('job_execution_complete')
+ client_mock.get_waiter.return_value.wait.assert_called_once_with(
+ jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19']
+ )
+ self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts)
+
+ def test_check_success_tasks_raises(self):
+ client_mock = mock.Mock()
+ self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+ self.batch.client = client_mock
+
+ client_mock.describe_jobs.return_value = {
+ 'jobs': []
+ }
+
+ with self.assertRaises(Exception) as e:
+ self.batch._check_success_task()
+
+ # Ordering of str(dict) is not guaranteed.
+ self.assertIn('No job found for ', str(e.exception))
+
+ def test_check_success_tasks_raises_failed(self):
+ client_mock = mock.Mock()
+ self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+ self.batch.client = client_mock
+
+ client_mock.describe_jobs.return_value = {
+ 'jobs': [{
+ 'status': 'FAILED',
+ 'attempts': [{
+ 'exitCode': 1
+ }]
+ }]
+ }
+
+ with self.assertRaises(Exception) as e:
+ self.batch._check_success_task()
+
+ # Ordering of str(dict) is not guaranteed.
+ self.assertIn('This containers encounter an error during execution ', str(e.exception))
+
+ def test_check_success_tasks_raises_pending(self):
+ client_mock = mock.Mock()
+ self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+ self.batch.client = client_mock
+
+ client_mock.describe_jobs.return_value = {
+ 'jobs': [{
+ 'status': 'RUNNABLE'
+ }]
+ }
+
+ with self.assertRaises(Exception) as e:
+ self.batch._check_success_task()
+
+ # Ordering of str(dict) is not guaranteed.
+ self.assertIn('This task is still pending ', str(e.exception))
+
+ def test_check_success_tasks_raises_mutliple(self):
+ client_mock = mock.Mock()
+ self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+ self.batch.client = client_mock
+
+ client_mock.describe_jobs.return_value = {
+ 'jobs': [{
+ 'status': 'FAILED',
+ 'attempts': [{
+ 'exitCode': 1
+ }, {
+ 'exitCode': 10
+ }]
+ }]
+ }
+
+ with self.assertRaises(Exception) as e:
+ self.batch._check_success_task()
+
+ # Ordering of str(dict) is not guaranteed.
+ self.assertIn('This containers encounter an error during execution ', str(e.exception))
+
+ def test_check_success_task_not_raises(self):
+ client_mock = mock.Mock()
+ self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
+ self.batch.client = client_mock
+
+ client_mock.describe_jobs.return_value = {
+ 'jobs': [{
+ 'status': 'SUCCEEDED'
+ }]
+ }
+
+ self.batch._check_success_task()
+
+ # Ordering of str(dict) is not guaranteed.
+ client_mock.describe_jobs.assert_called_once_with(jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'])
+
+
+if __name__ == '__main__':
+ unittest.main()