You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/03/07 01:48:33 UTC

[airflow] branch main updated: retry on very specific eni provision failures (#22002)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 01a1a26  retry on very specific eni provision failures (#22002)
01a1a26 is described below

commit 01a1a263fdac53f4cd9fa82e5ae89172cf68b91c
Author: Zach Liu <za...@users.noreply.github.com>
AuthorDate: Sun Mar 6 20:47:49 2022 -0500

    retry on very specific eni provision failures (#22002)
---
 airflow/providers/amazon/aws/exceptions.py       | 11 +++++++
 airflow/providers/amazon/aws/operators/ecs.py    | 40 +++++++++++++++---------
 tests/providers/amazon/aws/operators/test_ecs.py | 34 ++++++++++++++++++--
 3 files changed, 68 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/amazon/aws/exceptions.py b/airflow/providers/amazon/aws/exceptions.py
index 6ae5dab..27f46df 100644
--- a/airflow/providers/amazon/aws/exceptions.py
+++ b/airflow/providers/amazon/aws/exceptions.py
@@ -21,6 +21,17 @@
 import warnings
 
 
+class EcsTaskFailToStart(Exception):
+    """Raise when ECS tasks fail to start AFTER processing the request."""
+
+    def __init__(self, message: str):
+        self.message = message
+        super().__init__(message)
+
+    def __reduce__(self):
+        return EcsTaskFailToStart, (self.message)
+
+
 class EcsOperatorError(Exception):
     """Raise when ECS cannot handle the request."""
 
diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index f982084..8a4a039 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -30,7 +30,7 @@ from botocore.waiter import Waiter
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, XCom
-from airflow.providers.amazon.aws.exceptions import EcsOperatorError
+from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.typing_compat import Protocol, runtime_checkable
@@ -48,6 +48,13 @@ def should_retry(exception: Exception):
     return False
 
 
+def should_retry_eni(exception: Exception):
+    """Check if exception is related to ENI (Elastic Network Interfaces)."""
+    if isinstance(exception, EcsTaskFailToStart):
+        return any(eni_reason in exception.message for eni_reason in ['network interface provisioning'])
+    return False
+
+
 @runtime_checkable
 class EcsProtocol(Protocol):
     """
@@ -287,6 +294,23 @@ class EcsOperator(BaseOperator):
         if self.reattach:
             self._try_reattach_task(context)
 
+        self._start_wait_check_task(context)
+
+        self.log.info('ECS Task has been successfully executed')
+
+        if self.reattach:
+            # Clear the XCom value storing the ECS task ARN if the task has completed
+            # as we can't reattach it anymore
+            self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
+
+        if self.do_xcom_push and self.task_log_fetcher:
+            return self.task_log_fetcher.get_last_log_message()
+
+        return None
+
+    @AwsBaseHook.retry(should_retry_eni)
+    def _start_wait_check_task(self, context):
+
         if not self.arn:
             self._start_task(context)
 
@@ -306,18 +330,6 @@ class EcsOperator(BaseOperator):
 
         self._check_success_task()
 
-        self.log.info('ECS Task has been successfully executed')
-
-        if self.reattach:
-            # Clear the XCom value storing the ECS task ARN if the task has completed
-            # as we can't reattach it anymore
-            self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
-
-        if self.do_xcom_push and self.task_log_fetcher:
-            return self.task_log_fetcher.get_last_log_message()
-
-        return None
-
     def _xcom_del(self, session, task_id):
         session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete()
 
@@ -438,7 +450,7 @@ class EcsOperator(BaseOperator):
         for task in response['tasks']:
 
             if task.get('stopCode', '') == 'TaskFailedToStart':
-                raise AirflowException(f"The task failed to start due to: {task.get('stoppedReason', '')}")
+                raise EcsTaskFailToStart(f"The task failed to start due to: {task.get('stoppedReason', '')}")
 
             # This is a `stoppedReason` that indicates a task has not
             # successfully finished, but there is no other indication of failure
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py
index a29f2a8..3a2bcce 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -28,8 +28,13 @@ from botocore.exceptions import ClientError
 from parameterized import parameterized
 
 from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.exceptions import EcsOperatorError
-from airflow.providers.amazon.aws.operators.ecs import EcsOperator, EcsTaskLogFetcher, should_retry
+from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
+from airflow.providers.amazon.aws.operators.ecs import (
+    EcsOperator,
+    EcsTaskLogFetcher,
+    should_retry,
+    should_retry_eni,
+)
 
 # fmt: off
 RESPONSE_WITHOUT_FAILURES = {
@@ -261,7 +266,7 @@ class TestEcsOperator(unittest.TestCase):
             ]
         }
 
-        with pytest.raises(Exception) as ctx:
+        with pytest.raises(EcsTaskFailToStart) as ctx:
             self.ecs._check_success_task()
 
         assert str(ctx.value) == "The task failed to start due to: Task failed to start"
@@ -558,6 +563,29 @@ class TestShouldRetry(unittest.TestCase):
         self.assertFalse(should_retry(EcsOperatorError([{'reason': 'CLUSTER_NOT_FOUND'}], 'Foo')))
 
 
+class TestShouldRetryEni(unittest.TestCase):
+    def test_return_true_on_valid_reason(self):
+        self.assertTrue(
+            should_retry_eni(
+                EcsTaskFailToStart(
+                    "The task failed to start due to: "
+                    "Timeout waiting for network interface provisioning to complete."
+                )
+            )
+        )
+
+    def test_return_false_on_invalid_reason(self):
+        self.assertFalse(
+            should_retry_eni(
+                EcsTaskFailToStart(
+                    "The task failed to start due to: "
+                    "CannotPullContainerError: "
+                    "ref pull has been retried 5 time(s): failed to resolve reference"
+                )
+            )
+        )
+
+
 class TestEcsTaskLogFetcher(unittest.TestCase):
     @mock.patch('logging.Logger')
     def set_up_log_fetcher(self, logger_mock):