You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/09/19 18:49:11 UTC
[airflow] branch main updated: EMR Serverless Fix for Jobs marked as success even on failure (#26218)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8f1c78f6e0 EMR Serverless Fix for Jobs marked as success even on failure (#26218)
8f1c78f6e0 is described below
commit 8f1c78f6e08b184004b5b4f1f4b0eafa1d08aef3
Author: syedahsn <10...@users.noreply.github.com>
AuthorDate: Mon Sep 19 12:48:39 2022 -0600
EMR Serverless Fix for Jobs marked as success even on failure (#26218)
* Change desired state in EmrServerlessStartJobOperator to SUCCESS_STATES rather than TERMINAL_STATES. This makes it so that the task is marked as a failure if the job doesn't run successfully.
* Define states in hook instead of Sensor.
Add test to cover job failure exception.
* Use APPLICATION_SUCCESS_STATES instead of JOB_SUCCESS_STATES to allow tests to pass.
---
airflow/providers/amazon/aws/hooks/emr.py | 9 ++++++
airflow/providers/amazon/aws/operators/emr.py | 17 ++++++------
airflow/providers/amazon/aws/sensors/emr.py | 17 +++---------
.../amazon/aws/operators/test_emr_serverless.py | 32 ++++++++++++++++++++++
4 files changed, 53 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index ebb2a79ef8..fb18c32394 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -105,6 +105,15 @@ class EmrServerlessHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
+ JOB_INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
+ JOB_FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
+ JOB_SUCCESS_STATES = {'SUCCESS'}
+ JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES)
+
+ APPLICATION_INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
+ APPLICATION_FAILURE_STATES = {'STOPPED', 'TERMINATED'}
+ APPLICATION_SUCCESS_STATES = {'CREATED', 'STARTED'}
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 5028dfed86..5ccff487e9 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -26,7 +26,6 @@ from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink
-from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -552,7 +551,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
get_state_args={'applicationId': application_id},
parse_response=['application', 'state'],
desired_state={'CREATED'},
- failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
object_type='application',
action='created',
)
@@ -567,7 +566,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
get_state_args={'applicationId': application_id},
parse_response=['application', 'state'],
desired_state={'STARTED'},
- failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
object_type='application',
action='started',
)
@@ -633,7 +632,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
self.log.info('Starting job on Application: %s', self.application_id)
app_state = self.hook.conn.get_application(applicationId=self.application_id)['application']['state']
- if app_state not in EmrServerlessApplicationSensor.SUCCESS_STATES:
+ if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES:
self.hook.conn.start_application(applicationId=self.application_id)
self.hook.waiter(
@@ -641,7 +640,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
get_state_args={'applicationId': self.application_id},
parse_response=['application', 'state'],
desired_state={'STARTED'},
- failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
object_type='application',
action='started',
)
@@ -668,8 +667,8 @@ class EmrServerlessStartJobOperator(BaseOperator):
'jobRunId': response['jobRunId'],
},
parse_response=['jobRun', 'state'],
- desired_state=EmrServerlessJobSensor.TERMINAL_STATES,
- failure_states=EmrServerlessJobSensor.FAILURE_STATES,
+ desired_state=EmrServerlessHook.JOB_SUCCESS_STATES,
+ failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
object_type='job',
action='run',
)
@@ -719,7 +718,7 @@ class EmrServerlessDeleteApplicationOperator(BaseOperator):
'applicationId': self.application_id,
},
parse_response=['application', 'state'],
- desired_state=EmrServerlessApplicationSensor.FAILURE_STATES,
+ desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
failure_states=set(),
object_type='application',
action='stopped',
@@ -738,7 +737,7 @@ class EmrServerlessDeleteApplicationOperator(BaseOperator):
get_state_args={'applicationId': self.application_id},
parse_response=['application', 'state'],
desired_state={'TERMINATED'},
- failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
object_type='application',
action='deleted',
)
diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py
index 0b1c5c686b..4759b3d838 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -129,11 +129,6 @@ class EmrServerlessJobSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
"""
- INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
- FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
- SUCCESS_STATES = {'SUCCESS'}
- TERMINAL_STATES = SUCCESS_STATES.union(FAILURE_STATES)
-
template_fields: Sequence[str] = (
'application_id',
'job_run_id',
@@ -144,7 +139,7 @@ class EmrServerlessJobSensor(BaseSensorOperator):
*,
application_id: str,
job_run_id: str,
- target_states: set | frozenset = frozenset(SUCCESS_STATES),
+ target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES),
aws_conn_id: str = 'aws_default',
**kwargs: Any,
) -> None:
@@ -159,7 +154,7 @@ class EmrServerlessJobSensor(BaseSensorOperator):
state = response['jobRun']['state']
- if state in self.FAILURE_STATES:
+ if state in EmrServerlessHook.JOB_FAILURE_STATES:
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
raise AirflowException(failure_message)
@@ -198,15 +193,11 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
template_fields: Sequence[str] = ('application_id',)
- INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
- FAILURE_STATES = {'STOPPED', 'TERMINATED'}
- SUCCESS_STATES = {'CREATED', 'STARTED'}
-
def __init__(
self,
*,
application_id: str,
- target_states: set | frozenset = frozenset(SUCCESS_STATES),
+ target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES),
aws_conn_id: str = 'aws_default',
**kwargs: Any,
) -> None:
@@ -220,7 +211,7 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
state = response['application']['state']
- if state in self.FAILURE_STATES:
+ if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
raise AirflowException(failure_message)
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 688cc2711b..a175bc8d10 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -220,6 +220,38 @@ class TestEmrServerlessStartJobOperator:
configurationOverrides=configuration_overrides,
)
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_job_run_job_failed(self, mock_conn):
+ mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
+ mock_conn.start_job_run.return_value = {
+ 'jobRunId': job_run_id,
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+
+ mock_conn.get_job_run.return_value = {'jobRun': {'state': 'FAILED'}}
+
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ client_request_token=client_request_token,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ )
+ with pytest.raises(AirflowException) as ex_message:
+ id = operator.execute(None)
+ assert id == job_run_id
+ assert "Job reached failure state FAILED." in str(ex_message.value)
+ mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+ mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id)
+ mock_conn.start_job_run.assert_called_once_with(
+ clientToken=client_request_token,
+ applicationId=application_id,
+ executionRoleArn=execution_role_arn,
+ jobDriver=job_driver,
+ configurationOverrides=configuration_overrides,
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
def test_job_run_app_not_started(self, mock_conn, mock_waiter):