You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/09/25 17:54:25 UTC
[airflow] branch main updated: Respect `soft_fail` argument when running `BatchSensors` (#34592)
This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5a133e8b52 Respect `soft_fail` argument when running `BatchSensors` (#34592)
5a133e8b52 is described below
commit 5a133e8b52618262eb8d49a45172f0f1ea7c8c1f
Author: Utkarsh Sharma <ut...@gmail.com>
AuthorDate: Mon Sep 25 23:24:19 2023 +0530
Respect `soft_fail` argument when running `BatchSensors` (#34592)
---
airflow/providers/amazon/aws/sensors/batch.py | 37 +++++++++---
tests/providers/amazon/aws/sensors/test_batch.py | 76 ++++++++++++++++++++++++
2 files changed, 105 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py
index cc4f0d0c67..4788049e1f 100644
--- a/airflow/providers/amazon/aws/sensors/batch.py
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -83,9 +83,17 @@ class BatchSensor(BaseSensorOperator):
return False
if state == BatchClientHook.FAILURE_STATE:
- raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}")
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ message = f"Batch sensor failed. AWS Batch job status: {state}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
- raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}")
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ message = f"Batch sensor failed. Unknown AWS Batch job status: {state}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
def execute(self, context: Context) -> None:
if not self.deferrable:
@@ -182,7 +190,11 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
)
if not response["computeEnvironments"]:
- raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found")
+ message = f"AWS Batch compute environment {self.compute_environment} not found"
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
status = response["computeEnvironments"][0]["status"]
@@ -192,9 +204,11 @@ class BatchComputeEnvironmentSensor(BaseSensorOperator):
if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS:
return False
- raise AirflowException(
- f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
- )
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ message = f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
class BatchJobQueueSensor(BaseSensorOperator):
@@ -250,7 +264,11 @@ class BatchJobQueueSensor(BaseSensorOperator):
if self.treat_non_existing_as_deleted:
return True
else:
- raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
+ message = f"AWS Batch job queue {self.job_queue} not found"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
status = response["jobQueues"][0]["status"]
@@ -260,4 +278,7 @@ class BatchJobQueueSensor(BaseSensorOperator):
if status in BatchClientHook.JOB_QUEUE_INTERMEDIATE_STATUS:
return False
- raise AirflowException(f"AWS Batch job queue failed. AWS Batch job queue status: {status}")
+ message = f"AWS Batch job queue failed. AWS Batch job queue status: {status}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py
index a7d0cb6974..267aeb998f 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -106,6 +106,34 @@ class TestBatchSensor:
with pytest.raises(AirflowSkipException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+ )
+ @pytest.mark.parametrize(
+ "state, error_message",
+ (
+ (
+ BatchClientHook.FAILURE_STATE,
+ f"Batch sensor failed. AWS Batch job status: {BatchClientHook.FAILURE_STATE}",
+ ),
+ ("unknown_state", "Batch sensor failed. Unknown AWS Batch job status: unknown_state"),
+ ),
+ )
+ @mock.patch.object(BatchClientHook, "get_job_description")
+ def test_fail_poke(
+ self,
+ mock_get_job_description,
+ batch_sensor: BatchSensor,
+ state,
+ error_message,
+ soft_fail,
+ expected_exception,
+ ):
+ mock_get_job_description.return_value = {"status": state}
+ batch_sensor.soft_fail = soft_fail
+ with pytest.raises(expected_exception, match=error_message):
+ batch_sensor.poke({})
+
@pytest.fixture(scope="module")
def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
@@ -174,6 +202,34 @@ class TestBatchComputeEnvironmentSensor:
)
assert "AWS Batch compute environment failed" in str(ctx.value)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+ )
+ @pytest.mark.parametrize(
+ "compute_env, error_message",
+ (
+ (
+ [{"status": "unknown_status"}],
+ "AWS Batch compute environment failed. AWS Batch compute environment status:",
+ ),
+ ([], "AWS Batch compute environment"),
+ ),
+ )
+ @mock.patch.object(BatchClientHook, "client")
+ def test_fail_poke(
+ self,
+ mock_batch_client,
+ batch_compute_environment_sensor: BatchComputeEnvironmentSensor,
+ compute_env,
+ error_message,
+ soft_fail,
+ expected_exception,
+ ):
+ mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": compute_env}
+ batch_compute_environment_sensor.soft_fail = soft_fail
+ with pytest.raises(expected_exception, match=error_message):
+ batch_compute_environment_sensor.poke({})
+
@pytest.fixture(scope="module")
def batch_job_queue_sensor() -> BatchJobQueueSensor:
@@ -242,3 +298,23 @@ class TestBatchJobQueueSensor:
jobQueues=[JOB_QUEUE],
)
assert "AWS Batch job queue failed" in str(ctx.value)
+
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+ )
+ @pytest.mark.parametrize("job_queue", ([], [{"status": "UNKNOWN_STATUS"}]))
+ @mock.patch.object(BatchClientHook, "client")
+ def test_fail_poke(
+ self,
+ mock_batch_client,
+ batch_job_queue_sensor: BatchJobQueueSensor,
+ job_queue,
+ soft_fail,
+ expected_exception,
+ ):
+ mock_batch_client.describe_job_queues.return_value = {"jobQueues": job_queue}
+ batch_job_queue_sensor.treat_non_existing_as_deleted = False
+ batch_job_queue_sensor.soft_fail = soft_fail
+ message = "AWS Batch job queue"
+ with pytest.raises(expected_exception, match=message):
+ batch_job_queue_sensor.poke({})