You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by on...@apache.org on 2023/06/20 21:20:47 UTC
[airflow] branch main updated: Add custom waiters to EMR Serverless (#30463)
This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 743bf5a0ae Add custom waiters to EMR Serverless (#30463)
743bf5a0ae is described below
commit 743bf5a0ae1279c96d018aad54dcce108f16dc96
Author: Syed Hussaain <10...@users.noreply.github.com>
AuthorDate: Tue Jun 20 14:20:39 2023 -0700
Add custom waiters to EMR Serverless (#30463)
* Move waiter logic to utils folder
---------
Co-authored-by: Raphaƫl Vandon <va...@amazon.com>
---
airflow/providers/amazon/aws/operators/emr.py | 279 ++++++++++------
.../amazon/aws/utils/waiter_with_logging.py | 90 +++++
.../amazon/aws/waiters/emr-serverless.json | 139 ++++++++
.../amazon/aws/operators/test_emr_serverless.py | 369 +++++++++++++++------
.../amazon/aws/utils/test_waiter_with_logging.py | 304 +++++++++++++++++
5 files changed, 976 insertions(+), 205 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index b8ca53226e..9fdad3b918 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -34,6 +34,7 @@ from airflow.providers.amazon.aws.triggers.emr import (
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils.waiter import waiter
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.utils.helpers import exactly_one, prune_dict
from airflow.utils.types import NOTSET, ArgNotSet
@@ -945,10 +946,13 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
Its value must be unique for each request.
:param config: Optional dictionary for arbitrary parameters to the boto API create_application call.
:param aws_conn_id: AWS connection to use
- :param waiter_countdown: Total amount of time, in seconds, the operator will wait for
+ :param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for
the application to start. Defaults to 25 minutes.
- :param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
- Defaults to 60 seconds.
+ :param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state
+ of the application. Defaults to 60 seconds.
+ :waiter_max_attempts: Number of times the waiter should poll the application to check the state.
+ If not set, the waiter will use its default value.
+ :param waiter_delay: Number of seconds between polling the state of the application.
"""
def __init__(
@@ -959,18 +963,41 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
config: dict | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
- waiter_countdown: int = 25 * 60,
- waiter_check_interval_seconds: int = 60,
+ waiter_countdown: int | ArgNotSet = NOTSET,
+ waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
+ waiter_max_attempts: int | ArgNotSet = NOTSET,
+ waiter_delay: int | ArgNotSet = NOTSET,
**kwargs,
):
+ if waiter_check_interval_seconds is NOTSET:
+ waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay
+ else:
+ waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay
+ warnings.warn(
+ "The parameter waiter_check_interval_seconds has been deprecated to standardize "
+ "naming conventions. Please use waiter_delay instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
+ if waiter_countdown is NOTSET:
+ waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts
+ else:
+ if waiter_max_attempts is NOTSET:
+ # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables
+ # are of type ArgNotSet at this point.
+ waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator]
+ warnings.warn(
+ "The parameter waiter_countdown has been deprecated to standardize "
+ "naming conventions. Please use waiter_max_attempts instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
self.aws_conn_id = aws_conn_id
self.release_label = release_label
self.job_type = job_type
self.wait_for_completion = wait_for_completion
self.kwargs = kwargs
self.config = config or {}
- self.waiter_countdown = waiter_countdown
- self.waiter_check_interval_seconds = waiter_check_interval_seconds
+ self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
+ self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
super().__init__(**kwargs)
self.client_request_token = client_request_token or str(uuid4())
@@ -993,37 +1020,31 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
raise AirflowException(f"Application Creation failed: {response}")
self.log.info("EMR serverless application created: %s", application_id)
+ waiter = self.hook.get_waiter("serverless_app_created")
- # This should be replaced with a boto waiter when available.
- waiter(
- get_state_callable=self.hook.conn.get_application,
- get_state_args={"applicationId": application_id},
- parse_response=["application", "state"],
- desired_state={"CREATED"},
- failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
- object_type="application",
- action="created",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ wait(
+ waiter=waiter,
+ waiter_delay=self.waiter_delay,
+ max_attempts=self.waiter_max_attempts,
+ args={"applicationId": application_id},
+ failure_message="Serverless Application creation failed",
+ status_message="Serverless Application status is",
+ status_args=["application.state", "application.stateDetails"],
)
-
self.log.info("Starting application %s", application_id)
self.hook.conn.start_application(applicationId=application_id)
if self.wait_for_completion:
- # This should be replaced with a boto waiter when available.
- waiter(
- get_state_callable=self.hook.conn.get_application,
- get_state_args={"applicationId": application_id},
- parse_response=["application", "state"],
- desired_state={"STARTED"},
- failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
- object_type="application",
- action="started",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ waiter = self.hook.get_waiter("serverless_app_started")
+ wait(
+ waiter=waiter,
+ max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ args={"applicationId": application_id},
+ failure_message="Serverless Application failed to start",
+ status_message="Serverless Application status is",
+ status_args=["application.state", "application.stateDetails"],
)
-
return application_id
@@ -1047,10 +1068,13 @@ class EmrServerlessStartJobOperator(BaseOperator):
when waiting for the application be to in the ``STARTED`` state.
:param aws_conn_id: AWS connection to use.
:param name: Name for the EMR Serverless job. If not provided, a default name will be assigned.
- :param waiter_countdown: Total amount of time, in seconds, the operator will wait for
+ :param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for
the job finish. Defaults to 25 minutes.
- :param waiter_check_interval_seconds: Number of seconds between polling the state of the job.
+ :param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state of the job.
Defaults to 60 seconds.
+ :waiter_max_attempts: Number of times the waiter should poll the application to check the state.
+ If not set, the waiter will use its default value.
+ :param waiter_delay: Number of seconds between polling the state of the job run.
"""
template_fields: Sequence[str] = (
@@ -1077,10 +1101,33 @@ class EmrServerlessStartJobOperator(BaseOperator):
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
name: str | None = None,
- waiter_countdown: int = 25 * 60,
- waiter_check_interval_seconds: int = 60,
+ waiter_countdown: int | ArgNotSet = NOTSET,
+ waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
+ waiter_max_attempts: int | ArgNotSet = NOTSET,
+ waiter_delay: int | ArgNotSet = NOTSET,
**kwargs,
):
+ if waiter_check_interval_seconds is NOTSET:
+ waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay
+ else:
+ waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay
+ warnings.warn(
+ "The parameter waiter_check_interval_seconds has been deprecated to standardize "
+ "naming conventions. Please use waiter_delay instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
+ if waiter_countdown is NOTSET:
+ waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts
+ else:
+ if waiter_max_attempts is NOTSET:
+ # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables
+ # are of type ArgNotSet at this point.
+ waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator]
+ warnings.warn(
+ "The parameter waiter_countdown has been deprecated to standardize "
+ "naming conventions. Please use waiter_max_attempts instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
self.aws_conn_id = aws_conn_id
self.application_id = application_id
self.execution_role_arn = execution_role_arn
@@ -1089,8 +1136,8 @@ class EmrServerlessStartJobOperator(BaseOperator):
self.wait_for_completion = wait_for_completion
self.config = config or {}
self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}")
- self.waiter_countdown = waiter_countdown
- self.waiter_check_interval_seconds = waiter_check_interval_seconds
+ self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
+ self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.job_id: str | None = None
super().__init__(**kwargs)
@@ -1107,17 +1154,16 @@ class EmrServerlessStartJobOperator(BaseOperator):
app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"]
if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES:
self.hook.conn.start_application(applicationId=self.application_id)
-
- waiter(
- get_state_callable=self.hook.conn.get_application,
- get_state_args={"applicationId": self.application_id},
- parse_response=["application", "state"],
- desired_state={"STARTED"},
- failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
- object_type="application",
- action="started",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ waiter = self.hook.get_waiter("serverless_app_started")
+
+ wait(
+ waiter=waiter,
+ max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ args={"applicationId": self.application_id},
+ failure_message="Serverless Application failed to start",
+ status_message="Serverless Application status is",
+ status_args=["application.state", "application.stateDetails"],
)
response = self.hook.conn.start_job_run(
@@ -1136,21 +1182,17 @@ class EmrServerlessStartJobOperator(BaseOperator):
self.job_id = response["jobRunId"]
self.log.info("EMR serverless job started: %s", self.job_id)
if self.wait_for_completion:
- # This should be replaced with a boto waiter when available.
- waiter(
- get_state_callable=self.hook.conn.get_job_run,
- get_state_args={
- "applicationId": self.application_id,
- "jobRunId": self.job_id,
- },
- parse_response=["jobRun", "state"],
- desired_state=EmrServerlessHook.JOB_SUCCESS_STATES,
- failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
- object_type="job",
- action="run",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ waiter = self.hook.get_waiter("serverless_job_completed")
+ wait(
+ waiter=waiter,
+ max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ args={"applicationId": self.application_id, "jobRunId": self.job_id},
+ failure_message="Serverless Job failed",
+ status_message="Serverless Job status is",
+ status_args=["jobRun.state", "jobRun.stateDetails"],
)
+
return self.job_id
def on_kill(self) -> None:
@@ -1180,8 +1222,8 @@ class EmrServerlessStartJobOperator(BaseOperator):
failure_states=set(),
object_type="job",
action="cancelled",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ countdown=self.waiter_delay * self.waiter_max_attempts,
+ check_interval_seconds=self.waiter_delay,
)
@@ -1213,16 +1255,39 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
application_id: str,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
- waiter_countdown: int = 5 * 60,
- waiter_check_interval_seconds: int = 30,
+ waiter_countdown: int | ArgNotSet = NOTSET,
+ waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
+ waiter_max_attempts: int | ArgNotSet = NOTSET,
+ waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
**kwargs,
):
+ if waiter_check_interval_seconds is NOTSET:
+ waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay
+ else:
+ waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay
+ warnings.warn(
+ "The parameter waiter_check_interval_seconds has been deprecated to standardize "
+ "naming conventions. Please use waiter_delay instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
+ if waiter_countdown is NOTSET:
+ waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts
+ else:
+ if waiter_max_attempts is NOTSET:
+ # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables
+ # are of type ArgNotSet at this point.
+ waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator]
+ warnings.warn(
+ "The parameter waiter_countdown has been deprecated to standardize "
+ "naming conventions. Please use waiter_max_attempts instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
self.aws_conn_id = aws_conn_id
self.application_id = application_id
self.wait_for_completion = wait_for_completion
- self.waiter_countdown = waiter_countdown
- self.waiter_check_interval_seconds = waiter_check_interval_seconds
+ self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
+ self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.force_stop = force_stop
super().__init__(**kwargs)
@@ -1238,27 +1303,23 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
self.hook.cancel_running_jobs(
self.application_id,
waiter_config={
- "Delay": self.waiter_check_interval_seconds,
- "MaxAttempts": self.waiter_countdown / self.waiter_check_interval_seconds,
+ "Delay": self.waiter_delay,
+ "MaxAttempts": self.waiter_max_attempts,
},
)
self.hook.conn.stop_application(applicationId=self.application_id)
if self.wait_for_completion:
- # This should be replaced with a boto waiter when available.
- waiter(
- get_state_callable=self.hook.conn.get_application,
- get_state_args={
- "applicationId": self.application_id,
- },
- parse_response=["application", "state"],
- desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
- failure_states=set(),
- object_type="application",
- action="stopped",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ waiter = self.hook.get_waiter("serverless_app_stopped")
+ wait(
+ waiter=waiter,
+ max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ args={"applicationId": self.application_id},
+ failure_message="Error stopping application",
+ status_message="Serverless Application status is",
+ status_args=["application.state", "application.stateDetails"],
)
self.log.info("EMR serverless application %s stopped successfully", self.application_id)
@@ -1292,11 +1353,34 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
application_id: str,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
- waiter_countdown: int = 25 * 60,
- waiter_check_interval_seconds: int = 60,
+ waiter_countdown: int | ArgNotSet = NOTSET,
+ waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
+ waiter_max_attempts: int | ArgNotSet = NOTSET,
+ waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
**kwargs,
):
+ if waiter_check_interval_seconds is NOTSET:
+ waiter_delay = 60 if waiter_delay is NOTSET else waiter_delay
+ else:
+ waiter_delay = waiter_check_interval_seconds if waiter_delay is NOTSET else waiter_delay
+ warnings.warn(
+ "The parameter waiter_check_interval_seconds has been deprecated to standardize "
+ "naming conventions. Please use waiter_delay instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
+ if waiter_countdown is NOTSET:
+ waiter_max_attempts = 25 if waiter_max_attempts is NOTSET else waiter_max_attempts
+ else:
+ if waiter_max_attempts is NOTSET:
+ # ignoring mypy because it doesn't like ArgNotSet as an operand, but neither variables
+ # are of type ArgNotSet at this point.
+ waiter_max_attempts = waiter_countdown // waiter_delay # type: ignore[operator]
+ warnings.warn(
+ "The parameter waiter_countdown has been deprecated to standardize "
+ "naming conventions. Please use waiter_max_attempts instead. In the "
+ "future this will default to None and defer to the waiter's default value."
+ )
self.wait_for_delete_completion = wait_for_completion
# super stops the app
super().__init__(
@@ -1304,8 +1388,8 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
# when deleting an app, we always need to wait for it to stop before we can call delete()
wait_for_completion=True,
aws_conn_id=aws_conn_id,
- waiter_countdown=waiter_countdown,
- waiter_check_interval_seconds=waiter_check_interval_seconds,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
force_stop=force_stop,
**kwargs,
)
@@ -1321,17 +1405,16 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
raise AirflowException(f"Application deletion failed: {response}")
if self.wait_for_delete_completion:
- # This should be replaced with a boto waiter when available.
- waiter(
- get_state_callable=self.hook.conn.get_application,
- get_state_args={"applicationId": self.application_id},
- parse_response=["application", "state"],
- desired_state={"TERMINATED"},
- failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
- object_type="application",
- action="deleted",
- countdown=self.waiter_countdown,
- check_interval_seconds=self.waiter_check_interval_seconds,
+ waiter = self.hook.get_waiter("serverless_app_terminated")
+
+ wait(
+ waiter=waiter,
+ max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ args={"applicationId": self.application_id},
+ failure_message="Error terminating application",
+ status_message="Serverless Application status is",
+ status_args=["application.state", "application.stateDetails"],
)
self.log.info("EMR serverless application deleted")
diff --git a/airflow/providers/amazon/aws/utils/waiter_with_logging.py b/airflow/providers/amazon/aws/utils/waiter_with_logging.py
new file mode 100644
index 0000000000..8c9e33077f
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/waiter_with_logging.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import time
+
+import jmespath
+from botocore.exceptions import WaiterError
+from botocore.waiter import Waiter
+
+from airflow.exceptions import AirflowException
+
+
+def wait(
+ waiter: Waiter,
+ waiter_delay: int,
+ max_attempts: int,
+ args: dict,
+ failure_message: str,
+ status_message: str,
+ status_args: list,
+) -> None:
+ """
+ Use a boto waiter to poll an AWS service for the specified state. Although this function
+ uses boto waiters to poll the state of the service, it logs the response of the service
+ after every attempt, which is not currently supported by boto waiters.
+
+ :param waiter: The boto waiter to use.
+ :param waiter_delay: The amount of time in seconds to wait between attempts.
+ :param max_attempts: The maximum number of attempts to be made.
+ :param args: The arguments to pass to the waiter.
+ :param failure_message: The message to log if a failure state is reached.
+ :param status_message: The message logged when printing the status of the service.
+ :param status_args: A list containing the arguments to retrieve status information from
+ the waiter response.
+ e.g.
+ response = {"Cluster": {"state": "CREATING"}}
+ status_args = ["Cluster.state"]
+
+ response = {
+ "Clusters": [{"state": "CREATING", "details": "User initiated."},]
+ }
+ status_args = ["Clusters[0].state", "Clusters[0].details"]
+ """
+ log = logging.getLogger(__name__)
+ attempt = 0
+ while True:
+ attempt += 1
+ try:
+ waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ raise AirflowException(f"{failure_message}: {error}")
+ status_string = _format_status_string(status_args, error.last_response)
+ log.info("%s: %s", status_message, status_string)
+ time.sleep(waiter_delay)
+
+ if attempt >= max_attempts:
+ raise AirflowException("Waiter error: max attempts reached")
+
+
+def _format_status_string(args, response):
+ """
+ Loops through the supplied args list and generates a string
+ which contains values from the waiter response.
+ """
+ values = []
+ for arg in args:
+ value = jmespath.search(arg, response)
+ if value is not None and value != "":
+ values.append(str(value))
+
+ return " - ".join(values)
diff --git a/airflow/providers/amazon/aws/waiters/emr-serverless.json b/airflow/providers/amazon/aws/waiters/emr-serverless.json
index a77d07f243..4066109382 100644
--- a/airflow/providers/amazon/aws/waiters/emr-serverless.json
+++ b/airflow/providers/amazon/aws/waiters/emr-serverless.json
@@ -13,6 +13,145 @@
"state": "success"
}
]
+ },
+ "serverless_app_created": {
+ "operation": "GetApplication",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "CREATED",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "TERMINATED",
+ "state": "failure"
+ }
+ ]
+ },
+ "serverless_app_started": {
+ "operation": "GetApplication",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "STARTED",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "TERMINATED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "STOPPED",
+ "state": "failure"
+ }
+ ]
+ },
+ "serverless_app_stopped": {
+ "operation": "GetApplication",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "STOPPED",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "TERMINATED",
+ "state": "failure"
+ }
+ ]
+ },
+ "serverless_app_terminated": {
+ "operation": "GetApplication",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "TERMINATED",
+ "state": "success"
+ }
+ ]
+ },
+ "serverless_job_completed": {
+ "operation": "GetJobRun",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "SUCCESS",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "FAILED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "CANCELLED",
+ "state": "failure"
+ }
+ ]
+ },
+ "serverless_job_running": {
+ "operation": "GetJobRun",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "RUNNING",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "FAILED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "CANCELLED",
+ "state": "failure"
+ }
+ ]
+ },
+ "serverless_app_deleted": {
+ "operation": "GetApplication",
+ "delay": 60,
+ "maxAttempts": 1500,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "application.state",
+ "expected": "TERMINATED",
+ "state": "success"
+ }
+ ]
}
}
}
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 6889a374ce..8cb4eb1707 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -21,14 +21,17 @@ from unittest.mock import MagicMock, PropertyMock
from uuid import UUID
import pytest
+from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
from airflow.providers.amazon.aws.operators.emr import (
EmrServerlessCreateApplicationOperator,
EmrServerlessDeleteApplicationOperator,
EmrServerlessStartJobOperator,
EmrServerlessStopApplicationOperator,
)
+from airflow.utils.types import NOTSET
task_id = "test_emr_serverless_task_id"
application_id = "test_application_id"
@@ -46,8 +49,10 @@ application_id_delete_operator = "test_emr_serverless_delete_application_operato
class TestEmrServerlessCreateApplicationOperator:
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_execute_successfully_with_wait_for_completion(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_execute_successfully_with_wait_for_completion(self, mock_conn, mock_waiter):
+ mock_waiter().wait.return_value = True
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -63,6 +68,8 @@ class TestEmrServerlessCreateApplicationOperator:
job_type=job_type,
client_request_token=client_request_token,
config=config,
+ waiter_max_attempts=3,
+ waiter_delay=0,
)
id = operator.execute(None)
@@ -73,15 +80,22 @@ class TestEmrServerlessCreateApplicationOperator:
type=job_type,
**config,
)
+ mock_waiter().wait.assert_called_with(
+ applicationId=application_id,
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+ assert mock_waiter().wait.call_count == 2
+
mock_conn.start_application.assert_called_once_with(applicationId=application_id)
assert id == application_id
mock_conn.get_application.call_count == 2
- # @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
def test_execute_successfully_no_wait_for_completion(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ mock_waiter().wait.return_value = True
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -106,13 +120,11 @@ class TestEmrServerlessCreateApplicationOperator:
)
mock_conn.start_application.assert_called_once_with(applicationId=application_id)
- mock_waiter.assert_called_once()
+ mock_waiter().wait.assert_called_once()
assert id == application_id
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_failed_create_application_request(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_failed_create_application_request(self, mock_conn):
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 404},
@@ -138,13 +150,19 @@ class TestEmrServerlessCreateApplicationOperator:
**config,
)
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_failed_create_application(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_failed_create_application(self, mock_conn, mock_get_waiter):
+ error = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"application": {"state": "FAILED"}},
+ )
+ mock_get_waiter().wait.side_effect = error
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
- mock_conn.get_application.return_value = {"application": {"state": "TERMINATED"}}
operator = EmrServerlessCreateApplicationOperator(
task_id=task_id,
@@ -157,7 +175,7 @@ class TestEmrServerlessCreateApplicationOperator:
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)
- assert "Application reached failure state" in str(ex_message.value)
+ assert "Serverless Application creation failed:" in str(ex_message.value)
mock_conn.create_application.assert_called_once_with(
clientToken=client_request_token,
@@ -165,18 +183,51 @@ class TestEmrServerlessCreateApplicationOperator:
type=job_type,
**config,
)
- mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+ mock_conn.create_application.return_value = {
+ "applicationId": application_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ error = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"application": {"state": "TERMINATED"}},
+ )
+ mock_get_waiter().wait.side_effect = error
+
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ config=config,
+ )
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_failed_start_application(self, mock_conn):
+ with pytest.raises(AirflowException) as ex_message:
+ operator.execute(None)
+
+ assert "Serverless Application creation failed:" in str(ex_message.value)
+
+ mock_conn.create_application.assert_called_with(
+ clientToken=client_request_token,
+ releaseLabel=release_label,
+ type=job_type,
+ **config,
+ )
+ mock_conn.create_application.call_count == 2
+
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_failed_start_application(self, mock_conn, mock_get_waiter):
+ error = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"application": {"state": "TERMINATED"}},
+ )
+ mock_get_waiter().wait.side_effect = [True, error]
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
- mock_conn.get_application.side_effect = [
- {"application": {"state": "CREATED"}},
- {"application": {"state": "TERMINATED"}},
- ]
operator = EmrServerlessCreateApplicationOperator(
task_id=task_id,
@@ -189,7 +240,7 @@ class TestEmrServerlessCreateApplicationOperator:
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)
- assert "Application reached failure state" in str(ex_message.value)
+ assert "Serverless Application failed to start:" in str(ex_message.value)
mock_conn.create_application.assert_called_once_with(
clientToken=client_request_token,
@@ -197,12 +248,11 @@ class TestEmrServerlessCreateApplicationOperator:
type=job_type,
**config,
)
- mock_conn.get_application.call_count == 2
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
def test_no_client_request_token(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ mock_waiter().wait.return_value = True
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -221,10 +271,16 @@ class TestEmrServerlessCreateApplicationOperator:
assert str(UUID(generated_client_token, version=4)) == generated_client_token
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_application_in_failure_state(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_application_in_failure_state(self, mock_conn, mock_get_waiter):
fail_state = "STOPPED"
- mock_conn.get_application.return_value = {"application": {"state": fail_state}}
+ error = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"application": {"state": fail_state}},
+ )
+ mock_get_waiter().wait.side_effect = [error]
mock_conn.create_application.return_value = {
"applicationId": application_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -241,7 +297,7 @@ class TestEmrServerlessCreateApplicationOperator:
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)
- assert str(ex_message.value) == f"Application reached failure state {fail_state}."
+ assert str(ex_message.value) == f"Serverless Application creation failed: {error}"
mock_conn.create_application.assert_called_once_with(
clientToken=client_request_token,
@@ -250,10 +306,39 @@ class TestEmrServerlessCreateApplicationOperator:
**config,
)
+ @pytest.mark.parametrize(
+ "waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected",
+ [
+ (NOTSET, NOTSET, NOTSET, NOTSET, [60, 25]),
+ (30, 10, NOTSET, NOTSET, [30, 10]),
+ (NOTSET, NOTSET, 30 * 15, 15, [15, 30]),
+ (10, 20, 30, 40, [10, 20]),
+ ],
+ )
+ def test_create_application_waiter_params(
+ self, waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected
+ ):
+ operator = EmrServerlessCreateApplicationOperator(
+ task_id=task_id,
+ release_label=release_label,
+ job_type=job_type,
+ client_request_token=client_request_token,
+ config=config,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ waiter_countdown=waiter_countdown,
+ waiter_check_interval_seconds=waiter_check_interval_seconds,
+ )
+ assert operator.wait_for_completion is True
+ assert operator.waiter_delay == expected[0]
+ assert operator.waiter_max_attempts == expected[1]
+
class TestEmrServerlessStartJobOperator:
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_job_run_app_started(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_job_run_app_started(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
@@ -283,18 +368,22 @@ class TestEmrServerlessStartJobOperator:
configurationOverrides=configuration_overrides,
name=default_name,
)
- mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id)
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_job_run_job_failed(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_job_run_job_failed(self, mock_conn, mock_get_waiter):
+ error = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"jobRun": {"state": "FAILED"}},
+ )
+ mock_get_waiter().wait.side_effect = [error]
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
- mock_conn.get_job_run.return_value = {"jobRun": {"state": "FAILED"}}
-
operator = EmrServerlessStartJobOperator(
task_id=task_id,
client_request_token=client_request_token,
@@ -307,9 +396,8 @@ class TestEmrServerlessStartJobOperator:
with pytest.raises(AirflowException) as ex_message:
id = operator.execute(None)
assert id == job_run_id
- assert "Job reached failure state FAILED." in str(ex_message.value)
+ assert "Serverless Job failed:" in str(ex_message.value)
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
- mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, jobRunId=job_run_id)
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
applicationId=application_id,
@@ -319,10 +407,10 @@ class TestEmrServerlessStartJobOperator:
name=default_name,
)
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_job_run_app_not_started(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_job_run_app_not_started(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.get_application.return_value = {"application": {"state": "CREATING"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
@@ -343,7 +431,7 @@ class TestEmrServerlessStartJobOperator:
assert operator.wait_for_completion is True
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
- assert mock_waiter.call_count == 2
+ assert mock_get_waiter().wait.call_count == 2
assert id == job_run_id
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
@@ -354,12 +442,21 @@ class TestEmrServerlessStartJobOperator:
name=default_name,
)
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_job_run_app_not_started_app_failed(self, mock_conn):
- mock_conn.get_application.side_effect = [
- {"application": {"state": "CREATING"}},
- {"application": {"state": "TERMINATED"}},
- ]
+ @mock.patch("time.sleep", return_value=True)
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_job_run_app_not_started_app_failed(self, mock_conn, mock_get_waiter, mock_time):
+ error1 = WaiterError(
+ name="test_name",
+ reason="test-reason",
+ last_response={"application": {"state": "CREATING", "stateDetails": "test-details"}},
+ )
+ error2 = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"application": {"state": "TERMINATED", "stateDetails": "test-details"}},
+ )
+ mock_get_waiter().wait.side_effect = [error1, error2]
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -375,15 +472,14 @@ class TestEmrServerlessStartJobOperator:
)
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)
- assert "Application reached failure state" in str(ex_message.value)
+ assert "Serverless Application failed to start:" in str(ex_message.value)
assert operator.wait_for_completion is True
- mock_conn.get_application.call_count == 2
- mock_conn.assert_not_called()
+ assert mock_get_waiter().wait.call_count == 2
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.get_application.return_value = {"application": {"state": "CREATING"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
@@ -403,7 +499,7 @@ class TestEmrServerlessStartJobOperator:
id = operator.execute(None)
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
- mock_waiter.assert_called_once()
+ mock_get_waiter().wait.assert_called_once()
assert id == job_run_id
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
@@ -414,10 +510,10 @@ class TestEmrServerlessStartJobOperator:
name=default_name,
)
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
@@ -444,12 +540,12 @@ class TestEmrServerlessStartJobOperator:
configurationOverrides=configuration_overrides,
name=default_name,
)
- assert not mock_waiter.called
+ assert not mock_get_waiter().wait.called
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_failed_start_job_run(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_failed_start_job_run(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.get_application.return_value = {"application": {"state": "CREATING"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
@@ -470,7 +566,7 @@ class TestEmrServerlessStartJobOperator:
assert "EMR serverless job failed to start:" in str(ex_message.value)
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
- mock_waiter.assert_called_once()
+ mock_get_waiter().wait.assert_called_once()
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
applicationId=application_id,
@@ -480,15 +576,20 @@ class TestEmrServerlessStartJobOperator:
name=default_name,
)
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_start_job_run_fail_on_wait_for_completion(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_start_job_run_fail_on_wait_for_completion(self, mock_conn, mock_get_waiter):
+ error = WaiterError(
+ name="mock_waiter_error",
+ reason="Waiter encountered a terminal failure state:",
+ last_response={"jobRun": {"state": "FAILED", "stateDetails": "Test Details"}},
+ )
+ mock_get_waiter().wait.side_effect = [error]
mock_conn.get_application.return_value = {"application": {"state": "CREATED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
- mock_conn.get_job_run.return_value = {"jobRun": {"state": "FAILED"}}
-
operator = EmrServerlessStartJobOperator(
task_id=task_id,
client_request_token=client_request_token,
@@ -501,7 +602,7 @@ class TestEmrServerlessStartJobOperator:
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)
- assert "Job reached failure state" in str(ex_message.value)
+ assert "Serverless Job failed:" in str(ex_message.value)
mock_conn.get_application.call_count == 2
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
@@ -511,15 +612,17 @@ class TestEmrServerlessStartJobOperator:
configurationOverrides=configuration_overrides,
name=default_name,
)
+ mock_get_waiter().wait.assert_called_once()
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_start_job_default_name(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_start_job_default_name(self, mock_conn, mock_get_waiter):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
- mock_conn.get_job_run.return_value = {"jobRun": {"state": "SUCCESS"}}
+ mock_get_waiter().wait.return_value = True
operator = EmrServerlessStartJobOperator(
task_id=task_id,
@@ -543,15 +646,16 @@ class TestEmrServerlessStartJobOperator:
name=f"emr_serverless_job_airflow_{str(UUID(generated_name_uuid, version=4))}",
)
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_start_job_custom_name(self, mock_conn):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_start_job_custom_name(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
custom_name = "test_name"
mock_conn.start_job_run.return_value = {
"jobRunId": job_run_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
- mock_conn.get_job_run.return_value = {"jobRun": {"state": "SUCCESS"}}
operator = EmrServerlessStartJobOperator(
task_id=task_id,
@@ -573,7 +677,7 @@ class TestEmrServerlessStartJobOperator:
name=custom_name,
)
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+ @mock.patch.object(EmrServerlessHook, "conn")
def test_cancel_job_run(self, mock_conn):
mock_conn.get_application.return_value = {"application": {"state": "STARTED"}}
mock_conn.start_job_run.return_value = {
@@ -599,12 +703,39 @@ class TestEmrServerlessStartJobOperator:
jobRunId=id,
)
+ @pytest.mark.parametrize(
+ "waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected",
+ [
+ (NOTSET, NOTSET, NOTSET, NOTSET, [60, 25]),
+ (30, 10, NOTSET, NOTSET, [30, 10]),
+ (NOTSET, NOTSET, 30 * 15, 15, [15, 30]),
+ (10, 20, 30, 40, [10, 20]),
+ ],
+ )
+ def test_start_job_waiter_params(
+ self, waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected
+ ):
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ waiter_countdown=waiter_countdown,
+ waiter_check_interval_seconds=waiter_check_interval_seconds,
+ )
+ assert operator.wait_for_completion is True
+ assert operator.waiter_delay == expected[0]
+ assert operator.waiter_max_attempts == expected[1]
+
class TestEmrServerlessDeleteOperator:
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_delete_application_with_wait_for_completion_successfully(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_delete_application_with_wait_for_completion_successfully(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.stop_application.return_value = {}
mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}
@@ -615,14 +746,14 @@ class TestEmrServerlessDeleteOperator:
operator.execute(None)
assert operator.wait_for_completion is True
- assert mock_waiter.call_count == 2
+ assert mock_get_waiter().wait.call_count == 2
mock_conn.stop_application.assert_called_once()
mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_delete_application_without_wait_for_completion_successfully(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_delete_application_without_wait_for_completion_successfully(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.stop_application.return_value = {}
mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}
@@ -634,14 +765,14 @@ class TestEmrServerlessDeleteOperator:
operator.execute(None)
- mock_waiter.assert_called_once()
+ mock_get_waiter().wait.assert_called_once()
mock_conn.stop_application.assert_called_once()
mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_delete_application_failed_deletion(self, mock_conn, mock_waiter):
- mock_waiter.return_value = True
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_delete_application_failed_deletion(self, mock_conn, mock_get_waiter):
+ mock_get_waiter().wait.return_value = True
mock_conn.stop_application.return_value = {}
mock_conn.delete_application.return_value = {"ResponseMetadata": {"HTTPStatusCode": 400}}
@@ -653,37 +784,61 @@ class TestEmrServerlessDeleteOperator:
assert "Application deletion failed:" in str(ex_message.value)
- mock_waiter.assert_called_once()
+ mock_get_waiter().wait.assert_called_once()
mock_conn.stop_application.assert_called_once()
mock_conn.delete_application.assert_called_once_with(applicationId=application_id_delete_operator)
+ @pytest.mark.parametrize(
+ "waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected",
+ [
+ (NOTSET, NOTSET, NOTSET, NOTSET, [60, 25]),
+ (30, 10, NOTSET, NOTSET, [30, 10]),
+ (NOTSET, NOTSET, 30 * 15, 15, [15, 30]),
+ (10, 20, 30, 40, [10, 20]),
+ ],
+ )
+ def test_delete_application_waiter_params(
+ self, waiter_delay, waiter_max_attempts, waiter_countdown, waiter_check_interval_seconds, expected
+ ):
+ operator = EmrServerlessDeleteApplicationOperator(
+ task_id=task_id,
+ application_id=application_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ waiter_countdown=waiter_countdown,
+ waiter_check_interval_seconds=waiter_check_interval_seconds,
+ )
+ assert operator.wait_for_completion is True
+ assert operator.waiter_delay == expected[0]
+ assert operator.waiter_max_attempts == expected[1]
+
class TestEmrServerlessStopOperator:
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_stop(self, mock_conn: MagicMock, mock_waiter: MagicMock):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_stop(self, mock_conn: MagicMock, mock_get_waiter: MagicMock):
+ mock_get_waiter().wait.return_value = True
operator = EmrServerlessStopApplicationOperator(task_id=task_id, application_id="test")
operator.execute(None)
- mock_waiter.assert_called_once()
+ mock_get_waiter().wait.assert_called_once()
mock_conn.stop_application.assert_called_once()
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
- @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
- def test_stop_no_wait(self, mock_conn: MagicMock, mock_waiter: MagicMock):
+ @mock.patch.object(EmrServerlessHook, "get_waiter")
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_stop_no_wait(self, mock_conn: MagicMock, mock_get_waiter: MagicMock):
operator = EmrServerlessStopApplicationOperator(
task_id=task_id, application_id="test", wait_for_completion=False
)
operator.execute(None)
- mock_waiter.assert_not_called()
+ mock_get_waiter().wait.assert_not_called()
mock_conn.stop_application.assert_called_once()
- @mock.patch("airflow.providers.amazon.aws.operators.emr.waiter")
@mock.patch.object(EmrServerlessStopApplicationOperator, "hook", new_callable=PropertyMock)
- def test_force_stop(self, mock_hook: MagicMock, mock_waiter: MagicMock):
+ def test_force_stop(self, mock_hook: MagicMock):
operator = EmrServerlessStopApplicationOperator(
task_id=task_id, application_id="test", force_stop=True
)
@@ -692,4 +847,4 @@ class TestEmrServerlessStopOperator:
mock_hook().cancel_running_jobs.assert_called_once()
mock_hook().conn.stop_application.assert_called_once()
- mock_waiter.assert_called_once()
+ mock_hook().get_waiter().wait.assert_called_once()
diff --git a/tests/providers/amazon/aws/utils/test_waiter_with_logging.py b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
new file mode 100644
index 0000000000..2ca74936d7
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/test_waiter_with_logging.py
@@ -0,0 +1,304 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from typing import Any
+from unittest import mock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
+
+
+def generate_response(state: str) -> dict[str, Any]:
+ return {
+ "Status": {
+ "State": state,
+ },
+ }
+
+
+class TestWaiter:
+ @mock.patch("time.sleep")
+ def test_wait(self, mock_sleep, caplog):
+ mock_sleep.return_value = True
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response=generate_response("Pending"),
+ )
+ mock_waiter.wait.side_effect = [error, error, True]
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=123,
+ max_attempts=456,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Status.State"],
+ )
+
+ mock_waiter.wait.assert_called_with(
+ **{"test_arg": "test_value"},
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+ mock_waiter.wait.call_count == 3
+ mock_sleep.assert_called_with(123)
+ assert (
+ caplog.record_tuples
+ == [
+ (
+ "airflow.providers.amazon.aws.utils.waiter_with_logging",
+ logging.INFO,
+ "test status message: Pending",
+ )
+ ]
+ * 2
+ )
+
+ @mock.patch("time.sleep")
+ def test_wait_max_attempts_exceeded(self, mock_sleep, caplog):
+ mock_sleep.return_value = True
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response=generate_response("Pending"),
+ )
+ mock_waiter.wait.side_effect = [error, error, error]
+ with pytest.raises(AirflowException) as exc:
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=123,
+ max_attempts=2,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Status.State"],
+ )
+ assert "Waiter error: max attempts reached" in str(exc)
+ mock_waiter.wait.assert_called_with(
+ **{"test_arg": "test_value"},
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+
+ mock_waiter.wait.call_count == 11
+ mock_sleep.assert_called_with(123)
+ assert (
+ caplog.record_tuples
+ == [
+ (
+ "airflow.providers.amazon.aws.utils.waiter_with_logging",
+ logging.INFO,
+ "test status message: Pending",
+ )
+ ]
+ * 2
+ )
+
+ @mock.patch("time.sleep")
+ def test_wait_with_failure(self, mock_sleep, caplog):
+ mock_sleep.return_value = True
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response=generate_response("Pending"),
+ )
+ failure_error = WaiterError(
+ name="test_waiter",
+ reason="terminal failure in waiter",
+ last_response=generate_response("Failure"),
+ )
+ mock_waiter.wait.side_effect = [error, error, error, failure_error]
+ with pytest.raises(AirflowException) as exc:
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=123,
+ max_attempts=10,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Status.State"],
+ )
+ assert "test failure message" in str(exc)
+ mock_waiter.wait.assert_called_with(
+ **{"test_arg": "test_value"},
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+ assert mock_waiter.wait.call_count == 4
+ assert (
+ caplog.record_tuples
+ == [
+ (
+ "airflow.providers.amazon.aws.utils.waiter_with_logging",
+ logging.INFO,
+ "test status message: Pending",
+ )
+ ]
+ * 3
+ )
+
+ @mock.patch("time.sleep")
+ def test_wait_with_list_response(self, mock_sleep, caplog):
+ mock_sleep.return_value = True
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response={
+ "Clusters": [
+ {
+ "Status": "Pending",
+ },
+ {
+ "Status": "Pending",
+ },
+ ]
+ },
+ )
+ mock_waiter.wait.side_effect = [error, error, True]
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=123,
+ max_attempts=456,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Clusters[0].Status"],
+ )
+
+ mock_waiter.wait.assert_called_with(
+ **{"test_arg": "test_value"},
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+ mock_waiter.wait.call_count == 3
+ mock_sleep.assert_called_with(123)
+ assert (
+ caplog.record_tuples
+ == [
+ (
+ "airflow.providers.amazon.aws.utils.waiter_with_logging",
+ logging.INFO,
+ "test status message: Pending",
+ )
+ ]
+ * 2
+ )
+
+ @mock.patch("time.sleep")
+ def test_wait_with_incorrect_args(self, mock_sleep, caplog):
+ mock_sleep.return_value = True
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response={
+ "Clusters": [
+ {
+ "Status": "Pending",
+ },
+ {
+ "Status": "Pending",
+ },
+ ]
+ },
+ )
+ mock_waiter.wait.side_effect = [error, error, True]
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=123,
+ max_attempts=456,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Clusters[0].State"], # this does not exist in the response
+ )
+
+ mock_waiter.wait.assert_called_with(
+ **{"test_arg": "test_value"},
+ WaiterConfig={
+ "MaxAttempts": 1,
+ },
+ )
+ mock_waiter.wait.call_count == 3
+ mock_sleep.assert_called_with(123)
+ assert (
+ caplog.record_tuples
+ == [
+ (
+ "airflow.providers.amazon.aws.utils.waiter_with_logging",
+ logging.INFO,
+ "test status message: ",
+ )
+ ]
+ * 2
+ )
+
+ @mock.patch("time.sleep")
+ def test_wait_with_multiple_args(self, mock_sleep, caplog):
+ mock_sleep.return_value = True
+ mock_waiter = mock.MagicMock()
+ error = WaiterError(
+ name="test_waiter",
+ reason="test_reason",
+ last_response={
+ "Clusters": [
+ {
+ "Status": "Pending",
+ "StatusDetails": "test_details",
+ "ClusterName": "test_name",
+ },
+ ]
+ },
+ )
+ mock_waiter.wait.side_effect = [error, error, True]
+ wait(
+ waiter=mock_waiter,
+ waiter_delay=123,
+ max_attempts=456,
+ args={"test_arg": "test_value"},
+ failure_message="test failure message",
+ status_message="test status message",
+ status_args=["Clusters[0].Status", "Clusters[0].StatusDetails", "Clusters[0].ClusterName"],
+ )
+ mock_waiter.wait.call_count == 3
+ mock_sleep.assert_called_with(123)
+ assert (
+ caplog.record_tuples
+ == [
+ (
+ "airflow.providers.amazon.aws.utils.waiter_with_logging",
+ logging.INFO,
+ "test status message: Pending - test_details - test_name",
+ )
+ ]
+ * 2
+ )