You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/05 16:55:03 UTC
[airflow] branch main updated: Add EMR Serverless Operators and Hooks (#25324)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8df84e99b7 Add EMR Serverless Operators and Hooks (#25324)
8df84e99b7 is described below
commit 8df84e99b7319740990124736d0fc545165e7114
Author: syedahsn <10...@users.noreply.github.com>
AuthorDate: Fri Aug 5 10:54:57 2022 -0600
Add EMR Serverless Operators and Hooks (#25324)
---
.../aws/example_dags/example_emr_serverless.py | 97 +++++
airflow/providers/amazon/aws/hooks/emr.py | 75 +++-
airflow/providers/amazon/aws/operators/emr.py | 262 ++++++++++++-
airflow/providers/amazon/aws/sensors/emr.py | 142 ++++++-
airflow/providers/amazon/provider.yaml | 6 +
.../operators/emr_serverless.rst | 113 ++++++
.../amazon/aws/hooks/test_emr_serverless.py | 136 +++++++
.../amazon/aws/operators/test_emr_serverless.py | 409 +++++++++++++++++++++
8 files changed, 1232 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_serverless.py b/airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
new file mode 100644
index 0000000000..b8c0618014
--- /dev/null
+++ b/airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+from os import getenv
+
+from airflow import DAG
+from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.operators.emr import (
+ EmrServerlessCreateApplicationOperator,
+ EmrServerlessDeleteApplicationOperator,
+ EmrServerlessStartJobOperator,
+)
+from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor
+
+EXECUTION_ROLE_ARN = getenv('EXECUTION_ROLE_ARN', 'execution_role_arn')
+EMR_EXAMPLE_BUCKET = getenv('EMR_EXAMPLE_BUCKET', 'emr_example_bucket')
+SPARK_JOB_DRIVER = {
+ "sparkSubmit": {
+ "entryPoint": "s3://us-east-1.elasticmapreduce/emr-containers/samples/wordcount/scripts/wordcount.py",
+ "entryPointArguments": [f"s3://{EMR_EXAMPLE_BUCKET}/output"],
+ "sparkSubmitParameters": "--conf spark.executor.cores=1 --conf spark.executor.memory=4g\
+ --conf spark.driver.cores=1 --conf spark.driver.memory=4g --conf spark.executor.instances=1",
+ }
+}
+
+SPARK_CONFIGURATION_OVERRIDES = {
+ "monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": f"s3://{EMR_EXAMPLE_BUCKET}/logs"}}
+}
+
+with DAG(
+ dag_id='example_emr_serverless',
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ tags=['example'],
+ catchup=False,
+) as emr_serverless_dag:
+
+ # [START howto_operator_emr_serverless_create_application]
+ emr_serverless_app = EmrServerlessCreateApplicationOperator(
+ task_id='create_emr_serverless_task',
+ release_label='emr-6.6.0',
+ job_type="SPARK",
+ config={'name': 'new_application'},
+ )
+ # [END howto_operator_emr_serverless_create_application]
+
+ # [START howto_sensor_emr_serverless_application]
+ wait_for_app_creation = EmrServerlessApplicationSensor(
+ task_id='wait_for_app_creation',
+ application_id=emr_serverless_app.output,
+ )
+ # [END howto_sensor_emr_serverless_application]
+
+ # [START howto_operator_emr_serverless_start_job]
+ start_job = EmrServerlessStartJobOperator(
+ task_id='start_emr_serverless_job',
+ application_id=emr_serverless_app.output,
+ execution_role_arn=EXECUTION_ROLE_ARN,
+ job_driver=SPARK_JOB_DRIVER,
+ configuration_overrides=SPARK_CONFIGURATION_OVERRIDES,
+ )
+ # [END howto_operator_emr_serverless_start_job]
+
+ # [START howto_sensor_emr_serverless_job]
+ wait_for_job = EmrServerlessJobSensor(
+ task_id='wait_for_job', application_id=emr_serverless_app.output, job_run_id=start_job.output
+ )
+ # [END howto_sensor_emr_serverless_job]
+
+ # [START howto_operator_emr_serverless_delete_application]
+ delete_app = EmrServerlessDeleteApplicationOperator(
+ task_id='delete_application', application_id=emr_serverless_app.output, trigger_rule="all_done"
+ )
+ # [END howto_operator_emr_serverless_delete_application]
+
+ chain(
+ emr_serverless_app,
+ wait_for_app_creation,
+ start_job,
+ wait_for_job,
+ delete_app,
+ )
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index 2141b38ed8..e085bff899 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -16,10 +16,11 @@
# specific language governing permissions and limitations
# under the License.
from time import sleep
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Set
from botocore.exceptions import ClientError
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -90,6 +91,78 @@ class EmrHook(AwsBaseHook):
return response
+class EmrServerlessHook(AwsBaseHook):
+ """
+ Interact with EMR Serverless API.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ kwargs["client_type"] = "emr-serverless"
+ super().__init__(*args, **kwargs)
+
+ @cached_property
+ def conn(self):
+ """Get the underlying boto3 EmrServerlessAPIService client (cached)"""
+ return super().conn
+
+ # This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
+ def waiter(
+ self,
+ get_state_callable: Callable,
+ get_state_args: Dict,
+ parse_response: List,
+ desired_state: Set,
+ failure_states: Set,
+ object_type: str,
+ action: str,
+ countdown: int = 25 * 60,
+ check_interval_seconds: int = 60,
+ ) -> None:
+ """
+ Will run the sensor until it turns True.
+
+ :param get_state_callable: A callable to run until it returns True
+ :param get_state_args: Arguments to pass to get_state_callable
+ :param parse_response: Dictionary keys to extract state from response of get_state_callable
+ :param desired_state: Wait until the getter returns this value
+ :param failure_states: A set of states which indicate failure and should throw an
+ exception if any are reached before the desired_state
+ :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
+ :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
+ :param countdown: Total amount of time the waiter should wait for the desired state
+ before timing out (in seconds). Defaults to 25 * 60 seconds.
+ :param check_interval_seconds: Number of seconds waiter should wait before attempting
+ to retry get_state_callable. Defaults to 60 seconds.
+ """
+ response = get_state_callable(**get_state_args)
+ state: str = self.get_state(response, parse_response)
+ while state not in desired_state:
+ if state in failure_states:
+ raise AirflowException(f'{object_type.title()} reached failure state {state}.')
+ if countdown >= check_interval_seconds:
+ countdown -= check_interval_seconds
+ self.log.info('Waiting for %s to be %s.', object_type.lower(), action.lower())
+ sleep(check_interval_seconds)
+ state = self.get_state(get_state_callable(**get_state_args), parse_response)
+ else:
+ message = f'{object_type.title()} still not {action.lower()} after the allocated time limit.'
+ self.log.error(message)
+ raise RuntimeError(message)
+
+ def get_state(self, response, keys) -> str:
+ value = response
+ for key in keys:
+ if value is not None:
+ value = value.get(key, None)
+ return value
+
+
class EmrContainerHook(AwsBaseHook):
"""
Interact with AWS EMR Virtual Cluster to run, poll jobs and return job status
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 4e8a1d96c9..1cac1eb0f5 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -19,15 +19,17 @@ import ast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4
-from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
-from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink
+from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor, EmrServerlessJobSensor
if TYPE_CHECKING:
from airflow.utils.context import Context
+from airflow.compat.functools import cached_property
+
class EmrAddStepsOperator(BaseOperator):
"""
@@ -412,3 +414,259 @@ class EmrTerminateJobFlowOperator(BaseOperator):
raise AirflowException(f'JobFlow termination failed: {response}')
else:
self.log.info('JobFlow with id %s terminated', self.job_flow_id)
+
+
+class EmrServerlessCreateApplicationOperator(BaseOperator):
+ """
+ Operator to create Serverless EMR Application
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrServerlessCreateApplicationOperator`
+
+ :param release_label: The EMR release version associated with the application.
+ :param job_type: The type of application you want to start, such as Spark or Hive.
+ :param wait_for_completion: If true, wait for the Application to start before returning. Default to True
+ :param client_request_token: The client idempotency token of the application to create.
+ Its value must be unique for each request.
+ :param config: Optional dictionary for arbitrary parameters to the boto API create_application call.
+ :param aws_conn_id: AWS connection to use
+ """
+
+ def __init__(
+ self,
+ release_label: str,
+ job_type: str,
+ client_request_token: str = '',
+ config: Optional[dict] = None,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = 'aws_default',
+ **kwargs,
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.release_label = release_label
+ self.job_type = job_type
+ self.wait_for_completion = wait_for_completion
+ self.kwargs = kwargs
+ self.config = config or {}
+ super().__init__(**kwargs)
+
+ self.client_request_token = client_request_token or str(uuid4())
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook."""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: 'Context'):
+ response = self.hook.conn.create_application(
+ clientToken=self.client_request_token,
+ releaseLabel=self.release_label,
+ type=self.job_type,
+ **self.config,
+ )
+ application_id = response['applicationId']
+
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Application Creation failed: {response}')
+
+ self.log.info('EMR serverless application created: %s', application_id)
+
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={'applicationId': application_id},
+ parse_response=['application', 'state'],
+ desired_state={'CREATED'},
+ failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ object_type='application',
+ action='created',
+ )
+
+ self.log.info('Starting application %s', application_id)
+ self.hook.conn.start_application(applicationId=application_id)
+
+ if self.wait_for_completion:
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={'applicationId': application_id},
+ parse_response=['application', 'state'],
+ desired_state={'STARTED'},
+ failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ object_type='application',
+ action='started',
+ )
+
+ return application_id
+
+
+class EmrServerlessStartJobOperator(BaseOperator):
+ """
+ Operator to start EMR Serverless job.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrServerlessStartJobOperator`
+
+ :param application_id: ID of the EMR Serverless application to start.
+ :param execution_role_arn: ARN of role to perform action.
+ :param job_driver: Driver that the job runs on.
+ :param configuration_overrides: Configuration specifications to override existing configurations.
+ :param client_request_token: The client idempotency token of the application to create.
+ Its value must be unique for each request.
+ :param config: Optional dictionary for arbitrary parameters to the boto API start_job_run call.
+ :param wait_for_completion: If true, waits for the job to start before returning. Defaults to True.
+ :param aws_conn_id: AWS connection to use
+ """
+
+ template_fields: Sequence[str] = (
+ 'application_id',
+ 'execution_role_arn',
+ 'job_driver',
+ 'configuration_overrides',
+ )
+
+ def __init__(
+ self,
+ application_id: str,
+ execution_role_arn: str,
+ job_driver: dict,
+ configuration_overrides: Optional[dict],
+ client_request_token: str = '',
+ config: Optional[dict] = None,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = 'aws_default',
+ **kwargs,
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.application_id = application_id
+ self.execution_role_arn = execution_role_arn
+ self.job_driver = job_driver
+ self.configuration_overrides = configuration_overrides
+ self.wait_for_completion = wait_for_completion
+ self.config = config or {}
+ super().__init__(**kwargs)
+
+ self.client_request_token = client_request_token or str(uuid4())
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook."""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: 'Context') -> Dict:
+ self.log.info('Starting job on Application: %s', self.application_id)
+
+ app_state = self.hook.conn.get_application(applicationId=self.application_id)['application']['state']
+ if app_state not in EmrServerlessApplicationSensor.SUCCESS_STATES:
+ self.hook.conn.start_application(applicationId=self.application_id)
+
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={'applicationId': self.application_id},
+ parse_response=['application', 'state'],
+ desired_state={'STARTED'},
+ failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ object_type='application',
+ action='started',
+ )
+
+ response = self.hook.conn.start_job_run(
+ clientToken=self.client_request_token,
+ applicationId=self.application_id,
+ executionRoleArn=self.execution_role_arn,
+ jobDriver=self.job_driver,
+ configurationOverrides=self.configuration_overrides,
+ **self.config,
+ )
+
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'EMR serverless job failed to start: {response}')
+
+ self.log.info('EMR serverless job started: %s', response['jobRunId'])
+ if self.wait_for_completion:
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_job_run,
+ get_state_args={
+ 'applicationId': self.application_id,
+ 'jobRunId': response['jobRunId'],
+ },
+ parse_response=['jobRun', 'state'],
+ desired_state=EmrServerlessJobSensor.TERMINAL_STATES,
+ failure_states=EmrServerlessJobSensor.FAILURE_STATES,
+ object_type='job',
+ action='run',
+ )
+ return response['jobRunId']
+
+
+class EmrServerlessDeleteApplicationOperator(BaseOperator):
+ """
+ Operator to delete EMR Serverless application
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrServerlessDeleteApplicationOperator`
+
+ :param application_id: ID of the EMR Serverless application to delete.
+ :param wait_for_completion: If true, wait for the Application to start before returning. Default to True
+ :param aws_conn_id: AWS connection to use
+ """
+
+ template_fields: Sequence[str] = ('application_id',)
+
+ def __init__(
+ self,
+ application_id: str,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = 'aws_default',
+ **kwargs,
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.application_id = application_id
+ self.wait_for_completion = wait_for_completion
+ super().__init__(**kwargs)
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook."""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: 'Context') -> None:
+ self.log.info('Stopping application: %s', self.application_id)
+ self.hook.conn.stop_application(applicationId=self.application_id)
+
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={
+ 'applicationId': self.application_id,
+ },
+ parse_response=['application', 'state'],
+ desired_state=EmrServerlessApplicationSensor.FAILURE_STATES,
+ failure_states=set(),
+ object_type='application',
+ action='stopped',
+ )
+
+ self.log.info('Deleting application: %s', self.application_id)
+ response = self.hook.conn.delete_application(applicationId=self.application_id)
+
+ if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+ raise AirflowException(f'Application deletion failed: {response}')
+
+ if self.wait_for_completion:
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={'applicationId': self.application_id},
+ parse_response=['application', 'state'],
+ desired_state={'TERMINATED'},
+ failure_states=EmrServerlessApplicationSensor.FAILURE_STATES,
+ object_type='application',
+ action='deleted',
+ )
+
+ self.log.info('EMR serverless application deleted')
diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py
index 62c74ea560..7c09e2fb89 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -15,15 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Iterable, Optional, Sequence, Set, Union
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
+from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
from airflow.compat.functools import cached_property
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
-from airflow.sensors.base import BaseSensorOperator
class EmrBaseSensor(BaseSensorOperator):
@@ -37,7 +38,7 @@ class EmrBaseSensor(BaseSensorOperator):
Subclasses should set ``target_states`` and ``failed_states`` fields.
- :param aws_conn_id: aws connection to uses
+ :param aws_conn_id: aws connection to use
"""
ui_color = '#66c3ff'
@@ -111,6 +112,137 @@ class EmrBaseSensor(BaseSensorOperator):
raise NotImplementedError('Please implement failure_message_from_response() in subclass')
+class EmrServerlessJobSensor(BaseSensorOperator):
+ """
+ Asks for the state of the job run until it reaches a failure state or success state.
+ If the job run fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:EmrServerlessJobSensor`
+
+ :param application_id: application_id to check the state of
+ :param job_run_id: job_run_id to check the state of
+ :param target_states: a set of states to wait for, defaults to 'SUCCESS'
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ """
+
+ INTERMEDIATE_STATES = {'PENDING', 'RUNNING', 'SCHEDULED', 'SUBMITTED'}
+ FAILURE_STATES = {'FAILED', 'CANCELLING', 'CANCELLED'}
+ SUCCESS_STATES = {'SUCCESS'}
+ TERMINAL_STATES = SUCCESS_STATES.union(FAILURE_STATES)
+
+ template_fields: Sequence[str] = (
+ 'application_id',
+ 'job_run_id',
+ )
+
+ def __init__(
+ self,
+ *,
+ application_id: str,
+ job_run_id: str,
+ target_states: Union[Set, FrozenSet] = frozenset(SUCCESS_STATES),
+ aws_conn_id: str = 'aws_default',
+ **kwargs: Any,
+ ) -> None:
+ self.aws_conn_id = aws_conn_id
+ self.target_states = target_states
+ self.application_id = application_id
+ self.job_run_id = job_run_id
+ super().__init__(**kwargs)
+
+ def poke(self, context: 'Context') -> bool:
+ response = self.hook.conn.get_job_run(applicationId=self.application_id, jobRunId=self.job_run_id)
+
+ state = response['jobRun']['state']
+
+ if state in self.FAILURE_STATES:
+ failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
+ raise AirflowException(failure_message)
+
+ return state in self.target_states
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook"""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ @staticmethod
+ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
+ """
+ Get failure message from response dictionary.
+
+ :param response: response from AWS API
+ :return: failure message
+ :rtype: Optional[str]
+ """
+ return response['jobRun']['stateDetails']
+
+
+class EmrServerlessApplicationSensor(BaseSensorOperator):
+ """
+ Asks for the state of the application until it reaches a failure state or success state.
+ If the application fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:EmrServerlessApplicationSensor`
+
+ :param application_id: application_id to check the state of
+ :param target_states: a set of states to wait for, defaults to {'CREATED', 'STARTED'}
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ """
+
+ template_fields: Sequence[str] = ('application_id',)
+
+ INTERMEDIATE_STATES = {'CREATING', 'STARTING', 'STOPPING'}
+ FAILURE_STATES = {'STOPPED', 'TERMINATED'}
+ SUCCESS_STATES = {'CREATED', 'STARTED'}
+
+ def __init__(
+ self,
+ *,
+ application_id: str,
+ target_states: Union[Set, FrozenSet] = frozenset(SUCCESS_STATES),
+ aws_conn_id: str = 'aws_default',
+ **kwargs: Any,
+ ) -> None:
+ self.aws_conn_id = aws_conn_id
+ self.target_states = target_states
+ self.application_id = application_id
+ super().__init__(**kwargs)
+
+ def poke(self, context: 'Context') -> bool:
+ state = None
+
+ response = self.hook.conn.get_application(applicationId=self.application_id)
+
+ state = response['application']['state']
+
+ if state in self.FAILURE_STATES:
+ failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
+ raise AirflowException(failure_message)
+
+ return state in self.target_states
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook"""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ @staticmethod
+ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
+ """
+ Get failure message from response dictionary.
+
+ :param response: response from AWS API
+ :return: failure message
+ :rtype: Optional[str]
+ """
+ return response['application']['stateDetails']
+
+
class EmrContainerSensor(BaseSensorOperator):
"""
Asks for the state of the job run until it reaches a failure state or success state.
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 3eb05f5b8c..3fa76b6397 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -117,6 +117,12 @@ integrations:
- /docs/apache-airflow-providers-amazon/operators/emr_eks.rst
logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png
tags: [aws]
+ - integration-name: Amazon EMR Serverless
+ external-doc-url: https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/emr_serverless.rst
+ logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png
+ tags: [aws]
- integration-name: Amazon Glacier
external-doc-url: https://aws.amazon.com/glacier/
logo: /integration-logos/aws/Amazon-S3-Glacier_light-bg@4x.png
diff --git a/docs/apache-airflow-providers-amazon/operators/emr_serverless.rst b/docs/apache-airflow-providers-amazon/operators/emr_serverless.rst
new file mode 100644
index 0000000000..2496af2c40
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/emr_serverless.rst
@@ -0,0 +1,113 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+===============================
+Amazon EMR Serverless Operators
+===============================
+
+`Amazon EMR Serverless <https://aws.amazon.com/emr/serverless/>`__ is a serverless option
+in Amazon EMR that makes it easy for data analysts and engineers to run open-source big
+data analytics frameworks without configuring, managing, and scaling clusters or servers.
+You get all the features and benefits of Amazon EMR without the need for experts to plan
+and manage clusters.
+
+Prerequisite Tasks
+------------------
+
+.. include:: _partials/prerequisite_tasks.rst
+
+Operators
+---------
+.. _howto/operator:EmrServerlessCreateApplicationOperator:
+
+Create an EMR Serverless Application
+====================================
+
+You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessCreateApplicationOperator` to
+create a new EMR Serverless Application.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_emr_serverless_create_application]
+ :end-before: [END howto_operator_emr_serverless_create_application]
+
+.. _howto/operator:EmrServerlessStartJobOperator:
+
+Start an EMR Serverless Job
+============================
+
+You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStartJobOperator` to
+start an EMR Serverless Job.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_emr_serverless_start_job]
+ :end-before: [END howto_operator_emr_serverless_start_job]
+
+.. _howto/operator:EmrServerlessDeleteApplicationOperator:
+
+Delete an EMR Serverless Application
+====================================
+
+You can use :class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessDeleteApplicationOperator` to
+delete an EMR Serverless Application.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_emr_serverless_delete_application]
+ :end-before: [END howto_operator_emr_serverless_delete_application]
+
+Sensors
+-------
+
+.. _howto/sensor:EmrServerlessJobSensor:
+
+Wait on an EMR Serverless Job state
+===================================
+
+To monitor the state of an EMR Serverless Job you can use
+:class:`~airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_emr_serverless_job]
+ :end-before: [END howto_sensor_emr_serverless_job]
+
+.. _howto/sensor:EmrServerlessApplicationSensor:
+
+Wait on an EMR Serverless Application state
+============================================
+
+To monitor the state of an EMR Serverless Application you can use
+:class:`~airflow.providers.amazon.aws.sensors.emr.EmrServerlessApplicationSensor`.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_serverless.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_emr_serverless_application]
+ :end-before: [END howto_sensor_emr_serverless_application]
+
+Reference
+---------
+
+* `AWS boto3 library documentation for EMR Serverless <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-serverless.html>`__
+* `Configure IAM Roles for EMR Serverless permissions <https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/getting-started.html>`__
diff --git a/tests/providers/amazon/aws/hooks/test_emr_serverless.py b/tests/providers/amazon/aws/hooks/test_emr_serverless.py
new file mode 100644
index 0000000000..db17da63dc
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_emr_serverless.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
+
+task_id = 'test_emr_serverless_create_application_operator'
+application_id = 'test_application_id'
+release_label = 'test'
+job_type = 'test'
+client_request_token = 'eac427d0-1c6d4df=-96aa-32423412'
+config = {'name': 'test_application_emr_serverless'}
+
+
+class TestEmrServerlessHook:
+ def test_conn_attribute(self):
+ hook = EmrServerlessHook(aws_conn_id='aws_default')
+ assert hasattr(hook, 'conn')
+ # Testing conn is a cached property
+ conn = hook.conn
+ conn2 = hook.conn
+ assert conn is conn2
+
+ def test_waiter_failure_then_success(self):
+ mock_call_function = mock.MagicMock()
+ mock_call_function.side_effect = [{'response': 'test_failure'}, {'response': 'test_success'}]
+ success_state = {'test_success'}
+ hook = EmrServerlessHook()
+ waiter_response = hook.waiter(
+ get_state_callable=mock_call_function,
+ get_state_args={},
+ parse_response=['response'],
+ desired_state=success_state,
+ failure_states={},
+ object_type='test_object',
+ action='testing',
+ check_interval_seconds=1,
+ )
+ assert mock_call_function.call_count == 2
+ assert waiter_response is None
+
+ def test_waiter_success_state(self):
+ mock_call_function = mock.MagicMock()
+ mock_call_function.return_value = {'response': 'test_success'}
+ success_state = {'test_success'}
+ hook = EmrServerlessHook()
+ waiter_response = hook.waiter(
+ get_state_callable=mock_call_function,
+ get_state_args={},
+ parse_response=['response'],
+ desired_state=success_state,
+ failure_states={},
+ object_type='test_object',
+ action='testing',
+ )
+ mock_call_function.assert_called_once()
+ assert waiter_response is None
+
+ def test_waiter_failure_state(self):
+ mock_call_function = mock.MagicMock()
+ failure_state = {'test_failure'}
+ mock_call_function.return_value = {'response': 'test_failure'}
+ hook = EmrServerlessHook()
+ with pytest.raises(AirflowException) as ex_message:
+ hook.waiter(
+ get_state_callable=mock_call_function,
+ get_state_args={},
+ parse_response=['response'],
+ desired_state={},
+ failure_states=failure_state,
+ object_type='test_object',
+ action='testing',
+ )
+ mock_call_function.assert_called_once()
+ assert str(ex_message.value) == f"Test_Object reached failure state {','.join(failure_state)}."
+
+ def test_nested_waiter_success_state(self):
+ mock_call_function = mock.MagicMock()
+ mock_call_function.return_value = {
+ 'layer1': {'key1': 'value1', 'layer2': {'response': 'test_success'}}
+ }
+ success_state = {'test_success'}
+ hook = EmrServerlessHook()
+ waiter_response = hook.waiter(
+ get_state_callable=mock_call_function,
+ get_state_args={},
+ parse_response=['layer1', 'layer2', 'response'],
+ desired_state=success_state,
+ failure_states={},
+ object_type='test_object',
+ action='testing',
+ )
+ mock_call_function.assert_called_once()
+ assert waiter_response is None
+
+ def test_waiter_timeout(self):
+ mock_call_function = mock.MagicMock()
+ success_state = {'test_success'}
+ mock_call_function.return_value = {'response': 'pending'}
+ hook = EmrServerlessHook()
+ with pytest.raises(RuntimeError) as ex_message:
+ hook.waiter(
+ get_state_callable=mock_call_function,
+ get_state_args={},
+ parse_response=['response'],
+ desired_state=success_state,
+ failure_states={},
+ object_type='test_object',
+ action='testing',
+ check_interval_seconds=1,
+ countdown=3,
+ )
+ assert mock_call_function.call_count == 4
+ assert (
+ str(ex_message.value)
+ == f'{"test_object".title()} still not {"testing".lower()} after the allocated time limit.'
+ )
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py
new file mode 100644
index 0000000000..cdcea72949
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -0,0 +1,409 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest import mock
+from uuid import UUID
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.operators.emr import (
+ EmrServerlessCreateApplicationOperator,
+ EmrServerlessDeleteApplicationOperator,
+ EmrServerlessStartJobOperator,
+)
+
+task_id = 'test_emr_serverless_task_id'
+application_id = 'test_application_id'
+release_label = 'test'
+job_type = 'test'
+client_request_token = 'eac427d0-1c6d-4dfb9a-32423412'
+config = {'name': 'test_application_emr_serverless'}
+
+execution_role_arn = 'test_emr_serverless_role_arn'
+job_driver = {'test_key': 'test_value'}
+configuration_overrides = {'monitoringConfiguration': {'test_key': 'test_value'}}
+job_run_id = 'test_job_run_id'
+
+application_id_delete_operator = 'test_emr_serverless_delete_application_operator'
+
+
+class TestEmrServerlessCreateApplicationOperator:
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_execute_successfully_with_wait_for_completion(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ config=config,
+ )
+
+ id = operator.execute(None)
+
+ mock_conn.create_application.assert_called_once_with(
+ clientToken=client_request_token,
+ releaseLabel=release_label,
+ type=job_type,
+ **config,
+ )
+ mock_conn.start_application.assert_called_once_with(applicationId=application_id)
+
+ assert mock_waiter.call_count == 2
+ assert id == application_id
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_execute_successfully_no_wait_for_completion(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ wait_for_completion=False,
+ config=config,
+ )
+
+ id = operator.execute(None)
+
+ mock_conn.create_application.assert_called_once_with(
+ clientToken=client_request_token,
+ releaseLabel=release_label,
+ type=job_type,
+ **config,
+ )
+ mock_conn.start_application.assert_called_once_with(applicationId=application_id)
+
+ mock_waiter.assert_called_once()
+ assert id == application_id
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_failed_create_application(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 404},
+ }
+
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ config=config,
+ )
+
+ with pytest.raises(AirflowException) as ex_message:
+ operator.execute(None)
+
+ assert "Application Creation failed:" in str(ex_message.value)
+
+ mock_conn.create_application.assert_called_once_with(
+ clientToken=client_request_token,
+ releaseLabel=release_label,
+ type=job_type,
+ **config,
+ )
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_no_client_request_token(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ wait_for_completion=False,
+ config=config,
+ )
+
+ operator.execute(None)
+ generated_client_token = operator.client_request_token
+
+ assert str(UUID(generated_client_token, version=4)) == generated_client_token
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_application_in_failure_state(self, mock_conn):
+ fail_state = "STOPPED"
+ mock_conn.get_application.return_value = {"application": {"state": fail_state}}
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ config=config,
+ )
+
+ with pytest.raises(AirflowException) as ex_message:
+ operator.execute(None)
+
+ assert str(ex_message.value) == f"Application reached failure state {fail_state}."
+
+ mock_conn.create_application.assert_called_once_with(
+ clientToken=client_request_token,
+ releaseLabel=release_label,
+ type=job_type,
+ **config,
+ )
+
+
+class TestEmrServerlessStartJobOperator:
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_job_run_app_started(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
+ mock_conn.start_job_run.return_value = {
+ 'jobRunId': job_run_id,
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ client_request_token=client_request_token,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ )
+
+ id = operator.execute(None)
+
+ assert operator.wait_for_completion is True
+ mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+ mock_waiter.assert_called_once()
+ assert id == job_run_id
+ mock_conn.start_job_run.assert_called_once_with(
+ clientToken=client_request_token,
+ applicationId=application_id,
+ executionRoleArn=execution_role_arn,
+ jobDriver=job_driver,
+ configurationOverrides=configuration_overrides,
+ )
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_job_run_app_not_started(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.get_application.return_value = {"application": {"state": "CREATING"}}
+ mock_conn.start_job_run.return_value = {
+ 'jobRunId': job_run_id,
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ client_request_token=client_request_token,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ )
+
+ id = operator.execute(None)
+
+ assert operator.wait_for_completion is True
+ mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+ assert mock_waiter.call_count == 2
+ assert id == job_run_id
+ mock_conn.start_job_run.assert_called_once_with(
+ clientToken=client_request_token,
+ applicationId=application_id,
+ executionRoleArn=execution_role_arn,
+ jobDriver=job_driver,
+ configurationOverrides=configuration_overrides,
+ )
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.get_application.return_value = {"application": {"state": "CREATING"}}
+ mock_conn.start_job_run.return_value = {
+ 'jobRunId': job_run_id,
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ client_request_token=client_request_token,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ wait_for_completion=False,
+ )
+
+ id = operator.execute(None)
+
+ mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+ mock_waiter.assert_called_once()
+ assert id == job_run_id
+ mock_conn.start_job_run.assert_called_once_with(
+ clientToken=client_request_token,
+ applicationId=application_id,
+ executionRoleArn=execution_role_arn,
+ jobDriver=job_driver,
+ configurationOverrides=configuration_overrides,
+ )
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
+ mock_conn.start_job_run.return_value = {
+ 'jobRunId': job_run_id,
+ 'ResponseMetadata': {'HTTPStatusCode': 200},
+ }
+
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ client_request_token=client_request_token,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ wait_for_completion=False,
+ )
+
+ id = operator.execute(None)
+ assert id == job_run_id
+ mock_conn.start_job_run.assert_called_once_with(
+ clientToken=client_request_token,
+ applicationId=application_id,
+ executionRoleArn=execution_role_arn,
+ jobDriver=job_driver,
+ configurationOverrides=configuration_overrides,
+ )
+ assert not mock_waiter.called
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_failed_start_job_run(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.get_application.return_value = {"application": {"state": "CREATING"}}
+ mock_conn.start_job_run.return_value = {
+ 'jobRunId': job_run_id,
+ 'ResponseMetadata': {'HTTPStatusCode': 404},
+ }
+
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ client_request_token=client_request_token,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ )
+ with pytest.raises(AirflowException) as ex_message:
+ operator.execute(None)
+
+ assert "EMR serverless job failed to start:" in str(ex_message.value)
+ mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+ mock_waiter.assert_called_once()
+ mock_conn.start_job_run.assert_called_once_with(
+ clientToken=client_request_token,
+ applicationId=application_id,
+ executionRoleArn=execution_role_arn,
+ jobDriver=job_driver,
+ configurationOverrides=configuration_overrides,
+ )
+
+
+class TestEmrServerlessDeleteOperator:
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_delete_application_with_wait_for_completion_successfully(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.stop_application.return_value = {}
+ mock_conn.delete_application.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}}
+
+ operator = EmrServerlessDeleteApplicationOperator(
+ task_id=task_id, application_id=application_id_delete_operator
+ )
+
+ operator.execute(None)
+
+ assert operator.wait_for_completion is True
+ assert mock_waiter.call_count == 2
+ mock_conn.stop_application.assert_called_once()
+ mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_delete_application_without_wait_for_completion_successfully(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.stop_application.return_value = {}
+ mock_conn.delete_application.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}}
+
+ operator = EmrServerlessDeleteApplicationOperator(
+ task_id=task_id,
+ application_id=application_id_delete_operator,
+ wait_for_completion=False,
+ )
+
+ operator.execute(None)
+
+ assert operator.wait_for_completion is False
+ mock_waiter.assert_called_once()
+ mock_conn.stop_application.assert_called_once()
+ mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
+ @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ def test_delete_application_failed_deleteion(self, mock_conn, mock_waiter):
+ mock_waiter.return_value = True
+ mock_conn.stop_application.return_value = {}
+ mock_conn.delete_application.return_value = {'ResponseMetadata': {'HTTPStatusCode': 400}}
+
+ operator = EmrServerlessDeleteApplicationOperator(
+ task_id=task_id, application_id=application_id_delete_operator
+ )
+ with pytest.raises(AirflowException) as ex_message:
+ operator.execute(None)
+
+ assert "Application deletion failed:" in str(ex_message.value)
+
+ assert operator.wait_for_completion is True
+ mock_waiter.assert_called_once()
+ mock_conn.stop_application.assert_called_once()
+ mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)