You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/05 15:43:21 UTC

[airflow] branch main updated: Resolve Amazon Hook's `region_name` and `config` in wrapper (#25336)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4193558e80 Resolve Amazon Hook's `region_name` and `config` in wrapper (#25336)
4193558e80 is described below

commit 4193558e808af0d0eac0636b4bb6f88606ca54c6
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Aug 5 19:43:09 2022 +0400

    Resolve Amazon Hook's `region_name` and `config` in wrapper (#25336)
---
 airflow/providers/amazon/aws/hooks/base_aws.py     | 111 ++++++++++++---------
 .../amazon/aws/utils/connection_wrapper.py         |  64 ++++++++++--
 tests/providers/amazon/aws/hooks/test_base_aws.py  | 104 ++++++++++++++++++-
 .../amazon/aws/utils/test_connection_wrapper.py    | 100 ++++++++++++++++---
 4 files changed, 306 insertions(+), 73 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index f0de49556d..1132892e73 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -44,7 +44,7 @@ from slugify import slugify
 
 from airflow.compat.functools import cached_property
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.hooks.base import BaseHook
 from airflow.models.connection import Connection
 from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
@@ -66,19 +66,24 @@ class BaseSessionFactory(LoggingMixin):
     """
 
     def __init__(
-        self, conn: Union[Connection, AwsConnectionWrapper], region_name: Optional[str], config: Config
+        self,
+        conn: Optional[Union[Connection, AwsConnectionWrapper]],
+        region_name: Optional[str] = None,
+        config: Optional[Config] = None,
     ) -> None:
         super().__init__()
         self._conn = conn
         self._region_name = region_name
-        self.config = config
+        self._config = config
 
     @cached_property
     def conn(self) -> AwsConnectionWrapper:
         """Cached AWS Connection Wrapper."""
-        if isinstance(self._conn, AwsConnectionWrapper):
-            return self._conn
-        return AwsConnectionWrapper(self._conn)
+        return AwsConnectionWrapper(
+            conn=self._conn,
+            region_name=self._region_name,
+            botocore_config=self._config,
+        )
 
     @cached_property
     def basic_session(self) -> boto3.session.Session:
@@ -92,12 +97,13 @@ class BaseSessionFactory(LoggingMixin):
 
     @property
     def region_name(self) -> Optional[str]:
-        """Resolve region name.
+        """AWS Region Name read-only property."""
+        return self.conn.region_name
 
-        1. SessionFactory region_name
-        2. Connection region_name
-        """
-        return self._region_name or self.conn.region_name
+    @property
+    def config(self) -> Optional[Config]:
+        """Configuration for botocore client read-only property."""
+        return self.conn.botocore_config
 
     @property
     def role_arn(self) -> Optional[str]:
@@ -105,8 +111,15 @@ class BaseSessionFactory(LoggingMixin):
         return self.conn.role_arn
 
     def create_session(self) -> boto3.session.Session:
-        """Create AWS session."""
-        if not self.role_arn:
+        """Create boto3 Session from connection config."""
+        if not self.conn:
+            self.log.info(
+                "No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
+                "See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
+                self.region_name,
+            )
+            return boto3.session.Session(region_name=self.region_name)
+        elif not self.role_arn:
             return self.basic_session
         return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)
 
@@ -381,45 +394,50 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         self.verify = verify
         self.client_type = client_type
         self.resource_type = resource_type
-        self.region_name = region_name
-        self.config = config
-
-    def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
-
-        if not self.aws_conn_id:
-            session = boto3.session.Session(region_name=region_name)
-            return session, None
+        self._region_name = region_name
+        self._config = config
 
-        self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
+    @cached_property
+    def conn_config(self) -> AwsConnectionWrapper:
+        """Get the Airflow Connection object and wrap it in helper (cached)."""
+        connection = None
+        if self.aws_conn_id:
+            try:
+                connection = self.get_connection(self.aws_conn_id)
+            except AirflowNotFoundException:
+                warnings.warn(
+                    f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. "
+                    "This behaviour is deprecated and will be removed in a future releases. "
+                    "Please provide existed AWS connection ID or if required boto3 credential strategy "
+                    "explicit set AWS Connection ID to None.",
+                    DeprecationWarning,
+                    stacklevel=2,
+                )
 
-        try:
-            # Fetch the Airflow connection object and wrap it in helper
-            connection_object = AwsConnectionWrapper(self.get_connection(self.aws_conn_id))
+        return AwsConnectionWrapper(
+            conn=connection or Connection(conn_id=None, conn_type="aws"),
+            region_name=self._region_name,
+            botocore_config=self._config,
+        )
 
-            if connection_object.botocore_config:
-                # For historical reason botocore.config.Config from connection overwrites
-                # config which explicitly set in Hook.
-                self.config = connection_object.botocore_config
+    @property
+    def region_name(self) -> Optional[str]:
+        """AWS Region Name read-only property."""
+        return self.conn_config.region_name
 
-            session = SessionFactory(
-                conn=connection_object, region_name=region_name, config=self.config
-            ).create_session()
+    @property
+    def config(self) -> Optional[Config]:
+        """Configuration for botocore client read-only property."""
+        return self.conn_config.botocore_config
 
-            return session, connection_object.endpoint_url
+    def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
+        self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
 
-        except AirflowException:
-            self.log.warning(
-                "Unable to use Airflow Connection for credentials. "
-                "Fallback on boto3 credential strategy. See: "
-                "https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html"
-            )
+        session = SessionFactory(
+            conn=self.conn_config, region_name=region_name, config=self.config
+        ).create_session()
 
-        self.log.debug(
-            "Creating session using boto3 credential strategy region_name=%s",
-            region_name,
-        )
-        session = boto3.session.Session(region_name=region_name)
-        return session, None
+        return session, self.conn_config.endpoint_url
 
     def get_client_type(
         self,
@@ -491,6 +509,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
 
     @cached_property
     def conn_client_meta(self) -> ClientMeta:
+        """Get botocore client metadata from Hook connection (cached)."""
         conn = self.conn
         if isinstance(conn, botocore.client.BaseClient):
             return conn.meta
@@ -498,10 +517,12 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
 
     @property
     def conn_region_name(self) -> str:
+        """Get actual AWS Region Name from Hook connection (cached)."""
         return self.conn_client_meta.region_name
 
     @property
     def conn_partition(self) -> str:
+        """Get associated AWS Region Partition from Hook connection (cached)."""
         return self.conn_client_meta.partition
 
     def get_conn(self) -> BaseAwsConnection:
diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py
index 6672b971f6..cf7c33a823 100644
--- a/airflow/providers/amazon/aws/utils/connection_wrapper.py
+++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -17,8 +17,8 @@
 
 import warnings
 from copy import deepcopy
-from dataclasses import InitVar, dataclass, field
-from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+from dataclasses import MISSING, InitVar, dataclass, field, fields
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 
 from botocore.config import Config
 
@@ -35,25 +35,43 @@ class AwsConnectionWrapper(LoggingMixin):
     """
     AWS Connection Wrapper class helper.
     Use for validate and resolve AWS Connection parameters.
+
+    ``conn`` reference to Airflow Connection object or AwsConnectionWrapper
+        if it set to ``None`` than default values would use.
+
+    The precedence rules for ``region_name``
+        1. Explicit set (in Hook) ``region_name``.
+        2. Airflow Connection Extra 'region_name'.
+
+    The precedence rules for ``botocore_config``
+        1. Explicit set (in Hook) ``botocore_config``.
+        2. Construct from Airflow Connection Extra 'botocore_kwargs'.
+        3. The wrapper's default value
     """
 
-    conn: InitVar[Optional["Connection"]]
+    conn: InitVar[Optional[Union["Connection", "AwsConnectionWrapper"]]]
+    region_name: Optional[str] = field(default=None)
+    botocore_config: Optional[Config] = field(default=None)
 
+    # Reference to Airflow Connection attributes
+    # ``extra_config`` contains original Airflow Connection Extra.
     conn_id: Optional[str] = field(init=False, default=None)
     conn_type: Optional[str] = field(init=False, default=None)
     login: Optional[str] = field(init=False, repr=False, default=None)
     password: Optional[str] = field(init=False, repr=False, default=None)
     extra_config: Dict[str, Any] = field(init=False, repr=False, default_factory=dict)
 
-    aws_access_key_id: Optional[str] = field(init=False)
-    aws_secret_access_key: Optional[str] = field(init=False)
-    aws_session_token: Optional[str] = field(init=False)
+    # AWS Credentials from connection.
+    aws_access_key_id: Optional[str] = field(init=False, default=None)
+    aws_secret_access_key: Optional[str] = field(init=False, default=None)
+    aws_session_token: Optional[str] = field(init=False, default=None)
 
-    region_name: Optional[str] = field(init=False, default=None)
+    # Additional boto3.session.Session keyword arguments.
     session_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
-    botocore_config: Optional[Config] = field(init=False, default=None)
+    # Custom endpoint_url for boto3.client and boto3.resource
     endpoint_url: Optional[str] = field(init=False, default=None)
 
+    # Assume Role Configurations
     role_arn: Optional[str] = field(init=False, default=None)
     assume_role_method: Optional[str] = field(init=False, default=None)
     assume_role_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
@@ -63,7 +81,30 @@ class AwsConnectionWrapper(LoggingMixin):
         return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"
 
     def __post_init__(self, conn: "Connection"):
-        if not conn:
+        if isinstance(conn, type(self)):
+            # For every field with init=False we copy reference value from original wrapper
+            # For every field with init=True we use init values if it not equal default
+            # We can't use ``dataclasses.replace`` in classmethod because
+            # we limited by InitVar arguments since it not stored in object,
+            # and also we do not want to run __post_init__ method again which print all logs/warnings again.
+            for fl in fields(conn):
+                value = getattr(conn, fl.name)
+                if not fl.init:
+                    setattr(self, fl.name, value)
+                else:
+                    if fl.default is not MISSING:
+                        default = fl.default
+                    elif fl.default_factory is not MISSING:
+                        default = fl.default_factory()  # zero-argument callable
+                    else:
+                        continue  # Value mandatory, skip
+
+                    orig_value = getattr(self, fl.name)
+                    if orig_value == default:
+                        # Only replace value if it not equal default value
+                        setattr(self, fl.name, value)
+            return
+        elif not conn:
             return
 
         extra = deepcopy(conn.extra_dejson)
@@ -86,7 +127,7 @@ class AwsConnectionWrapper(LoggingMixin):
         init_credentials = self._get_credentials(**extra)
         self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials
 
-        if "region_name" in extra:
+        if not self.region_name and "region_name" in extra:
             self.region_name = extra["region_name"]
             self.log.info("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)
 
@@ -106,7 +147,7 @@ class AwsConnectionWrapper(LoggingMixin):
                 )
 
         config_kwargs = extra.get("config_kwargs")
-        if config_kwargs:
+        if not self.botocore_config and config_kwargs:
             # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
             self.log.info("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
             self.botocore_config = Config(**config_kwargs)
@@ -119,6 +160,7 @@ class AwsConnectionWrapper(LoggingMixin):
 
     @property
     def extra_dejson(self):
+        """Compatibility with `airflow.models.Connection.extra_dejson` property."""
         return self.extra_config
 
     def __bool__(self):
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index a04473b5b0..f2d55cd3bb 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -17,6 +17,7 @@
 # under the License.
 #
 import json
+import os
 import unittest
 from base64 import b64encode
 from datetime import datetime, timedelta, timezone
@@ -24,6 +25,7 @@ from unittest import mock
 
 import boto3
 import pytest
+from botocore.config import Config
 from moto.core import ACCOUNT_ID
 
 from airflow.models import Connection
@@ -45,6 +47,8 @@ except ImportError:
 
 MOCK_AWS_CONN_ID = "mock-conn-id"
 MOCK_CONN_TYPE = "aws"
+MOCK_BOTO3_SESSION = mock.MagicMock(return_value="Mock boto3.session.Session")
+
 
 SAML_ASSERTION = """
 <?xml version="1.0"?>
@@ -150,6 +154,70 @@ class TestSessionFactory:
         assert session_factory_conn.conn_type == MOCK_CONN_TYPE
         assert sf.conn is session_factory_conn
 
+    def test_empty_conn_property(self):
+        sf = BaseSessionFactory(conn=None, region_name=None, config=None)
+        assert isinstance(sf.conn, AwsConnectionWrapper)
+
+    @pytest.mark.parametrize(
+        "region_name,conn_region_name",
+        [
+            ("eu-west-1", "cn-north-1"),
+            ("eu-west-1", None),
+            (None, "cn-north-1"),
+            (None, None),
+        ],
+    )
+    def test_resolve_region_name(self, region_name, conn_region_name):
+        conn = AwsConnectionWrapper(
+            conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID),
+            region_name=conn_region_name,
+        )
+        sf = BaseSessionFactory(conn=conn, region_name=region_name, config=None)
+        expected = region_name or conn_region_name
+        assert sf.region_name == expected
+
+    @pytest.mark.parametrize(
+        "botocore_config, conn_botocore_config",
+        [
+            (Config(s3={"us_east_1_regional_endpoint": "regional"}), None),
+            (Config(s3={"us_east_1_regional_endpoint": "regional"}), Config(region_name="ap-southeast-1")),
+            (None, Config(region_name="ap-southeast-1")),
+            (None, None),
+        ],
+    )
+    def test_resolve_botocore_config(self, botocore_config, conn_botocore_config):
+        conn = AwsConnectionWrapper(
+            conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID),
+            botocore_config=conn_botocore_config,
+        )
+        sf = BaseSessionFactory(conn=conn, config=botocore_config)
+        expected = botocore_config or conn_botocore_config
+        assert sf.config == expected
+
+    @pytest.mark.parametrize("region_name", ["eu-central-1", None])
+    @mock.patch("boto3.session.Session", new_callable=mock.PropertyMock, return_value=MOCK_BOTO3_SESSION)
+    def test_create_session_boto3_credential_strategy(self, mock_boto3_session, region_name, caplog):
+        sf = BaseSessionFactory(conn=AwsConnectionWrapper(conn=None), region_name=region_name, config=None)
+        session = sf.create_session()
+        mock_boto3_session.assert_called_once_with(region_name=region_name)
+        assert session == MOCK_BOTO3_SESSION
+        logging_message = "No connection ID provided. Fallback on boto3 credential strategy"
+        assert any(logging_message in log_text for log_text in caplog.messages)
+
+    @pytest.mark.parametrize("region_name", ["eu-central-1", None])
+    @mock.patch("boto3.session.Session", new_callable=mock.PropertyMock, return_value=MOCK_BOTO3_SESSION)
+    def test_create_session_from_credentials(self, mock_boto3_session, region_name):
+        mock_conn = AwsConnectionWrapper(conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID))
+        sf = BaseSessionFactory(conn=mock_conn, region_name=region_name, config=None)
+        session = sf.create_session()
+        mock_boto3_session.assert_called_once_with(
+            aws_access_key_id=mock_conn.aws_access_key_id,
+            aws_secret_access_key=mock_conn.aws_secret_access_key,
+            aws_session_token=mock_conn.aws_session_token,
+            region_name=region_name,
+        )
+        assert session == MOCK_BOTO3_SESSION
+
 
 class TestAwsBaseHook:
     @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@@ -341,7 +409,11 @@ class TestAwsBaseHook:
     @mock.patch.object(AwsBaseHook, 'get_connection')
     @mock_sts
     def test_get_credentials_from_role_arn(self, mock_get_connection):
