You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/09/19 18:49:11 UTC

[airflow] branch main updated: EMR Serverless Fix for Jobs marked as success even on failure (#26218)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f1c78f6e0 EMR Serverless Fix for Jobs marked as success even on failure (#26218)
8f1c78f6e0 is described below

commit 8f1c78f6e08b184004b5b4f1f4b0eafa1d08aef3
Author: syedahsn <10...@users.noreply.github.com>
AuthorDate: Mon Sep 19 12:48:39 2022 -0600

    EMR Serverless Fix for Jobs marked as success even on failure (#26218)
    
    * Change desired state in EmrServerlessStartJobOperator to SUCCESS_STATES rather than TERMINAL_STATES. This makes it so that the task is marked as a failure if the job doesn't run successfully.
    
    * Define states in hook instead of Sensor.
    Add test to cover job failure exception.
    
    * Use APPLICATION_SUCCESS_STATES instead of JOB_SUCCESS_STATES to allow tests to pass.
---
 airflow/providers/amazon/aws/hooks/emr.py          |  9 ++++++
 airflow/providers/amazon/aws/operators/emr.py      | 17 ++++++------
 airflow/providers/amazon/aws/sensors/emr.py        | 17 +++---------
 .../amazon/aws/operators/test_emr_serverless.py    | 32 ++++++++++++++++++++++
 4 files changed, 53 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index ebb2a79ef8..fb18c32394 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -105,6 +105,15 @@ class EmrServerlessHook(AwsBaseHook):
         :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
     """
 
+    JOB_INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
+    JOB_FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
+    JOB_SUCCESS_STATES = {'SUCCESS'}
+    JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES)
+
+    APPLICATION_INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
+    APPLICATION_FAILURE_STATES = {'STOPPED', 'TERMINATED'}
+    APPLICATION_SUCCESS_STATES = {'CREATED', 'STARTED'}
+
     def __init__(self, *args: Any, **kwargs: Any) -> None:
         kwargs["client_type"] = "emr-serverless"
         super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 5028dfed86..5ccff487e9 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -26,7 +26,6 @@ from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
 from airflow.providers.amazon.aws.links.emr import EmrClusterLink
-from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -552,7 +551,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
             get_state_args={'applicationId': application_id},
             parse_response=['application', 'state'],
             desired_state={'CREATED'},
-            failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+            failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
             object_type='application',
             action='created',
         )
@@ -567,7 +566,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
                 get_state_args={'applicationId': application_id},
                 parse_response=['application', 'state'],
                 desired_state={'STARTED'},
-                failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+                failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
                 object_type='application',
                 action='started',
             )
@@ -633,7 +632,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
         self.log.info('Starting job on Application: %s', self.application_id)
 
         app_state = self.hook.conn.get_application(applicationId=self.application_id)['application']['state']
-        if app_state not in EmrServerlessApplicationSensor.SUCCESS_STATES:
+        if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES:
             self.hook.conn.start_application(applicationId=self.application_id)
 
             self.hook.waiter(
@@ -641,7 +640,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
                 get_state_args={'applicationId': self.application_id},
                 parse_response=['application', 'state'],
                 desired_state={'STARTED'},
-                failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+                failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
                 object_type='application',
                 action='started',
             )
@@ -668,8 +667,8 @@ class EmrServerlessStartJobOperator(BaseOperator):
                     'jobRunId': response['jobRunId'],
                 },
                 parse_response=['jobRun', 'state'],
-                desired_state=EmrServerlessJobSensor.TERMINAL_STATES,
-                failure_states=EmrServerlessJobSensor.FAILURE_STATES,
+                desired_state=EmrServerlessHook.JOB_SUCCESS_STATES,
+                failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
                 object_type='job',
                 action='run',
             )
@@ -719,7 +718,7 @@ class EmrServerlessDeleteApplicationOperator(BaseOperator):
                 'applicationId': self.application_id,
             },
             parse_response=['application', 'state'],
-            desired_state=EmrServerlessApplicationSensor.FAILURE_STATES,
+            desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
             failure_states=set(),
             object_type='application',
             action='stopped',
@@ -738,7 +737,7 @@ class EmrServerlessDeleteApplicationOperator(BaseOperator):
                 get_state_args={'applicationId': self.application_id},
                 parse_response=['application', 'state'],
                 desired_state={'TERMINATED'},
-                failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+                failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
                 object_type='application',
                 action='deleted',
             )
diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py
index 0b1c5c686b..4759b3d838 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -129,11 +129,6 @@ class EmrServerlessJobSensor(BaseSensorOperator):
     :param aws_conn_id: aws connection to use, defaults to 'aws_default'
     """
 
-    INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
-    FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
-    SUCCESS_STATES = {'SUCCESS'}
-    TERMINAL_STATES = SUCCESS_STATES.union(FAILURE_STATES)
-
     template_fields: Sequence[str] = (
         'application_id',
         'job_run_id',
@@ -144,7 +139,7 @@ class EmrServerlessJobSensor(BaseSensorOperator):
         *,
         application_id: str,
         job_run_id: str,
-        target_states: set | frozenset = frozenset(SUCCESS_STATES),
+        target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES),
         aws_conn_id: str = 'aws_default',
         **kwargs: Any,
     ) -> None:
@@ -159,7 +154,7 @@ class EmrServerlessJobSensor(BaseSensorOperator):
 
         state = response['jobRun']['state']
 
-        if state in self.FAILURE_STATES:
+        if state in EmrServerlessHook.JOB_FAILURE_STATES:
             failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
             raise AirflowException(failure_message)
 
@@ -198,15 +193,11 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
 
     template_fields: Sequence[str] = ('application_id',)
 
-    INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
-    FAILURE_STATES = {'STOPPED', 'TERMINATED'}
-    SUCCESS_STATES = {'CREATED', 'STARTED'}
-
     def __init__(
         self,
         *,
         application_id: str,
-        target_states: set | frozenset = frozenset(SUCCESS_STATES),
+        target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES),
         aws_conn_id: str = 'aws_default',
         **kwargs: Any,
     ) -> None:
@@ -220,7 +211,7 @@ class EmrServerlessApplicationSensor(BaseSensorOperator):
 
         state = response['application']['state']
 
-        if state in self.FAILURE_STATES:
+        if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
             failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
             raise AirflowException(failure_message)
 
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 688cc2711b..a175bc8d10 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -220,6 +220,38 @@ class TestEmrServerlessStartJobOperator:
             configurationOverrides=configuration_overrides,
         )
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+    def test_job_run_job_failed(self, mock_conn):
+        mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
+        mock_conn.start_job_run.return_value = {
+            'jobRunId': job_run_id,
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+
+        mock_conn.get_job_run.return_value = {'jobRun': {'state': 'FAILED'}}
+
+        operator = EmrServerlessStartJobOperator(
+            task_id=task_id,
+            client_request_token=client_request_token,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=job_driver,
+            configuration_overrides=configuration_overrides,
+        )
+        with pytest.raises(AirflowException) as ex_message:
+            id = operator.execute(None)
+            assert id == job_run_id
+        assert "Job reached failure state FAILED." in str(ex_message.value)
+        mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+        mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id)
+        mock_conn.start_job_run.assert_called_once_with(
+            clientToken=client_request_token,
+            applicationId=application_id,
+            executionRoleArn=execution_role_arn,
+            jobDriver=job_driver,
+            configurationOverrides=configuration_overrides,
+        )
+
     @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
     def test_job_run_app_not_started(self, mock_conn, mock_waiter):