You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2022/12/12 22:57:37 UTC
[airflow] branch main updated: Better support for Boto Waiters (#28236)
This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 27569a8b37 Better support for Boto Waiters (#28236)
27569a8b37 is described below
commit 27569a8b374a2f7a019f1f08b18a33be84d61693
Author: D. Ferruzzi <fe...@amazon.com>
AuthorDate: Mon Dec 12 14:57:31 2022 -0800
Better support for Boto Waiters (#28236)
* Add waiters to EKS Operators
* Custom waiter setup
* Move the heavy lifting into base hook - Individual service hooks are no longer modified
Add unit testing
Update README
---
airflow/providers/amazon/aws/hooks/base_aws.py | 42 +++++
airflow/providers/amazon/aws/operators/eks.py | 141 ++++++++------
airflow/providers/amazon/aws/waiters/README.md | 100 ++++++++++
.../providers/amazon/aws/waiters/__init__.py | 23 ---
.../providers/amazon/aws/waiters/base_waiter.py | 33 ++--
airflow/providers/amazon/aws/waiters/eks.json | 24 +++
dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2 | 1 +
tests/providers/amazon/aws/operators/test_eks.py | 210 ++++++++++++++++++++-
.../providers/amazon/aws/waiters/__init__.py | 23 ---
.../amazon/aws/waiters/test_custom_waiters.py | 100 ++++++++++
10 files changed, 569 insertions(+), 128 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 6ff5f21467..e8ddadf1e9 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -33,6 +33,8 @@ import uuid
import warnings
from copy import deepcopy
from functools import wraps
+from os import PathLike
+from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
import boto3
@@ -43,6 +45,7 @@ import tenacity
from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
+from botocore.waiter import Waiter, WaiterModel
from dateutil.tz import tzlocal
from slugify import slugify
@@ -51,6 +54,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
+from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -764,6 +768,44 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
except Exception as e:
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")
+ @cached_property
+ def waiter_path(self) -> PathLike[str] | None:
+ path = Path(__file__).parents[1].joinpath(f"waiters/{self.client_type}.json").resolve()
+ return path if path.exists() else None
+
+ def get_waiter(self, waiter_name: str) -> Waiter:
+ """
+ First checks if there is a custom waiter with the provided waiter_name and
+ uses that if it exists, otherwise it will check the service client for a
+ waiter that matches the name and pass that through.
+
+ :param waiter_name: The name of the waiter. The name should exactly match the
+ name of the key in the waiter model file (typically this is CamelCase).
+ """
+ if self.waiter_path and (waiter_name in self._list_custom_waiters()):
+ # Technically if waiter_name is in custom_waiters then self.waiter_path must
+ # exist but MyPy doesn't like the fact that self.waiter_path could be None.
+ with open(self.waiter_path) as config_file:
+ config = json.load(config_file)
+ return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name)
+ # If there is no custom waiter found for the provided name,
+ # then try checking the service's official waiters.
+ return self.conn.get_waiter(waiter_name)
+
+ def list_waiters(self) -> list[str]:
+ """Returns a list containing the names of all waiters for the service, official and custom."""
+ return [*self._list_official_waiters(), *self._list_custom_waiters()]
+
+ def _list_official_waiters(self) -> list[str]:
+ return self.conn.waiter_names
+
+ def _list_custom_waiters(self) -> list[str]:
+ if not self.waiter_path:
+ return []
+ with open(self.waiter_path) as config_file:
+ model_config = json.load(config_file)
+ return WaiterModel(model_config).waiter_names
+
class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py
index 70d74f4cda..050c91e894 100644
--- a/airflow/providers/amazon/aws/operators/eks.py
+++ b/airflow/providers/amazon/aws/operators/eks.py
@@ -19,12 +19,13 @@ from __future__ import annotations
import warnings
from ast import literal_eval
-from time import sleep
from typing import TYPE_CHECKING, Any, List, Sequence, cast
+from botocore.exceptions import ClientError, WaiterError
+
from airflow import AirflowException
from airflow.models import BaseOperator
-from airflow.providers.amazon.aws.hooks.eks import ClusterStates, EksHook, FargateProfileStates
+from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
if TYPE_CHECKING:
@@ -39,7 +40,6 @@ DEFAULT_FARGATE_PROFILE_NAME = "profile"
DEFAULT_NAMESPACE_NAME = "default"
DEFAULT_NODEGROUP_NAME = "nodegroup"
-ABORT_MSG = "{compute} are still active after the allocated time limit. Aborting."
CAN_NOT_DELETE_MSG = "A cluster can not be deleted with attached {compute}. Deleting {count} {compute}."
MISSING_ARN_MSG = "Creating an {compute} requires {requirement} to be passed in."
SUCCESS_MSG = "No {compute} remain, deleting cluster."
@@ -77,6 +77,7 @@ class EksCreateClusterOperator(BaseOperator):
:param compute: The type of compute architecture to generate along with the cluster. (templated)
Defaults to 'nodegroup' to generate an EKS Managed Nodegroup.
:param create_cluster_kwargs: Optional parameters to pass to the CreateCluster API (templated)
+ :param wait_for_completion: If True, waits for operator to complete. (default: False) (templated)
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -117,6 +118,7 @@ class EksCreateClusterOperator(BaseOperator):
"fargate_pod_execution_role_arn",
"fargate_selectors",
"create_fargate_profile_kwargs",
+ "wait_for_completion",
"aws_conn_id",
"region",
)
@@ -135,6 +137,7 @@ class EksCreateClusterOperator(BaseOperator):
fargate_pod_execution_role_arn: str | None = None,
fargate_selectors: list | None = None,
create_fargate_profile_kwargs: dict | None = None,
+ wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
@@ -151,6 +154,7 @@ class EksCreateClusterOperator(BaseOperator):
self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}]
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
+ self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
@@ -182,28 +186,21 @@ class EksCreateClusterOperator(BaseOperator):
**self.create_cluster_kwargs,
)
- if not self.compute:
+ # Short circuit early if we don't need to wait to attach compute
+ # and the caller hasn't requested to wait for the cluster either.
+ if not self.compute and not self.wait_for_completion:
return None
self.log.info("Waiting for EKS Cluster to provision. This will take some time.")
+ client = eks_hook.conn
- countdown = TIMEOUT_SECONDS
- while eks_hook.get_cluster_state(clusterName=self.cluster_name) != ClusterStates.ACTIVE:
- if countdown >= CHECK_INTERVAL_SECONDS:
- countdown -= CHECK_INTERVAL_SECONDS
- self.log.info(
- "Waiting for cluster to start. Checking again in %d seconds", CHECK_INTERVAL_SECONDS
- )
- sleep(CHECK_INTERVAL_SECONDS)
- else:
- message = (
- "Cluster is still inactive after the allocated time limit. "
- "Failed cluster will be torn down."
- )
- self.log.error(message)
- # If there is something preventing the cluster for activating, tear it down and abort.
- eks_hook.delete_cluster(name=self.cluster_name)
- raise RuntimeError(message)
+ try:
+ client.get_waiter("cluster_active").wait(name=self.cluster_name)
+ except (ClientError, WaiterError) as e:
+ self.log.error("Cluster failed to start and will be torn down.\n %s", e)
+ eks_hook.delete_cluster(name=self.cluster_name)
+ client.get_waiter("cluster_deleted").wait(name=self.cluster_name)
+ raise
if self.compute == "nodegroup":
eks_hook.create_nodegroup(
@@ -213,6 +210,12 @@ class EksCreateClusterOperator(BaseOperator):
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)
+ if self.wait_for_completion:
+ self.log.info("Waiting for nodegroup to provision. This will take some time.")
+ client.get_waiter("nodegroup_active").wait(
+ clusterName=self.cluster_name,
+ nodegroupName=self.nodegroup_name,
+ )
elif self.compute == "fargate":
eks_hook.create_fargate_profile(
clusterName=self.cluster_name,
@@ -221,6 +224,12 @@ class EksCreateClusterOperator(BaseOperator):
selectors=self.fargate_selectors,
**self.create_fargate_profile_kwargs,
)
+ if self.wait_for_completion:
+ self.log.info("Waiting for Fargate profile to provision. This will take some time.")
+ client.get_waiter("fargate_profile_active").wait(
+ clusterName=self.cluster_name,
+ fargateProfileName=self.fargate_profile_name,
+ )
class EksCreateNodegroupOperator(BaseOperator):
@@ -238,6 +247,7 @@ class EksCreateNodegroupOperator(BaseOperator):
:param nodegroup_role_arn:
The Amazon Resource Name (ARN) of the IAM role to associate with the managed nodegroup. (templated)
:param create_nodegroup_kwargs: Optional parameters to pass to the Create Nodegroup API (templated)
+ :param wait_for_completion: If True, waits for operator to complete. (default: False) (templated)
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -254,6 +264,7 @@ class EksCreateNodegroupOperator(BaseOperator):
"nodegroup_role_arn",
"nodegroup_name",
"create_nodegroup_kwargs",
+ "wait_for_completion",
"aws_conn_id",
"region",
)
@@ -265,6 +276,7 @@ class EksCreateNodegroupOperator(BaseOperator):
nodegroup_role_arn: str,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
create_nodegroup_kwargs: dict | None = None,
+ wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
@@ -273,6 +285,7 @@ class EksCreateNodegroupOperator(BaseOperator):
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
+ self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
self.nodegroup_subnets = nodegroup_subnets
@@ -304,6 +317,12 @@ class EksCreateNodegroupOperator(BaseOperator):
**self.create_nodegroup_kwargs,
)
+ if self.wait_for_completion:
+ self.log.info("Waiting for nodegroup to provision. This will take some time.")
+ eks_hook.conn.get_waiter("nodegroup_active").wait(
+ clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
+ )
+
class EksCreateFargateProfileOperator(BaseOperator):
"""
@@ -320,6 +339,7 @@ class EksCreateFargateProfileOperator(BaseOperator):
:param fargate_profile_name: The unique name to give your AWS Fargate profile. (templated)
:param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargate Profile API
(templated)
+ :param wait_for_completion: If True, waits for operator to complete. (default: False) (templated)
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
@@ -336,6 +356,7 @@ class EksCreateFargateProfileOperator(BaseOperator):
"selectors",
"fargate_profile_name",
"create_fargate_profile_kwargs",
+ "wait_for_completion",
"aws_conn_id",
"region",
)
@@ -347,6 +368,7 @@ class EksCreateFargateProfileOperator(BaseOperator):
selectors: list,
fargate_profile_name: str | None = DEFAULT_FARGATE_PROFILE_NAME,
create_fargate_profile_kwargs: dict | None = None,
+ wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
@@ -356,6 +378,7 @@ class EksCreateFargateProfileOperator(BaseOperator):
self.selectors = selectors
self.fargate_profile_name = fargate_profile_name
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
+ self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
@@ -374,6 +397,12 @@ class EksCreateFargateProfileOperator(BaseOperator):
**self.create_fargate_profile_kwargs,
)
+ if self.wait_for_completion:
+ self.log.info("Waiting for Fargate profile to provision. This will take some time.")
+ eks_hook.conn.get_waiter("fargate_profile_active").wait(
+ clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
+ )
+
class EksDeleteClusterOperator(BaseOperator):
"""
@@ -386,6 +415,7 @@ class EksDeleteClusterOperator(BaseOperator):
:param cluster_name: The name of the Amazon EKS Cluster to delete. (templated)
:param force_delete_compute: If True, will delete any attached resources. (templated)
Defaults to False.
+ :param wait_for_completion: If True, waits for operator to complete. (default: False) (templated)
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -399,6 +429,7 @@ class EksDeleteClusterOperator(BaseOperator):
template_fields: Sequence[str] = (
"cluster_name",
"force_delete_compute",
+ "wait_for_completion",
"aws_conn_id",
"region",
)
@@ -407,12 +438,14 @@ class EksDeleteClusterOperator(BaseOperator):
self,
cluster_name: str,
force_delete_compute: bool = False,
+ wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.force_delete_compute = force_delete_compute
+ self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
@@ -429,34 +462,25 @@ class EksDeleteClusterOperator(BaseOperator):
eks_hook.delete_cluster(name=self.cluster_name)
+ if self.wait_for_completion:
+ self.log.info("Waiting for cluster to delete. This will take some time.")
+ eks_hook.conn.get_waiter("cluster_deleted").wait(name=self.cluster_name)
+
def delete_any_nodegroups(self, eks_hook) -> None:
"""
Deletes all Amazon EKS managed node groups for a provided Amazon EKS Cluster.
Amazon EKS managed node groups can be deleted in parallel, so we can send all
- of the delete commands in bulk and move on once the count of nodegroups is zero.
+ delete commands in bulk and move on once the count of nodegroups is zero.
"""
nodegroups = eks_hook.list_nodegroups(clusterName=self.cluster_name)
if nodegroups:
self.log.info(CAN_NOT_DELETE_MSG.format(compute=NODEGROUP_FULL_NAME, count=len(nodegroups)))
for group in nodegroups:
eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=group)
-
- # Scaling up the timeout based on the number of nodegroups that are being processed.
- additional_seconds = 5 * 60
- countdown = TIMEOUT_SECONDS + (len(nodegroups) * additional_seconds)
- while eks_hook.list_nodegroups(clusterName=self.cluster_name):
- if countdown >= CHECK_INTERVAL_SECONDS:
- countdown -= CHECK_INTERVAL_SECONDS
- sleep(CHECK_INTERVAL_SECONDS)
- self.log.info(
- "Waiting for the remaining %s nodegroups to delete. "
- "Checking again in %d seconds.",
- len(nodegroups),
- CHECK_INTERVAL_SECONDS,
- )
- else:
- raise RuntimeError(ABORT_MSG.format(compute=NODEGROUP_FULL_NAME))
+ # Note this is a custom waiter so we're using hook.get_waiter(), not hook.conn.get_waiter().
+ self.log.info("Waiting for all nodegroups to delete. This will take some time.")
+ eks_hook.get_waiter("all_nodegroups_deleted").wait(clusterName=self.cluster_name)
self.log.info(SUCCESS_MSG.format(compute=NODEGROUP_FULL_NAME))
def delete_any_fargate_profiles(self, eks_hook) -> None:
@@ -469,30 +493,15 @@ class EksDeleteClusterOperator(BaseOperator):
fargate_profiles = eks_hook.list_fargate_profiles(clusterName=self.cluster_name)
if fargate_profiles:
self.log.info(CAN_NOT_DELETE_MSG.format(compute=FARGATE_FULL_NAME, count=len(fargate_profiles)))
+ self.log.info("Waiting for Fargate profiles to delete. This will take some time.")
for profile in fargate_profiles:
# The API will return a (cluster) ResourceInUseException if you try
# to delete Fargate profiles in parallel the way we can with nodegroups,
# so each must be deleted sequentially
eks_hook.delete_fargate_profile(clusterName=self.cluster_name, fargateProfileName=profile)
-
- countdown = TIMEOUT_SECONDS
- while (
- eks_hook.get_fargate_profile_state(
- clusterName=self.cluster_name, fargateProfileName=profile
- )
- != FargateProfileStates.NONEXISTENT
- ):
- if countdown >= CHECK_INTERVAL_SECONDS:
- countdown -= CHECK_INTERVAL_SECONDS
- sleep(CHECK_INTERVAL_SECONDS)
- self.log.info(
- "Waiting for the AWS Fargate profile %s to delete. "
- "Checking again in %d seconds.",
- profile,
- CHECK_INTERVAL_SECONDS,
- )
- else:
- raise RuntimeError(ABORT_MSG.format(compute=FARGATE_FULL_NAME))
+ eks_hook.conn.get_waiter("fargate_profile_deleted").wait(
+ clusterName=self.cluster_name, fargateProfileName=profile
+ )
self.log.info(SUCCESS_MSG.format(compute=FARGATE_FULL_NAME))
@@ -506,6 +515,7 @@ class EksDeleteNodegroupOperator(BaseOperator):
:param cluster_name: The name of the Amazon EKS Cluster associated with your nodegroup. (templated)
:param nodegroup_name: The name of the nodegroup to delete. (templated)
+ :param wait_for_completion: If True, waits for operator to complete. (default: False) (templated)
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -519,6 +529,7 @@ class EksDeleteNodegroupOperator(BaseOperator):
template_fields: Sequence[str] = (
"cluster_name",
"nodegroup_name",
+ "wait_for_completion",
"aws_conn_id",
"region",
)
@@ -527,12 +538,14 @@ class EksDeleteNodegroupOperator(BaseOperator):
self,
cluster_name: str,
nodegroup_name: str,
+ wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
self.nodegroup_name = nodegroup_name
+ self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
@@ -544,6 +557,11 @@ class EksDeleteNodegroupOperator(BaseOperator):
)
eks_hook.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=self.nodegroup_name)
+ if self.wait_for_completion:
+ self.log.info("Waiting for nodegroup to delete. This will take some time.")
+ eks_hook.conn.get_waiter("nodegroup_deleted").wait(
+ clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
+ )
class EksDeleteFargateProfileOperator(BaseOperator):
@@ -556,6 +574,7 @@ class EksDeleteFargateProfileOperator(BaseOperator):
:param cluster_name: The name of the Amazon EKS cluster associated with your Fargate profile. (templated)
:param fargate_profile_name: The name of the AWS Fargate profile to delete. (templated)
+ :param wait_for_completion: If True, waits for operator to complete. (default: False) (templated)
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -568,6 +587,7 @@ class EksDeleteFargateProfileOperator(BaseOperator):
template_fields: Sequence[str] = (
"cluster_name",
"fargate_profile_name",
+ "wait_for_completion",
"aws_conn_id",
"region",
)
@@ -576,6 +596,7 @@ class EksDeleteFargateProfileOperator(BaseOperator):
self,
cluster_name: str,
fargate_profile_name: str,
+ wait_for_completion: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
@@ -583,6 +604,7 @@ class EksDeleteFargateProfileOperator(BaseOperator):
super().__init__(**kwargs)
self.cluster_name = cluster_name
self.fargate_profile_name = fargate_profile_name
+ self.wait_for_completion = wait_for_completion
self.aws_conn_id = aws_conn_id
self.region = region
@@ -595,6 +617,11 @@ class EksDeleteFargateProfileOperator(BaseOperator):
eks_hook.delete_fargate_profile(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
)
+ if self.wait_for_completion:
+ self.log.info("Waiting for Fargate profile to delete. This will take some time.")
+ eks_hook.conn.get_waiter("fargate_profile_deleted").wait(
+ clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
+ )
class EksPodOperator(KubernetesPodOperator):
diff --git a/airflow/providers/amazon/aws/waiters/README.md b/airflow/providers/amazon/aws/waiters/README.md
new file mode 100644
index 0000000000..d6e8958b8e
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/README.md
@@ -0,0 +1,100 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+ -->
+
+This module is for custom Boto3 waiter configuration files. Since documentation
+on creating custom waiters is pretty sparse out in the wild, this document can
+act as a rough quickstart guide. It is not meant to cover all edge cases.
+
+# To add a new custom waiter
+
+## Create or modify the service waiter config file
+
+Find or create a file for the service it is related to, for example waiters/eks.json
+
+### In the service waiter config file
+
+Build or add to the waiter model config json in that file. For examples of what these
+should look like, have a look through some official waiter models. Some examples:
+
+* [Cloudwatch](https://github.com/boto/botocore/blob/develop/botocore/data/cloudwatch/2010-08-01/waiters-2.json)
+* [EC2](https://github.com/boto/botocore/blob/develop/botocore/data/ec2/2016-11-15/waiters-2.json)
+* [EKS](https://github.com/boto/botocore/blob/develop/botocore/data/eks/2017-11-01/waiters-2.json)
+
+Below is an example of a working waiter model config that will make an EKS waiter which will wait for
+all Nodegroups in a cluster to be deleted. An explanation follows the code snippet. Note the backticks
+to escape the integers in the "argument" values.
+
+```json
+{
+ "version": 2,
+ "waiters": {
+ "all_nodegroups_deleted": {
+ "operation": "ListNodegroups",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "length(nodegroups[]) == `0`",
+ "expected": true,
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "expected": true,
+ "argument": "length(nodegroups[]) > `0`",
+ "state": "retry"
+ }
+ ]
+ }
+ }
+}
+```
+
+In the model config above we create a new waiter called `all_nodegroups_deleted` which calls
+the `ListNodegroups` API endpoint. The parameters for the endpoint call must be passed into
+the `waiter.wait()` call, the same as when using an official waiter. The waiter then performs
+"argument" (in this case `len(result) == 0`) on the result. If the argument returns the value
+in "expected" (in this case `True`) then the waiter's state is set to `success`, the waiter can
+close down, and the operator which called it can continue. If `len(result) > 0` is `True` then
+the state is set to `retry`. The waiter will "delay" 30 seconds before trying again. If the
+state does not go to `success` before the maxAttempts number of tries, the waiter raises a
+WaiterException. Both `retry` and `maxAttempts` can be overridden by the user when calling
+`waiter.wait()` like any other waiter.
+
+### That's It!
+
+The AwsBaseHook handles the rest. Using the above waiter will look like this:
+`EksHook().get_waiter("all_nodegroups_deleted").wait(clusterName="my_cluster")`
+and for testing purposes, a `list_custom_waiters()` helper method is proved which can
+be used the same way: `EksHook().list_custom_waiters()`
+
+
+### In your Operators (How to use these)
+
+Once configured correctly, the custom waiter will be nearly indistinguishable from an official waiter.
+Below is an example of an official waiter followed by a custom one.
+
+```python
+EksHook().conn.get_waiter("nodegroup_deleted").wait(clusterName=cluster_name, nodegroupName=nodegroup_name)
+EksHook().get_waiter("all_nodegroups_deleted").wait(clusterName=cluster_name)
+```
+
+Note that since the get_waiter is in the hook instead of on the client side, a custom waiter is
+just `hook.get_waiter` and not `hook.conn.get_waiter`. Other than that, they should be identical.
diff --git a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2 b/airflow/providers/amazon/aws/waiters/__init__.py
similarity index 52%
copy from dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
copy to airflow/providers/amazon/aws/waiters/__init__.py
index 1cbbab4513..13a83393a9 100644
--- a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
+++ b/airflow/providers/amazon/aws/waiters/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,25 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
-# OVERWRITTEN WHEN PREPARING PACKAGES.
-
-# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE
-# `MANIFEST_TEMPLATE.py.jinja2` IN the `provider_packages` DIRECTORY
-
-
-{% if PROVIDER_PACKAGE_ID == 'amazon' %}
-include airflow/providers/amazon/aws/hooks/batch_waiters.json
-{% elif PROVIDER_PACKAGE_ID == 'google' %}
-include airflow/providers/google/cloud/example_dags/*.yaml
-include airflow/providers/google/cloud/example_dags/*.sql
-{% elif PROVIDER_PACKAGE_ID == 'cncf.kubernetes' %}
-include airflow/providers/cncf/kubernetes/*.jinja2
-{% endif %}
-
-include NOTICE
-include LICENSE
-include CHANGELOG.txt
-include README.md
-global-exclude __pycache__ *.pyc
diff --git a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2 b/airflow/providers/amazon/aws/waiters/base_waiter.py
similarity index 52%
copy from dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
copy to airflow/providers/amazon/aws/waiters/base_waiter.py
index 1cbbab4513..0d9f8a1d4e 100644
--- a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
+++ b/airflow/providers/amazon/aws/waiters/base_waiter.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,24 +15,22 @@
# specific language governing permissions and limitations
# under the License.
-# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
-# OVERWRITTEN WHEN PREPARING PACKAGES.
+from __future__ import annotations
+
+import boto3
+from botocore.waiter import Waiter, WaiterModel, create_waiter_with_client
+
-# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE
-# `MANIFEST_TEMPLATE.py.jinja2` IN the `provider_packages` DIRECTORY
+class BaseBotoWaiter:
+ """
+ Used to create custom Boto3 Waiters.
+ For more details, see airflow/providers/amazon/aws/waiters/README.md
+ """
-{% if PROVIDER_PACKAGE_ID == 'amazon' %}
-include airflow/providers/amazon/aws/hooks/batch_waiters.json
-{% elif PROVIDER_PACKAGE_ID == 'google' %}
-include airflow/providers/google/cloud/example_dags/*.yaml
-include airflow/providers/google/cloud/example_dags/*.sql
-{% elif PROVIDER_PACKAGE_ID == 'cncf.kubernetes' %}
-include airflow/providers/cncf/kubernetes/*.jinja2
-{% endif %}
+ def __init__(self, client: boto3.client, model_config: dict) -> None:
+ self.model = WaiterModel(model_config)
+ self.client = client
-include NOTICE
-include LICENSE
-include CHANGELOG.txt
-include README.md
-global-exclude __pycache__ *.pyc
+ def waiter(self, waiter_name: str) -> Waiter:
+ return create_waiter_with_client(waiter_name=waiter_name, waiter_model=self.model, client=self.client)
diff --git a/airflow/providers/amazon/aws/waiters/eks.json b/airflow/providers/amazon/aws/waiters/eks.json
new file mode 100644
index 0000000000..71694520cb
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/eks.json
@@ -0,0 +1,24 @@
+{
+ "version": 2,
+ "waiters": {
+ "all_nodegroups_deleted": {
+ "operation": "ListNodegroups",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "length(nodegroups[]) == `0`",
+ "expected": true,
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "expected": true,
+ "argument": "length(nodegroups[]) > `0`",
+ "state": "retry"
+ }
+ ]
+ }
+ }
+}
diff --git a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2 b/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
index 1cbbab4513..c83adbdd57 100644
--- a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
+++ b/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
@@ -25,6 +25,7 @@
{% if PROVIDER_PACKAGE_ID == 'amazon' %}
include airflow/providers/amazon/aws/hooks/batch_waiters.json
+include airflow/providers/amazon/aws/waiters/*.json
{% elif PROVIDER_PACKAGE_ID == 'google' %}
include airflow/providers/google/cloud/example_dags/*.yaml
include airflow/providers/google/cloud/example_dags/*.sql
diff --git a/tests/providers/amazon/aws/operators/test_eks.py b/tests/providers/amazon/aws/operators/test_eks.py
index 9a99bf23e5..15dd50f103 100644
--- a/tests/providers/amazon/aws/operators/test_eks.py
+++ b/tests/providers/amazon/aws/operators/test_eks.py
@@ -21,6 +21,7 @@ from typing import Any
from unittest import mock
import pytest
+from botocore.waiter import Waiter
from airflow.providers.amazon.aws.hooks.eks import ClusterStates, EksHook
from airflow.providers.amazon.aws.operators.eks import (
@@ -60,6 +61,17 @@ CREATE_NODEGROUP_KWARGS = {
}
+def assert_expected_waiter_type(waiter: mock.MagicMock, expected: str):
+ """
+ There does not appear to be a straight-forward way to assert the type of waiter.
+ Instead, get the class name and check if it contains the expected name.
+
+ :param waiter: A mocked Boto3 Waiter object.
+ :param expected: The expected class name of the Waiter object, for example "ClusterActive".
+ """
+ assert expected in str(type(waiter.call_args[0][0]))
+
+
class ClusterParams(TypedDict):
cluster_name: str
cluster_role_arn: str
@@ -168,24 +180,75 @@ class TestEksCreateClusterOperator:
mock_create_cluster.assert_called_with(**convert_keys(parameters))
mock_create_nodegroup.assert_not_called()
+ @pytest.mark.parametrize(
+ "create_cluster_kwargs",
+ [
+ pytest.param(None, id="without cluster kwargs"),
+ pytest.param(CREATE_CLUSTER_KWARGS, id="with cluster kwargs"),
+ ],
+ )
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "create_cluster")
+ @mock.patch.object(EksHook, "create_nodegroup")
+ def test_execute_create_cluster_with_wait(
+ self, mock_create_nodegroup, mock_create_cluster, mock_waiter, create_cluster_kwargs
+ ):
+ op_kwargs = {**self.create_cluster_params, "compute": None}
+ if create_cluster_kwargs:
+ op_kwargs["create_cluster_kwargs"] = create_cluster_kwargs
+ parameters = {**self.create_cluster_params, **create_cluster_kwargs}
+ else:
+ assert "create_cluster_kwargs" not in op_kwargs
+ parameters = self.create_cluster_params
+
+ operator = EksCreateClusterOperator(task_id=TASK_ID, **op_kwargs, wait_for_completion=True)
+ operator.execute({})
+ mock_create_cluster.assert_called_with(**convert_keys(parameters))
+ mock_create_nodegroup.assert_not_called()
+ mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+ assert_expected_waiter_type(mock_waiter, "ClusterActive")
+
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "get_cluster_state")
@mock.patch.object(EksHook, "create_cluster")
@mock.patch.object(EksHook, "create_nodegroup")
def test_execute_when_called_with_nodegroup_creates_both(
- self, mock_create_nodegroup, mock_create_cluster, mock_cluster_state
+ self, mock_create_nodegroup, mock_create_cluster, mock_cluster_state, mock_waiter
+ ):
+ mock_cluster_state.return_value = ClusterStates.ACTIVE
+
+ self.create_cluster_operator_with_nodegroup.execute({})
+
+ mock_create_cluster.assert_called_once_with(**convert_keys(self.create_cluster_params))
+ mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params))
+ mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+ assert_expected_waiter_type(mock_waiter, "ClusterActive")
+
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "get_cluster_state")
+ @mock.patch.object(EksHook, "create_cluster")
+ @mock.patch.object(EksHook, "create_nodegroup")
+ def test_execute_with_wait_when_called_with_nodegroup_creates_both(
+ self, mock_create_nodegroup, mock_create_cluster, mock_cluster_state, mock_waiter
):
mock_cluster_state.return_value = ClusterStates.ACTIVE
+ self.create_cluster_operator_with_nodegroup.wait_for_completion = True
self.create_cluster_operator_with_nodegroup.execute({})
mock_create_cluster.assert_called_once_with(**convert_keys(self.create_cluster_params))
mock_create_nodegroup.assert_called_once_with(**convert_keys(self.create_nodegroup_params))
+ # Calls waiter once for the cluster and once for the nodegroup.
+ assert mock_waiter.call_count == 2
+ mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME)
+ assert_expected_waiter_type(mock_waiter, "NodegroupActive")
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "get_cluster_state")
@mock.patch.object(EksHook, "create_cluster")
@mock.patch.object(EksHook, "create_fargate_profile")
def test_execute_when_called_with_fargate_creates_both(
- self, mock_create_fargate_profile, mock_create_cluster, mock_cluster_state
+ self, mock_create_fargate_profile, mock_create_cluster, mock_cluster_state, mock_waiter
):
mock_cluster_state.return_value = ClusterStates.ACTIVE
@@ -195,6 +258,31 @@ class TestEksCreateClusterOperator:
mock_create_fargate_profile.assert_called_once_with(
**convert_keys(self.create_fargate_profile_params)
)
+ mock_waiter.assert_called_once_with(mock.ANY, name=CLUSTER_NAME)
+ assert_expected_waiter_type(mock_waiter, "ClusterActive")
+
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "get_cluster_state")
+ @mock.patch.object(EksHook, "create_cluster")
+ @mock.patch.object(EksHook, "create_fargate_profile")
+ def test_execute_with_wait_when_called_with_fargate_creates_both(
+ self, mock_create_fargate_profile, mock_create_cluster, mock_cluster_state, mock_waiter
+ ):
+ mock_cluster_state.return_value = ClusterStates.ACTIVE
+ self.create_cluster_operator_with_fargate_profile.wait_for_completion = True
+
+ self.create_cluster_operator_with_fargate_profile.execute({})
+
+ mock_create_cluster.assert_called_once_with(**convert_keys(self.create_cluster_params))
+ mock_create_fargate_profile.assert_called_once_with(
+ **convert_keys(self.create_fargate_profile_params)
+ )
+ # Calls waiter once for the cluster and once for the nodegroup.
+ assert mock_waiter.call_count == 2
+ mock_waiter.assert_called_with(
+ mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME
+ )
+ assert_expected_waiter_type(mock_waiter, "FargateProfileActive")
def test_invalid_compute_value(self):
invalid_compute = EksCreateClusterOperator(
@@ -249,9 +337,10 @@ class TestEksCreateFargateProfileOperator:
pytest.param(CREATE_FARGATE_PROFILE_KWARGS, id="with fargate profile kwargs"),
],
)
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "create_fargate_profile")
def test_execute_when_fargate_profile_does_not_already_exist(
- self, mock_create_fargate_profile, create_fargate_profile_kwargs
+ self, mock_create_fargate_profile, mock_waiter, create_fargate_profile_kwargs
):
op_kwargs = {**self.create_fargate_profile_params}
if create_fargate_profile_kwargs:
@@ -264,6 +353,35 @@ class TestEksCreateFargateProfileOperator:
operator = EksCreateFargateProfileOperator(task_id=TASK_ID, **op_kwargs)
operator.execute({})
mock_create_fargate_profile.assert_called_with(**convert_keys(parameters))
+ mock_waiter.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "create_fargate_profile_kwargs",
+ [
+ pytest.param(None, id="without fargate profile kwargs"),
+ pytest.param(CREATE_FARGATE_PROFILE_KWARGS, id="with fargate profile kwargs"),
+ ],
+ )
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "create_fargate_profile")
+ def test_execute_with_wait_when_fargate_profile_does_not_already_exist(
+ self, mock_create_fargate_profile, mock_waiter, create_fargate_profile_kwargs
+ ):
+ op_kwargs = {**self.create_fargate_profile_params}
+ if create_fargate_profile_kwargs:
+ op_kwargs["create_fargate_profile_kwargs"] = create_fargate_profile_kwargs
+ parameters = {**self.create_fargate_profile_params, **create_fargate_profile_kwargs}
+ else:
+ assert "create_fargate_profile_kwargs" not in op_kwargs
+ parameters = self.create_fargate_profile_params
+
+ operator = EksCreateFargateProfileOperator(task_id=TASK_ID, **op_kwargs, wait_for_completion=True)
+ operator.execute({})
+ mock_create_fargate_profile.assert_called_with(**convert_keys(parameters))
+ mock_waiter.assert_called_with(
+ mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME
+ )
+ assert_expected_waiter_type(mock_waiter, "FargateProfileActive")
class TestEksCreateNodegroupOperator:
@@ -282,9 +400,10 @@ class TestEksCreateNodegroupOperator:
pytest.param(CREATE_NODEGROUP_KWARGS, id="with nodegroup kwargs"),
],
)
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "create_nodegroup")
def test_execute_when_nodegroup_does_not_already_exist(
- self, mock_create_nodegroup, create_nodegroup_kwargs
+ self, mock_create_nodegroup, mock_waiter, create_nodegroup_kwargs
):
op_kwargs = {**self.create_nodegroup_params}
if create_nodegroup_kwargs:
@@ -297,6 +416,33 @@ class TestEksCreateNodegroupOperator:
operator = EksCreateNodegroupOperator(task_id=TASK_ID, **op_kwargs)
operator.execute({})
mock_create_nodegroup.assert_called_with(**convert_keys(parameters))
+ mock_waiter.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "create_nodegroup_kwargs",
+ [
+ pytest.param(None, id="without nodegroup kwargs"),
+ pytest.param(CREATE_NODEGROUP_KWARGS, id="with nodegroup kwargs"),
+ ],
+ )
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "create_nodegroup")
+ def test_execute_with_wait_when_nodegroup_does_not_already_exist(
+ self, mock_create_nodegroup, mock_waiter, create_nodegroup_kwargs
+ ):
+ op_kwargs = {**self.create_nodegroup_params}
+ if create_nodegroup_kwargs:
+ op_kwargs["create_nodegroup_kwargs"] = create_nodegroup_kwargs
+ parameters = {**self.create_nodegroup_params, **create_nodegroup_kwargs}
+ else:
+ assert "create_nodegroup_params" not in op_kwargs
+ parameters = self.create_nodegroup_params
+
+ operator = EksCreateNodegroupOperator(task_id=TASK_ID, **op_kwargs, wait_for_completion=True)
+ operator.execute({})
+ mock_create_nodegroup.assert_called_with(**convert_keys(parameters))
+ mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME)
+ assert_expected_waiter_type(mock_waiter, "NodegroupActive")
class TestEksDeleteClusterOperator:
@@ -307,12 +453,30 @@ class TestEksDeleteClusterOperator:
task_id=TASK_ID, cluster_name=self.cluster_name
)
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "list_nodegroups")
+ @mock.patch.object(EksHook, "delete_cluster")
+ def test_existing_cluster_not_in_use(self, mock_delete_cluster, mock_list_nodegroups, mock_waiter):
+ mock_list_nodegroups.return_value = []
+ self.delete_cluster_operator.execute({})
+ mock_delete_cluster.assert_called_once_with(name=self.cluster_name)
+ mock_waiter.assert_not_called()
+
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "list_nodegroups")
@mock.patch.object(EksHook, "delete_cluster")
- def test_existing_cluster_not_in_use(self, mock_delete_cluster, mock_list_nodegroups):
+ def test_existing_cluster_not_in_use_with_wait(
+ self, mock_delete_cluster, mock_list_nodegroups, mock_waiter
+ ):
mock_list_nodegroups.return_value = []
+ self.delete_cluster_operator.wait_for_completion = True
+
self.delete_cluster_operator.execute({})
+
+ mock_list_nodegroups.assert_called_once
mock_delete_cluster.assert_called_once_with(name=self.cluster_name)
+ mock_waiter.assert_called_with(mock.ANY, name=CLUSTER_NAME)
+ assert_expected_waiter_type(mock_waiter, "ClusterDeleted")
class TestEksDeleteNodegroupOperator:
@@ -324,13 +488,28 @@ class TestEksDeleteNodegroupOperator:
task_id=TASK_ID, cluster_name=self.cluster_name, nodegroup_name=self.nodegroup_name
)
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "delete_nodegroup")
+ def test_existing_nodegroup(self, mock_delete_nodegroup, mock_waiter):
+ self.delete_nodegroup_operator.execute({})
+
+ mock_delete_nodegroup.assert_called_once_with(
+ clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
+ )
+ mock_waiter.assert_not_called()
+
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "delete_nodegroup")
- def test_existing_nodegroup(self, mock_delete_nodegroup):
+ def test_existing_nodegroup_with_wait(self, mock_delete_nodegroup, mock_waiter):
+ self.delete_nodegroup_operator.wait_for_completion = True
+
self.delete_nodegroup_operator.execute({})
mock_delete_nodegroup.assert_called_once_with(
clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
)
+ mock_waiter.assert_called_with(mock.ANY, clusterName=CLUSTER_NAME, nodegroupName=NODEGROUP_NAME)
+ assert_expected_waiter_type(mock_waiter, "NodegroupDeleted")
class TestEksDeleteFargateProfileOperator:
@@ -342,13 +521,30 @@ class TestEksDeleteFargateProfileOperator:
task_id=TASK_ID, cluster_name=self.cluster_name, fargate_profile_name=self.fargate_profile_name
)
+ @mock.patch.object(Waiter, "wait")
+ @mock.patch.object(EksHook, "delete_fargate_profile")
+ def test_existing_fargate_profile(self, mock_delete_fargate_profile, mock_waiter):
+ self.delete_fargate_profile_operator.execute({})
+
+ mock_delete_fargate_profile.assert_called_once_with(
+ clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
+ )
+ mock_waiter.assert_not_called()
+
+ @mock.patch.object(Waiter, "wait")
@mock.patch.object(EksHook, "delete_fargate_profile")
- def test_existing_fargate_profile(self, mock_delete_fargate_profile):
+ def test_existing_fargate_profile_with_wait(self, mock_delete_fargate_profile, mock_waiter):
+ self.delete_fargate_profile_operator.wait_for_completion = True
+
self.delete_fargate_profile_operator.execute({})
mock_delete_fargate_profile.assert_called_once_with(
clusterName=self.cluster_name, fargateProfileName=self.fargate_profile_name
)
+ mock_waiter.assert_called_with(
+ mock.ANY, clusterName=CLUSTER_NAME, fargateProfileName=FARGATE_PROFILE_NAME
+ )
+ assert_expected_waiter_type(mock_waiter, "FargateProfileDeleted")
class TestEksPodOperator:
diff --git a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2 b/tests/providers/amazon/aws/waiters/__init__.py
similarity index 52%
copy from dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
copy to tests/providers/amazon/aws/waiters/__init__.py
index 1cbbab4513..13a83393a9 100644
--- a/dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
+++ b/tests/providers/amazon/aws/waiters/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,25 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
-# OVERWRITTEN WHEN PREPARING PACKAGES.
-
-# IF YOU WANT TO MODIFY IT, YOU SHOULD MODIFY THE TEMPLATE
-# `MANIFEST_TEMPLATE.py.jinja2` IN the `provider_packages` DIRECTORY
-
-
-{% if PROVIDER_PACKAGE_ID == 'amazon' %}
-include airflow/providers/amazon/aws/hooks/batch_waiters.json
-{% elif PROVIDER_PACKAGE_ID == 'google' %}
-include airflow/providers/google/cloud/example_dags/*.yaml
-include airflow/providers/google/cloud/example_dags/*.sql
-{% elif PROVIDER_PACKAGE_ID == 'cncf.kubernetes' %}
-include airflow/providers/cncf/kubernetes/*.jinja2
-{% endif %}
-
-include NOTICE
-include LICENSE
-include CHANGELOG.txt
-include README.md
-global-exclude __pycache__ *.pyc
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
new file mode 100644
index 0000000000..b5ce808d3a
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+
+import boto3
+from botocore.waiter import WaiterModel
+from moto import mock_eks
+
+from airflow.providers.amazon.aws.hooks.eks import EksHook
+from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
+
+
+def assert_all_match(*args):
+ assert len(set(args)) == 1
+
+
+class TestBaseWaiter:
+ def test_init(self):
+ waiter_name = "test_waiter"
+ client_name = "test_client"
+ waiter_model_config = {
+ "version": 2,
+ "waiters": {
+ waiter_name: {
+ "operation": "ListNodegroups",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "length(nodegroups[]) == `0`",
+ "expected": True,
+ "state": "success",
+ },
+ {
+ "matcher": "path",
+ "expected": True,
+ "argument": "length(nodegroups[]) > `0`",
+ "state": "retry",
+ },
+ ],
+ }
+ },
+ }
+ expected_model = WaiterModel(waiter_model_config)
+
+ waiter = BaseBotoWaiter(client_name, waiter_model_config)
+
+ # WaiterModel objects don't implement an eq() so equivalence checking manually.
+ for attr, _ in expected_model.__dict__.items():
+ assert waiter.model.__getattribute__(attr) == expected_model.__getattribute__(attr)
+ assert waiter.client == client_name
+
+
+class TestServiceWaiters:
+ def test_service_waiters(self):
+ hook = EksHook()
+ with open(hook.waiter_path) as config_file:
+ expected_waiters = json.load(config_file)["waiters"]
+
+ for waiter in list(expected_waiters.keys()):
+ assert waiter in hook.list_waiters()
+ assert waiter in hook._list_custom_waiters()
+
+ @mock_eks
+ def test_existing_waiter_inherited(self):
+ """
+ AwsBaseHook::get_waiter will first check if there is a custom waiter with the
+ provided name and pass that through is it exists, otherwise it will check the
+ custom waiters for the given service. This test checks to make sure that the
+ waiter is the same whichever way you get it and no modifications are made.
+ """
+ hook_waiter = EksHook().get_waiter("cluster_active")
+ client_waiter = EksHook().conn.get_waiter("cluster_active")
+ boto_waiter = boto3.client("eks").get_waiter("cluster_active")
+
+ assert_all_match(hook_waiter.name, client_waiter.name, boto_waiter.name)
+ assert_all_match(len(hook_waiter.__dict__), len(client_waiter.__dict__), len(boto_waiter.__dict__))
+ for attr, _ in hook_waiter.__dict__.items():
+ # Not all attributes in a Waiter are directly comparable
+ # so the best we can do it make sure the same attrs exist.
+ assert hasattr(boto_waiter, attr)
+ assert hasattr(client_waiter, attr)