-        mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}')
+        mock_connection = Connection(
+            conn_id='aws_default',
+            conn_type=MOCK_CONN_TYPE,
+            extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}',
+        )
         mock_get_connection.return_value = mock_connection
         hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
         credentials_from_hook = hook.get_credentials()
@@ -445,6 +517,7 @@ class TestAwsBaseHook:
         duration_seconds = 901
 
         mock_connection = Connection(
+            conn_id=MOCK_AWS_CONN_ID,
             extra=json.dumps(
                 {
                     "role_arn": role_arn,
@@ -459,7 +532,7 @@ class TestAwsBaseHook:
                     },
                     "assume_role_kwargs": {"DurationSeconds": duration_seconds},
                 }
-            )
+            ),
         )
         mock_get_connection.return_value = mock_connection
 
@@ -671,6 +744,33 @@ class TestAwsBaseHook:
         assert result
         assert hook.client_type == "s3"  # Same client_type which defined during initialisation
 
+    @mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}": "aws://"})
+    def test_conn_config_conn_id_exists(self):
+        """Test retrieve connection config if aws_conn_id exists."""
+        hook = AwsBaseHook(aws_conn_id=MOCK_AWS_CONN_ID)
+        conn_config_exist = hook.conn_config
+        assert conn_config_exist is hook.conn_config, "Expected cached Connection Config"
+        assert isinstance(conn_config_exist, AwsConnectionWrapper)
+        assert conn_config_exist
+
+    @pytest.mark.parametrize("aws_conn_id", ["", None], ids=["empty", "None"])
+    def test_conn_config_conn_id_empty(self, aws_conn_id):
+        """Test retrieve connection config if aws_conn_id empty or None."""
+        conn_config_empty = AwsBaseHook(aws_conn_id=aws_conn_id).conn_config
+        assert isinstance(conn_config_empty, AwsConnectionWrapper)
+        assert not conn_config_empty
+
+    def test_conn_config_conn_id_not_exists(self):
+        """Test fallback connection config if aws_conn_id not exists."""
+        warning_message = (
+            r"Unable to find AWS Connection ID '.*', switching to empty\. "
+            r"This behaviour is deprecated and will be removed in a future releases"
+        )
+        with pytest.warns(DeprecationWarning, match=warning_message):
+            conn_config_fallback_not_exists = AwsBaseHook(aws_conn_id="aws-conn-not-exists").conn_config
+        assert isinstance(conn_config_fallback_not_exists, AwsConnectionWrapper)
+        assert not conn_config_fallback_not_exists
+
 
 class ThrowErrorUntilCount:
     """Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/utils/test_connection_wrapper.py b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
index 74696db98e..47ed16fa73 100644
--- a/tests/providers/amazon/aws/utils/test_connection_wrapper.py
+++ b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
@@ -15,10 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from dataclasses import fields
 from typing import Optional
 from unittest import mock
 
 import pytest
+from botocore.config import Config
 
 from airflow.models import Connection
 from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
@@ -138,11 +140,24 @@ class TestAwsConnectionWrapper:
         assert wrap_conn.aws_secret_access_key == aws_secret_access_key
         assert wrap_conn.aws_session_token == aws_session_token
 
-    @pytest.mark.parametrize("region_name", [None, "mock-aws-region"])
-    def test_get_region_name(self, region_name):
-        mock_conn = mock_connection_factory(extra={"region_name": region_name} if region_name else None)
-        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
-        assert wrap_conn.region_name == region_name
+    @pytest.mark.parametrize(
+        "region_name,conn_region_name",
+        [
+            ("mock-region-name", None),
+            ("mock-region-name", "mock-connection-region-name"),
+            (None, "mock-connection-region-name"),
+            (None, None),
+        ],
+    )
+    def test_get_region_name(self, region_name, conn_region_name):
+        mock_conn = mock_connection_factory(
+            extra={"region_name": conn_region_name} if conn_region_name else None
+        )
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn, region_name=region_name)
+        if region_name:
+            assert wrap_conn.region_name == region_name, "Expected provided region_name"
+        else:
+            assert wrap_conn.region_name == conn_region_name, "Expected connection region_name"
 
     @pytest.mark.parametrize("session_kwargs", [None, {"profile_name": "mock-profile"}])
     def test_get_session_kwargs(self, session_kwargs):
@@ -160,20 +175,30 @@ class TestAwsConnectionWrapper:
             wrap_conn = AwsConnectionWrapper(conn=mock_conn)
         assert "profile_name" not in wrap_conn.session_kwargs
 
-    @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper.Config", autospec=True)
-    @pytest.mark.parametrize("botocore_config_kwargs", [None, {"user_agent": "Airflow Amazon Provider"}])
-    def test_get_botocore_config(self, mock_botocore_config, botocore_config_kwargs):
+    @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper.Config")
+    @pytest.mark.parametrize(
+        "botocore_config, botocore_config_kwargs",
+        [
+            (Config(s3={"us_east_1_regional_endpoint": "regional"}), None),
+            (Config(region_name="ap-southeast-1"), {"user_agent": "Airflow Amazon Provider"}),
+            (None, {"user_agent": "Airflow Amazon Provider"}),
+            (None, None),
+        ],
+    )
+    def test_get_botocore_config(self, mock_botocore_config, botocore_config, botocore_config_kwargs):
         mock_conn = mock_connection_factory(
             extra={"config_kwargs": botocore_config_kwargs} if botocore_config_kwargs else None
         )
-        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
-
-        if not botocore_config_kwargs:
-            assert not mock_botocore_config.called
-            assert wrap_conn.botocore_config is None
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn, botocore_config=botocore_config)
+
+        if botocore_config:
+            assert wrap_conn.botocore_config == botocore_config, "Expected provided botocore_config"
+            assert mock_botocore_config.assert_not_called
+        elif not botocore_config_kwargs:
+            assert wrap_conn.botocore_config is None, "Expected default botocore_config"
+            assert mock_botocore_config.assert_not_called
         else:
-            assert mock_botocore_config.called
-            assert mock_botocore_config.call_count == 1
+            assert mock_botocore_config.assert_called_once
             assert mock.call(**botocore_config_kwargs) in mock_botocore_config.mock_calls
 
     @pytest.mark.parametrize("endpoint_url", [None, "https://example.org"])
@@ -281,3 +306,48 @@ class TestAwsConnectionWrapper:
             wrap_conn = AwsConnectionWrapper(conn=mock_conn)
         assert "ExternalId" in wrap_conn.assume_role_kwargs
         assert wrap_conn.assume_role_kwargs["ExternalId"] == mock_external_id_in_extra
+
+    @pytest.mark.parametrize(
+        "orig_wrapper",
+        [
+            AwsConnectionWrapper(
+                conn=mock_connection_factory(
+                    login="mock-login",
+                    password="mock-password",
+                    extra={
+                        "region_name": "mock-region",
+                        "botocore_kwargs": {"user_agent": "Airflow Amazon Provider"},
+                        "role_arn": MOCK_ROLE_ARN,
+                        "aws_session_token": "mock-aws-session-token",
+                    },
+                ),
+            ),
+            AwsConnectionWrapper(conn=mock_connection_factory()),
+            AwsConnectionWrapper(conn=None),
+            AwsConnectionWrapper(
+                conn=None,
+                region_name="mock-region",
+                botocore_config=Config(user_agent="Airflow Amazon Provider"),
+            ),
+        ],
+    )
+    @pytest.mark.parametrize("region_name", [None, "ca-central-1"])
+    @pytest.mark.parametrize("botocore_config", [None, Config(region_name="ap-southeast-1")])
+    def test_wrap_wrapper(self, orig_wrapper, region_name, botocore_config):
+        wrap_kwargs = {}
+        if region_name:
+            wrap_kwargs["region_name"] = region_name
+        if botocore_config:
+            wrap_kwargs["botocore_config"] = botocore_config
+        wrap_conn = AwsConnectionWrapper(conn=orig_wrapper, **wrap_kwargs)
+
+        # Non init fields should be same in orig_wrapper and child wrapper
+        wrap_non_init_fields = [f.name for f in fields(wrap_conn) if not f.init]
+        for field in wrap_non_init_fields:
+            assert getattr(wrap_conn, field) == getattr(
+                orig_wrapper, field
+            ), "Expected no changes in non-init values"
+
+        # Test overwrite/inherit init fields
+        assert wrap_conn.region_name == (region_name or orig_wrapper.region_name)
+        assert wrap_conn.botocore_config == (botocore_config or orig_wrapper.botocore_config)