You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2016/11/23 18:50:08 UTC
incubator-airflow git commit: [AIRFLOW-345] Add contrib ECSOperator
Repository: incubator-airflow
Updated Branches:
refs/heads/master 41490f9c4 -> 98197d956
[AIRFLOW-345] Add contrib ECSOperator
Closes #1894 from poulainv/ecs_operator
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/98197d95
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/98197d95
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/98197d95
Branch: refs/heads/master
Commit: 98197d95681abaae0ec8f928e0147a8b32132ecb
Parents: 41490f9
Author: Vincent Poulain <vi...@tinyclues.com>
Authored: Wed Nov 23 10:49:57 2016 -0800
Committer: Siddharth Anand <si...@yahoo.com>
Committed: Wed Nov 23 10:49:57 2016 -0800
----------------------------------------------------------------------
airflow/contrib/hooks/aws_hook.py | 27 +++-
airflow/contrib/operators/ecs_operator.py | 127 +++++++++++++++
docs/code.rst | 1 +
tests/contrib/operators/ecs_operator.py | 207 +++++++++++++++++++++++++
4 files changed, 356 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/airflow/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py
index 37a02ee..3eced28 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -12,24 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import boto3
+
+from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
class AwsHook(BaseHook):
"""
Interact with AWS.
-
This class is a thin wrapper around the boto3 python library.
"""
def __init__(self, aws_conn_id='aws_default'):
self.aws_conn_id = aws_conn_id
- def get_client_type(self, client_type):
- connection_object = self.get_connection(self.aws_conn_id)
+ def get_client_type(self, client_type, region_name=None):
+ try:
+ connection_object = self.get_connection(self.aws_conn_id)
+ aws_access_key_id = connection_object.login
+ aws_secret_access_key = connection_object.password
+
+ if region_name is None:
+ region_name = connection_object.extra_dejson.get('region_name')
+
+ except AirflowException:
+ # No connection found: fallback on boto3 credential strategy
+ # http://boto3.readthedocs.io/en/latest/guide/configuration.html
+ aws_access_key_id = None
+ aws_secret_access_key = None
+
return boto3.client(
client_type,
- region_name=connection_object.extra_dejson.get('region_name'),
- aws_access_key_id=connection_object.login,
- aws_secret_access_key=connection_object.password,
+ region_name=region_name,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key
)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/airflow/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py
new file mode 100644
index 0000000..7415d32
--- /dev/null
+++ b/airflow/contrib/operators/ecs_operator.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import logging
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils import apply_defaults
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+class ECSOperator(BaseOperator):
+
+ """
+ Execute a task on AWS EC2 Container Service
+
+ :param task_definition: the task definition name on EC2 Container Service
+ :type task_definition: str
+ :param cluster: the cluster name on EC2 Container Service
+ :type cluster: str
+ :param: overrides: the same parameter that boto3 will receive: http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task
+ :type: overrides: dict
+ :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+ :type aws_conn_id: str
+ :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided)
+ """
+
+ ui_color = '#f0ede4'
+ client = None
+ arn = None
+ template_fields = ('overrides',)
+
+ @apply_defaults
+ def __init__(self, task_definition, cluster, overrides,
+ aws_conn_id=None, region_name=None, **kwargs):
+ super(ECSOperator, self).__init__(**kwargs)
+
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.task_definition = task_definition
+ self.cluster = cluster
+ self.overrides = overrides
+
+ self.hook = self.get_hook()
+
+ def execute(self, context):
+
+ logging.info('Running ECS Task - Task definition: {} - on cluster {}'.format(
+ self.task_definition,
+ self.cluster
+ ))
+ logging.info('ECSOperator overrides: {}'.format(self.overrides))
+
+ self.client = self.hook.get_client_type(
+ 'ecs',
+ region_name=self.region_name
+ )
+
+ response = self.client.run_task(
+ cluster=self.cluster,
+ taskDefinition=self.task_definition,
+ overrides=self.overrides,
+ startedBy=self.owner
+ )
+
+ failures = response['failures']
+ if (len(failures) > 0):
+ raise AirflowException(response)
+ logging.info('ECS Task started: {}'.format(response))
+
+ self.arn = response['tasks'][0]['taskArn']
+ self._wait_for_task_ended()
+
+ self._check_success_task()
+ logging.info('ECS Task has been successfully executed: {}'.format(response))
+
+ def _wait_for_task_ended(self):
+ waiter = self.client.get_waiter('tasks_stopped')
+ waiter.config.max_attempts = sys.maxint # timeout is managed by airflow
+ waiter.wait(
+ cluster=self.cluster,
+ tasks=[self.arn]
+ )
+
+ def _check_success_task(self):
+ response = self.client.describe_tasks(
+ cluster=self.cluster,
+ tasks=[self.arn]
+ )
+ logging.info('ECS Task stopped, check status: {}'.format(response))
+
+ if (len(response.get('failures', [])) > 0):
+ raise AirflowException(response)
+
+ for task in response['tasks']:
+ containers = task['containers']
+ for container in containers:
+ if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0:
+ raise AirflowException('This task is not in success state {}'.format(task))
+ elif container.get('lastStatus') == 'PENDING':
+ raise AirflowException('This task is still pending {}'.format(task))
+ elif 'error' in container.get('reason', '').lower():
+ raise AirflowException('This containers encounter an error during launching : {}'.format(container.get('reason', '').lower()))
+
+ def get_hook(self):
+ return AwsHook(
+ aws_conn_id=self.aws_conn_id
+ )
+
+ def on_kill(self):
+ response = self.client.stop_task(
+ cluster=self.cluster,
+ task=self.arn,
+ reason='Task killed by the user')
+ logging.info(response)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 8548120..0e1993e 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -97,6 +97,7 @@ Community-contributed Operators
.. autoclass:: airflow.contrib.operators.bigquery_operator.BigQueryOperator
.. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
+.. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator
.. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator
.. autoclass:: airflow.contrib.operators.QuboleOperator
.. autoclass:: airflow.contrib.operators.hipchat_operator.HipChatAPIOperator
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/tests/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/ecs_operator.py b/tests/contrib/operators/ecs_operator.py
new file mode 100644
index 0000000..5a593a6
--- /dev/null
+++ b/tests/contrib/operators/ecs_operator.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+from copy import deepcopy
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.ecs_operator import ECSOperator
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+RESPONSE_WITHOUT_FAILURES = {
+ "failures": [],
+ "tasks": [
+ {
+ "containers": [
+ {
+ "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868",
+ "lastStatus": "PENDING",
+ "name": "wordpress",
+ "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55"
+ }
+ ],
+ "desiredStatus": "RUNNING",
+ "lastStatus": "PENDING",
+ "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
+ "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11"
+ }
+ ]
+}
+
+
+class TestECSOperator(unittest.TestCase):
+
+ @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
+ def setUp(self, aws_hook_mock):
+ configuration.load_test_config()
+
+ self.aws_hook_mock = aws_hook_mock
+ self.ecs = ECSOperator(
+ task_id='task',
+ task_definition='t',
+ cluster='c',
+ overrides={},
+ aws_conn_id=None,
+ region_name='eu-west-1')
+
+ def test_init(self):
+
+ self.assertEqual(self.ecs.region_name, 'eu-west-1')
+ self.assertEqual(self.ecs.task_definition, 't')
+ self.assertEqual(self.ecs.aws_conn_id, None)
+ self.assertEqual(self.ecs.cluster, 'c')
+ self.assertEqual(self.ecs.overrides, {})
+ self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value)
+
+ self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
+
+ def test_template_fields_overrides(self):
+ self.assertEqual(self.ecs.template_fields, ('overrides',))
+
+ @mock.patch.object(ECSOperator, '_wait_for_task_ended')
+ @mock.patch.object(ECSOperator, '_check_success_task')
+ def test_execute_without_failures(self, check_mock, wait_mock):
+
+ client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+ client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
+
+ self.ecs.execute(None)
+
+ self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
+ client_mock.run_task.assert_called_once_with(
+ cluster='c',
+ overrides={},
+ startedBy='Airflow',
+ taskDefinition='t'
+ )
+
+ wait_mock.assert_called_once_with()
+ check_mock.assert_called_once_with()
+ self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55')
+
+ def test_execute_with_failures(self):
+
+ client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+ resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES)
+ resp_failures['failures'].append('dummy error')
+ client_mock.run_task.return_value = resp_failures
+
+ with self.assertRaises(AirflowException):
+ self.ecs.execute(None)
+
+ self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
+ client_mock.run_task.assert_called_once_with(
+ cluster='c',
+ overrides={},
+ startedBy='Airflow',
+ taskDefinition='t'
+ )
+
+ def test_wait_end_tasks(self):
+
+ client_mock = mock.Mock()
+ self.ecs.arn = 'arn'
+ self.ecs.client = client_mock
+
+ self.ecs._wait_for_task_ended()
+ client_mock.get_waiter.assert_called_once_with('tasks_stopped')
+ client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn'])
+ self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts)
+
+ def test_check_success_tasks_raises(self):
+ client_mock = mock.Mock()
+ self.ecs.arn = 'arn'
+ self.ecs.client = client_mock
+
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'foo',
+ 'lastStatus': 'STOPPED',
+ 'exitCode': 1
+ }]
+ }]
+ }
+ with self.assertRaises(Exception) as e:
+ self.ecs._check_success_task()
+
+ self.assertEquals(str(e.exception), "This task is not in success state {'containers': [{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}")
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+ def test_check_success_tasks_raises_pending(self):
+ client_mock = mock.Mock()
+ self.ecs.client = client_mock
+ self.ecs.arn = 'arn'
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'container-name',
+ 'lastStatus': 'PENDING'
+ }]
+ }]
+ }
+ with self.assertRaises(Exception) as e:
+ self.ecs._check_success_task()
+ self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus': 'PENDING', 'name': 'container-name'}]}")
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+ def test_check_success_tasks_raises_mutliple(self):
+ client_mock = mock.Mock()
+ self.ecs.client = client_mock
+ self.ecs.arn = 'arn'
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'foo',
+ 'exitCode': 1
+ }, {
+ 'name': 'bar',
+ 'lastStatus': 'STOPPED',
+ 'exitCode': 0
+ }]
+ }]
+ }
+ self.ecs._check_success_task()
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+ def test_check_success_task_not_raises(self):
+ client_mock = mock.Mock()
+ self.ecs.client = client_mock
+ self.ecs.arn = 'arn'
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'container-name',
+ 'lastStatus': 'STOPPED',
+ 'exitCode': 0
+ }]
+ }]
+ }
+ self.ecs._check_success_task()
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+
+if __name__ == '__main__':
+ unittest.main()