You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/09/25 17:54:25 UTC

[airflow] branch main updated: Respect `soft_fail` argument when running `BatchSensors` (#34592)

This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a133e8b52 Respect `soft_fail` argument when running `BatchSensors` (#34592)
5a133e8b52 is described below

commit 5a133e8b52618262eb8d49a45172f0f1ea7c8c1f
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Mon Sep 25 23:24:19 2023 +0530

    Respect `soft_fail` argument when running `BatchSensors` (#34592)
---
 airflow/providers/amazon/aws/sensors/batch.py    | 37 +++++++++---
 tests/providers/amazon/aws/sensors/test_batch.py | 76 ++++++++++++++++++++++++
 2 files changed, 105 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py
index cc4f0d0c67..4788049e1f 100644
--- a/airflow/providers/amazon/aws/sensors/batch.py
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -83,9 +83,17 @@ class BatchSensor(BaseSensorOperator):
             return False
 
         if state == BatchClientHook.FAILURE_STATE:
-            raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}")
+            # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+            message = f"Batch sensor failed. AWS Batch job status: {state}"
+            if self.soft_fail:
+                raise AirflowSkipException(message)
+            raise AirflowException(message)
 
-        raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}")
+        # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+        message = f"Batch sensor failed. Unknown AWS Batch job status: {state}"
+        if self.soft_fail:
+            raise AirflowSkipException(message)
+        raise AirflowException(message)
 
     def execute(self, context: Context) -> None:
         if not self.deferrable:
@@ -182,7 +190,11 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
         )
 
         if not response["computeEnvironments"]:
-            raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found")
+            message = f"AWS Batch compute environment {self.compute_environment} not found"
+            # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+            if self.soft_fail:
+                raise AirflowSkipException(message)
+            raise AirflowException(message)
 
         status = response["computeEnvironments"][0]["status"]
 
@@ -192,9 +204,11 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
         if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS:
             return False
 
-        raise AirflowException(
-            f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
-        )
+        # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+        message = f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
+        if self.soft_fail:
+            raise AirflowSkipException(message)
+        raise AirflowException(message)
 
 
 class BatchJobQueueSensor(BaseSensorOperator):
@@ -250,7 +264,11 @@ class BatchJobQueueSensor(BaseSensorOperator):
             if self.treat_non_existing_as_deleted:
                 return True
             else:
-                raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
+                # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+                message = f"AWS Batch job queue {self.job_queue} not found"
+                if self.soft_fail:
+                    raise AirflowSkipException(message)
+                raise AirflowException(message)
 
         status = response["jobQueues"][0]["status"]
 
@@ -260,4 +278,7 @@ class BatchJobQueueSensor(BaseSensorOperator):
         if status in BatchClientHook.JOB_QUEUE_INTERMEDIATE_STATUS:
             return False
 
-        raise AirflowException(f"AWS Batch job queue failed. AWS Batch job queue status: {status}")
+        message = f"AWS Batch job queue failed. AWS Batch job queue status: {status}"
+        if self.soft_fail:
+            raise AirflowSkipException(message)
+        raise AirflowException(message)
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py
index a7d0cb6974..267aeb998f 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -106,6 +106,34 @@ class TestBatchSensor:
         with pytest.raises(AirflowSkipException):
             deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+    )
+    @pytest.mark.parametrize(
+        "state, error_message",
+        (
+            (
+                BatchClientHook.FAILURE_STATE,
+                f"Batch sensor failed. AWS Batch job status: {BatchClientHook.FAILURE_STATE}",
+            ),
+            ("unknown_state", "Batch sensor failed. Unknown AWS Batch job status: unknown_state"),
+        ),
+    )
+    @mock.patch.object(BatchClientHook, "get_job_description")
+    def test_fail_poke(
+        self,
+        mock_get_job_description,
+        batch_sensor: BatchSensor,
+        state,
+        error_message,
+        soft_fail,
+        expected_exception,
+    ):
+        mock_get_job_description.return_value = {"status": state}
+        batch_sensor.soft_fail = soft_fail
+        with pytest.raises(expected_exception, match=error_message):
+            batch_sensor.poke({})
+
 
 @pytest.fixture(scope="module")
 def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
@@ -174,6 +202,34 @@ class TestBatchComputeEnvironmentSensor:
         )
         assert "AWS Batch compute environment failed" in str(ctx.value)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+    )
+    @pytest.mark.parametrize(
+        "compute_env, error_message",
+        (
+            (
+                [{"status": "unknown_status"}],
+                "AWS Batch compute environment failed. AWS Batch compute environment status:",
+            ),
+            ([], "AWS Batch compute environment"),
+        ),
+    )
+    @mock.patch.object(BatchClientHook, "client")
+    def test_fail_poke(
+        self,
+        mock_batch_client,
+        batch_compute_environment_sensor: BatchComputeEnvironmentSensor,
+        compute_env,
+        error_message,
+        soft_fail,
+        expected_exception,
+    ):
+        mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": compute_env}
+        batch_compute_environment_sensor.soft_fail = soft_fail
+        with pytest.raises(expected_exception, match=error_message):
+            batch_compute_environment_sensor.poke({})
+
 
 @pytest.fixture(scope="module")
 def batch_job_queue_sensor() -> BatchJobQueueSensor:
@@ -242,3 +298,23 @@ class TestBatchJobQueueSensor:
             jobQueues=[JOB_QUEUE],
         )
         assert "AWS Batch job queue failed" in str(ctx.value)
+
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+    )
+    @pytest.mark.parametrize("job_queue", ([], [{"status": "UNKNOWN_STATUS"}]))
+    @mock.patch.object(BatchClientHook, "client")
+    def test_fail_poke(
+        self,
+        mock_batch_client,
+        batch_job_queue_sensor: BatchJobQueueSensor,
+        job_queue,
+        soft_fail,
+        expected_exception,
+    ):
+        mock_batch_client.describe_job_queues.return_value = {"jobQueues": job_queue}
+        batch_job_queue_sensor.treat_non_existing_as_deleted = False
+        batch_job_queue_sensor.soft_fail = soft_fail
+        message = "AWS Batch job queue"
+        with pytest.raises(expected_exception, match=message):
+            batch_job_queue_sensor.poke({})