You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/04 21:10:52 UTC
[airflow] branch main updated: Add test_connection method to AWS hook (#24662)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 210549c658 Add test_connection method to AWS hook (#24662)
210549c658 is described below
commit 210549c658c96ad0129609f50a46e40eebfdaa23
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Tue Jul 5 01:10:41 2022 +0400
Add test_connection method to AWS hook (#24662)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 39 ++++++++++++++++++-----
tests/providers/amazon/aws/hooks/test_base_aws.py | 23 +++++++++++++
2 files changed, 54 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 5e518adde0..28135b0fa4 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -26,6 +26,7 @@ This module contains Base AWS Hook.
import configparser
import datetime
+import json
import logging
import warnings
from functools import wraps
@@ -409,9 +410,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
self.region_name = region_name
self.config = config
- if not (self.client_type or self.resource_type):
- raise AirflowException('Either client_type or resource_type must be provided.')
-
def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
if not self.aws_conn_id:
@@ -510,13 +508,15 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
:return: boto3.client or boto3.resource
:rtype: Union[boto3.client, boto3.resource]
"""
- if self.client_type:
+ if not ((not self.client_type) ^ (not self.resource_type)):
+ raise ValueError(
+ f"Either client_type={self.client_type!r} or "
+ f"resource_type={self.resource_type!r} must be provided, not both."
+ )
+ elif self.client_type:
return self.get_client_type(region_name=self.region_name)
- elif self.resource_type:
- return self.get_resource_type(region_name=self.region_name)
else:
- # Rare possibility - subclasses have not specified a client_type or resource_type
- raise NotImplementedError('Could not get boto3 connection!')
+ return self.get_resource_type(region_name=self.region_name)
@cached_property
def conn_client_meta(self) -> ClientMeta:
@@ -611,6 +611,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
return retry_decorator
+ def test_connection(self):
+ """
+ Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API.
+
+ .. seealso::
+ https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
+ """
+ orig_client_type, self.client_type = self.client_type, 'sts'
+ try:
+ res = self.get_client_type().get_caller_identity()
+ metadata = res.pop("ResponseMetadata", {})
+ if metadata.get("HTTPStatusCode") == 200:
+ return True, json.dumps(res)
+ else:
+ try:
+ return False, json.dumps(metadata)
+ except TypeError:
+ return False, str(metadata)
+ except Exception as e:
+ return False, str(e)
+ finally:
+ self.client_type = orig_client_type
+
class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 00ffc31163..d1c3da9cdd 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -701,6 +701,29 @@ class TestAwsBaseHook:
assert hook.conn_partition == expected_partition
+ @pytest.mark.parametrize(
+ "client_type,resource_type",
+ [
+ ("s3", "dynamodb"),
+ (None, None),
+ ("", ""),
+ ],
+ )
+ def test_connection_client_resource_types_check(self, client_type, resource_type):
+ # Should not raise any error during Hook initialisation.
+ hook = AwsBaseHook(aws_conn_id=None, client_type=client_type, resource_type=resource_type)
+
+ with pytest.raises(ValueError, match="Either client_type=.* or resource_type=.* must be provided"):
+ hook.get_conn()
+
+ @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
+ @mock_sts
+ def test_hook_connection_test(self):
+ hook = AwsBaseHook(client_type="s3")
+ result, message = hook.test_connection()
+ assert result
+ assert hook.client_type == "s3" # Same client_type which defined during initialisation
+
class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""