You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/06/12 23:55:11 UTC

[airflow] branch main updated: Add support of capacity provider strategy for ECSOperator (#15848)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 30708b5  Add support of capacity provider strategy for ECSOperator (#15848)
30708b5 is described below

commit 30708b5b254960395d8061e8c403294b93900c4d
Author: Pavel Hlushchanka <co...@users.noreply.github.com>
AuthorDate: Sun Jun 13 01:54:55 2021 +0200

    Add support of capacity provider strategy for ECSOperator (#15848)
---
 airflow/providers/amazon/aws/operators/ecs.py    | 12 ++++-
 tests/providers/amazon/aws/operators/test_ecs.py | 66 +++++++++++++++++++-----
 2 files changed, 63 insertions(+), 15 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index 1580e35..0c89317 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -106,6 +106,11 @@ class ECSOperator(BaseOperator):  # pylint: disable=too-many-instance-attributes
     :type region_name: str
     :param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE')
     :type launch_type: str
+    :param capacity_provider_strategy: the capacity provider strategy to use for the task.
+        When capacity_provider_strategy is specified, the launch_type parameter is omitted.
+        If no capacity_provider_strategy or launch_type is specified,
+        the default capacity provider strategy for the cluster is used.
+    :type capacity_provider_strategy: list
     :param group: the name of the task group associated with the task
     :type group: str
     :param placement_constraints: an array of placement constraint objects to use for
@@ -153,6 +158,7 @@ class ECSOperator(BaseOperator):  # pylint: disable=too-many-instance-attributes
         aws_conn_id: Optional[str] = None,
         region_name: Optional[str] = None,
         launch_type: str = 'EC2',
+        capacity_provider_strategy: Optional[list] = None,
         group: Optional[str] = None,
         placement_constraints: Optional[list] = None,
         placement_strategy: Optional[list] = None,
@@ -175,6 +181,7 @@ class ECSOperator(BaseOperator):  # pylint: disable=too-many-instance-attributes
         self.cluster = cluster
         self.overrides = overrides
         self.launch_type = launch_type
+        self.capacity_provider_strategy = capacity_provider_strategy
         self.group = group
         self.placement_constraints = placement_constraints
         self.placement_strategy = placement_strategy
@@ -229,7 +236,10 @@ class ECSOperator(BaseOperator):  # pylint: disable=too-many-instance-attributes
             'startedBy': self.owner,
         }
 
-        if self.launch_type:
+        if self.capacity_provider_strategy:
+            run_opts['capacityProviderStrategy'] = self.capacity_provider_strategy
+            run_opts['platformVersion'] = self.platform_version
+        elif self.launch_type:
             run_opts['launchType'] = self.launch_type
             if self.launch_type == 'FARGATE':
                 run_opts['platformVersion'] = self.platform_version
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py
index 96717c3..cf4049e 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -96,30 +96,68 @@ class TestECSOperator(unittest.TestCase):
 
     @parameterized.expand(
         [
-            ['EC2', None],
-            ['FARGATE', None],
-            ['EC2', {'testTagKey': 'testTagValue'}],
-            ['', {'testTagKey': 'testTagValue'}],
+            ['EC2', None, None, {'launchType': 'EC2'}],
+            ['FARGATE', None, None, {'launchType': 'FARGATE', 'platformVersion': 'LATEST'}],
+            [
+                'EC2',
+                None,
+                {'testTagKey': 'testTagValue'},
+                {'launchType': 'EC2', 'tags': [{'key': 'testTagKey', 'value': 'testTagValue'}]},
+            ],
+            [
+                '',
+                None,
+                {'testTagKey': 'testTagValue'},
+                {'tags': [{'key': 'testTagKey', 'value': 'testTagValue'}]},
+            ],
+            [
+                None,
+                {'capacityProvider': 'FARGATE_SPOT'},
+                None,
+                {
+                    'capacityProviderStrategy': {'capacityProvider': 'FARGATE_SPOT'},
+                    'platformVersion': 'LATEST',
+                },
+            ],
+            [
+                'FARGATE',
+                {'capacityProvider': 'FARGATE_SPOT', 'weight': 123, 'base': 123},
+                None,
+                {
+                    'capacityProviderStrategy': {
+                        'capacityProvider': 'FARGATE_SPOT',
+                        'weight': 123,
+                        'base': 123,
+                    },
+                    'platformVersion': 'LATEST',
+                },
+            ],
+            [
+                'EC2',
+                {'capacityProvider': 'FARGATE_SPOT'},
+                None,
+                {
+                    'capacityProviderStrategy': {'capacityProvider': 'FARGATE_SPOT'},
+                    'platformVersion': 'LATEST',
+                },
+            ],
         ]
     )
     @mock.patch.object(ECSOperator, '_wait_for_task_ended')
     @mock.patch.object(ECSOperator, '_check_success_task')
-    def test_execute_without_failures(self, launch_type, tags, check_mock, wait_mock):
+    def test_execute_without_failures(
+        self, launch_type, capacity_provider_strategy, tags, expected_args, check_mock, wait_mock
+    ):
 
-        self.set_up_operator(launch_type=launch_type, tags=tags)  # pylint: disable=no-value-for-parameter
+        self.set_up_operator(  # pylint: disable=no-value-for-parameter
+            launch_type=launch_type, capacity_provider_strategy=capacity_provider_strategy, tags=tags
+        )
         client_mock = self.aws_hook_mock.return_value.get_conn.return_value
         client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
 
         self.ecs.execute(None)
 
         self.aws_hook_mock.return_value.get_conn.assert_called_once()
-        extend_args = {}
-        if launch_type:
-            extend_args['launchType'] = launch_type
-        if launch_type == 'FARGATE':
-            extend_args['platformVersion'] = 'LATEST'
-        if tags:
-            extend_args['tags'] = [{'key': k, 'value': v} for (k, v) in tags.items()]
 
         client_mock.run_task.assert_called_once_with(
             cluster='c',
@@ -133,7 +171,7 @@ class TestECSOperator(unittest.TestCase):
                 'awsvpcConfiguration': {'securityGroups': ['sg-123abc'], 'subnets': ['subnet-123456ab']}
             },
             propagateTags='TASK_DEFINITION',
-            **extend_args,
+            **expected_args,
         )
 
         wait_mock.assert_called_once_with()