You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/04 21:10:52 UTC

[airflow] branch main updated: Add test_connection method to AWS hook (#24662)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 210549c658 Add test_connection method to AWS hook (#24662)
210549c658 is described below

commit 210549c658c96ad0129609f50a46e40eebfdaa23
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Tue Jul 5 01:10:41 2022 +0400

    Add test_connection method to AWS hook (#24662)
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 39 ++++++++++++++++++-----
 tests/providers/amazon/aws/hooks/test_base_aws.py | 23 +++++++++++++
 2 files changed, 54 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 5e518adde0..28135b0fa4 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -26,6 +26,7 @@ This module contains Base AWS Hook.
 
 import configparser
 import datetime
+import json
 import logging
 import warnings
 from functools import wraps
@@ -409,9 +410,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         self.region_name = region_name
         self.config = config
 
-        if not (self.client_type or self.resource_type):
-            raise AirflowException('Either client_type or resource_type must be provided.')
-
     def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
 
         if not self.aws_conn_id:
@@ -510,13 +508,15 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         :return: boto3.client or boto3.resource
         :rtype: Union[boto3.client, boto3.resource]
         """
-        if self.client_type:
+        if not ((not self.client_type) ^ (not self.resource_type)):
+            raise ValueError(
+                f"Either client_type={self.client_type!r} or "
+                f"resource_type={self.resource_type!r} must be provided, not both."
+            )
+        elif self.client_type:
             return self.get_client_type(region_name=self.region_name)
-        elif self.resource_type:
-            return self.get_resource_type(region_name=self.region_name)
         else:
-            # Rare possibility - subclasses have not specified a client_type or resource_type
-            raise NotImplementedError('Could not get boto3 connection!')
+            return self.get_resource_type(region_name=self.region_name)
 
     @cached_property
     def conn_client_meta(self) -> ClientMeta:
@@ -611,6 +611,29 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
 
         return retry_decorator
 
+    def test_connection(self):
+        """
+        Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API.
+
+        .. seealso::
+            https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
+        """
+        orig_client_type, self.client_type = self.client_type, 'sts'
+        try:
+            res = self.get_client_type().get_caller_identity()
+            metadata = res.pop("ResponseMetadata", {})
+            if metadata.get("HTTPStatusCode") == 200:
+                return True, json.dumps(res)
+            else:
+                try:
+                    return False, json.dumps(metadata)
+                except TypeError:
+                    return False, str(metadata)
+        except Exception as e:
+            return False, str(e)
+        finally:
+            self.client_type = orig_client_type
+
 
 class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
     """
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 00ffc31163..d1c3da9cdd 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -701,6 +701,29 @@ class TestAwsBaseHook:
 
             assert hook.conn_partition == expected_partition
 
+    @pytest.mark.parametrize(
+        "client_type,resource_type",
+        [
+            ("s3", "dynamodb"),
+            (None, None),
+            ("", ""),
+        ],
+    )
+    def test_connection_client_resource_types_check(self, client_type, resource_type):
+        # Should not raise any error during Hook initialisation.
+        hook = AwsBaseHook(aws_conn_id=None, client_type=client_type, resource_type=resource_type)
+
+        with pytest.raises(ValueError, match="Either client_type=.* or resource_type=.* must be provided"):
+            hook.get_conn()
+
+    @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
+    @mock_sts
+    def test_hook_connection_test(self):
+        hook = AwsBaseHook(client_type="s3")
+        result, message = hook.test_connection()
+        assert result
+        assert hook.client_type == "s3"  # Same client_type which defined during initialisation
+
 
 class ThrowErrorUntilCount:
     """Holds counter state for invoking a method several times in a row."""