You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/26 12:55:04 UTC

[airflow] branch main updated: Resolve and validate AWS Connection parameters in wrapper (#25256)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 432977be0c Resolve and validate AWS Connection parameters in wrapper (#25256)
432977be0c is described below

commit 432977be0cd1e95f623fa5edda2a227798fa2939
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Tue Jul 26 16:54:54 2022 +0400

    Resolve and validate AWS Connection parameters in wrapper (#25256)
---
 airflow/providers/amazon/aws/hooks/base_aws.py     | 320 ++++++++-------------
 .../amazon/aws/utils/connection_wrapper.py         | 282 ++++++++++++++++++
 tests/providers/amazon/aws/hooks/test_base_aws.py  |  97 ++-----
 .../amazon/aws/utils/test_connection_wrapper.py    | 283 ++++++++++++++++++
 4 files changed, 699 insertions(+), 283 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index a69798990f..37e3780c58 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -24,7 +24,6 @@ This module contains Base AWS Hook.
     :ref:`howto/connection:AWSHook`
 """
 
-import configparser
 import datetime
 import json
 import logging
@@ -48,6 +47,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.models.connection import Connection
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
@@ -65,78 +65,62 @@ class BaseSessionFactory(LoggingMixin):
         :ref:`howto/connection:aws:session-factory`
     """
 
-    def __init__(self, conn: Connection, region_name: Optional[str], config: Config) -> None:
+    def __init__(
+        self, conn: Union[Connection, AwsConnectionWrapper], region_name: Optional[str], config: Config
+    ) -> None:
         super().__init__()
-        self.conn = conn
-        self.region_name = region_name
+        self._conn = conn
+        self._region_name = region_name
         self.config = config
-        self.extra_config = self.conn.extra_dejson
 
-        self.basic_session: Optional[boto3.session.Session] = None
-        self.role_arn: Optional[str] = None
+    @cached_property
+    def conn(self) -> AwsConnectionWrapper:
+        """Cached AWS Connection Wrapper."""
+        if isinstance(self._conn, AwsConnectionWrapper):
+            return self._conn
+        return AwsConnectionWrapper(self._conn)
 
-    def create_session(self) -> boto3.session.Session:
-        """Create AWS session."""
-        session_kwargs = {}
-        if "session_kwargs" in self.extra_config:
-            self.log.info(
-                "Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s",
-                self.extra_config["session_kwargs"],
-            )
-            session_kwargs = self.extra_config["session_kwargs"]
-
-        if "profile" in self.extra_config and "s3_config_file" not in self.extra_config:
-            if "profile_name" not in session_kwargs:
-                self.log.warning(
-                    "Found 'profile' without specifying 's3_config_file'. "
-                    "If required profile from AWS Shared Credentials please "
-                    "set 'profile_name' in extra 'session_kwargs'."
-                )
+    @cached_property
+    def basic_session(self) -> boto3.session.Session:
+        """Cached property with basic boto3.session.Session."""
+        return self._create_basic_session(session_kwargs=self.conn.session_kwargs)
 
-        self.basic_session = self._create_basic_session(session_kwargs=session_kwargs)
-        self.role_arn = self._read_role_arn_from_extra_config()
-        # If role_arn was specified then STS + assume_role
-        if self.role_arn is None:
-            return self.basic_session
+    @property
+    def extra_config(self) -> Dict[str, Any]:
+        """AWS Connection extra_config."""
+        return self.conn.extra_config
 
-        return self._create_session_with_assume_role(session_kwargs=session_kwargs)
+    @property
+    def region_name(self) -> Optional[str]:
+        """Resolve region name.
 
-    def _get_region_name(self) -> Optional[str]:
-        region_name = self.region_name
-        if self.region_name is None and 'region_name' in self.extra_config:
-            self.log.info("Retrieving region_name from Connection.extra_config['region_name']")
-            region_name = self.extra_config["region_name"]
-        return region_name
+        1. SessionFactory region_name
+        2. Connection region_name
+        """
+        return self._region_name or self.conn.region_name
 
-    def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
-        aws_access_key_id, aws_secret_access_key = self._read_credentials_from_connection()
-        aws_session_token = self.extra_config.get("aws_session_token")
-        region_name = self._get_region_name()
-        self.log.debug(
-            "Creating session with aws_access_key_id=%s region_name=%s",
-            aws_access_key_id,
-            region_name,
-        )
+    @property
+    def role_arn(self) -> Optional[str]:
+        """Assume Role ARN from AWS Connection"""
+        return self.conn.role_arn
 
+    def create_session(self) -> boto3.session.Session:
+        """Create AWS session."""
+        if not self.role_arn:
+            return self.basic_session
+        return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)
+
+    def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
         return boto3.session.Session(
-            aws_access_key_id=aws_access_key_id,
-            aws_secret_access_key=aws_secret_access_key,
-            region_name=region_name,
-            aws_session_token=aws_session_token,
+            aws_access_key_id=self.conn.aws_access_key_id,
+            aws_secret_access_key=self.conn.aws_secret_access_key,
+            aws_session_token=self.conn.aws_session_token,
+            region_name=self.region_name,
             **session_kwargs,
         )
 
     def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
-        assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
-        self.log.debug("assume_role_method=%s", assume_role_method)
-        supported_methods = ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
-        if assume_role_method not in supported_methods:
-            raise NotImplementedError(
-                f'assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra.'
-                f'Currently {supported_methods} are supported.'
-                '(Exclude this setting will default to "assume_role").'
-            )
-        if assume_role_method == 'assume_role_with_web_identity':
+        if self.conn.assume_role_method == 'assume_role_with_web_identity':
             # Deferred credentials have no initial credentials
             credential_fetcher = self._get_web_identity_credential_fetcher()
             credentials = botocore.credentials.DeferredRefreshableCredentials(
@@ -151,12 +135,9 @@ class BaseSessionFactory(LoggingMixin):
                 refresh_using=self._refresh_credentials,
                 method="sts-assume-role",
             )
+
         session = botocore.session.get_session()
         session._credentials = credentials
-
-        if self.basic_session is None:
-            raise RuntimeError("The basic session should be created here!")
-
         region_name = self.basic_session.region_name
         session.set_config_variable("region", region_name)
 
@@ -164,25 +145,21 @@ class BaseSessionFactory(LoggingMixin):
 
     def _refresh_credentials(self) -> Dict[str, Any]:
         self.log.debug('Refreshing credentials')
-        assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
-        sts_session = self.basic_session
-
-        if sts_session is None:
-            raise RuntimeError(
-                "Session should be initialized when refresh credentials with assume_role is used!"
-            )
+        assume_role_method = self.conn.assume_role_method
+        if assume_role_method not in ('assume_role', 'assume_role_with_saml'):
+            raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
 
-        sts_client = sts_session.client("sts", config=self.config)
+        sts_client = self.basic_session.client("sts", config=self.config)
 
         if assume_role_method == 'assume_role':
             sts_response = self._assume_role(sts_client=sts_client)
-        elif assume_role_method == 'assume_role_with_saml':
-            sts_response = self._assume_role_with_saml(sts_client=sts_client)
         else:
-            raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
+            sts_response = self._assume_role_with_saml(sts_client=sts_client)
+
         sts_response_http_status = sts_response['ResponseMetadata']['HTTPStatusCode']
-        if not sts_response_http_status == 200:
+        if sts_response_http_status != 200:
             raise RuntimeError(f'sts_response_http_status={sts_response_http_status}')
+
         credentials = sts_response['Credentials']
         expiry_time = credentials.get('Expiration').isoformat()
         self.log.debug('New credentials expiry_time: %s', expiry_time)
@@ -194,70 +171,13 @@ class BaseSessionFactory(LoggingMixin):
         }
         return credentials
 
-    def _read_role_arn_from_extra_config(self) -> Optional[str]:
-        aws_account_id = self.extra_config.get("aws_account_id")
-        aws_iam_role = self.extra_config.get("aws_iam_role")
-        role_arn = self.extra_config.get("role_arn")
-        if role_arn is None and aws_account_id is not None and aws_iam_role is not None:
-            self.log.info("Constructing role_arn from aws_account_id and aws_iam_role")
-            warnings.warn(
-                "Constructing 'role_arn' from 'aws_account_id' and 'aws_iam_role' is deprecated and "
-                "will be removed in a future releases. Please set 'role_arn' in extra config.",
-                DeprecationWarning,
-                stacklevel=3,
-            )
-            role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
-        self.log.debug("role_arn is %s", role_arn)
-        return role_arn
-
-    def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
-        aws_access_key_id = None
-        aws_secret_access_key = None
-        if self.conn.login:
-            aws_access_key_id = self.conn.login
-            aws_secret_access_key = self.conn.password
-            self.log.info("Credentials retrieved from login")
-        elif "aws_access_key_id" in self.extra_config and "aws_secret_access_key" in self.extra_config:
-            aws_access_key_id = self.extra_config["aws_access_key_id"]
-            aws_secret_access_key = self.extra_config["aws_secret_access_key"]
-            self.log.info("Credentials retrieved from extra_config")
-        elif "s3_config_file" in self.extra_config:
-            warnings.warn(
-                "Use local credentials file is never documented and well tested. "
-                "Obtain credentials by this way deprecated and will be removed in a future releases.",
-                DeprecationWarning,
-                stacklevel=3,
-            )
-            aws_access_key_id, aws_secret_access_key = _parse_s3_config(
-                self.extra_config["s3_config_file"],
-                self.extra_config.get("s3_config_format"),
-                self.extra_config.get("profile"),
-            )
-            self.log.info("Credentials retrieved from extra_config['s3_config_file']")
-        return aws_access_key_id, aws_secret_access_key
-
-    def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
-        return slugify(role_session_name, regex_pattern=r'[^\w+=,.@-]+')
-
     def _assume_role(self, sts_client: boto3.client) -> Dict:
-        assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
-        if "ExternalId" not in assume_role_kwargs and "external_id" in self.extra_config:
-            warnings.warn(
-                "'external_id' in extra config is deprecated and will be removed in a future releases. "
-                "Set 'ExternalId' in 'assume_role_kwargs' in extra config.",
-                DeprecationWarning,
-                stacklevel=3,
-            )
-            assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id")
-        role_session_name = self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}")
-        self.log.debug(
-            "Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)",
-            self.role_arn,
-            role_session_name,
-        )
-        return sts_client.assume_role(
-            RoleArn=self.role_arn, RoleSessionName=role_session_name, **assume_role_kwargs
-        )
+        kw = {
+            "RoleSessionName": self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),
+            **self.conn.assume_role_kwargs,
+            "RoleArn": self.role_arn,
+        }
+        return sts_client.assume_role(**kw)
 
     def _assume_role_with_saml(self, sts_client: boto3.client) -> Dict[str, Any]:
         saml_config = self.extra_config['assume_role_with_saml']
@@ -273,12 +193,11 @@ class BaseSessionFactory(LoggingMixin):
             )
 
         self.log.debug("Doing sts_client.assume_role_with_saml to role_arn=%s", self.role_arn)
-        assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
         return sts_client.assume_role_with_saml(
             RoleArn=self.role_arn,
             PrincipalArn=principal_arn,
             SAMLAssertion=saml_assertion,
-            **assume_role_kwargs,
+            **self.conn.assume_role_kwargs,
         )
 
     def _get_idp_response(
@@ -357,8 +276,6 @@ class BaseSessionFactory(LoggingMixin):
     def _get_web_identity_credential_fetcher(
         self,
     ) -> botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher:
-        if self.basic_session is None:
-            raise Exception("Session should be set where identity is fetched!")
         base_session = self.basic_session._session or botocore.session.get_session()
         client_creator = base_session.create_client
         federation = self.extra_config.get('assume_role_with_web_identity_federation')
@@ -368,12 +285,11 @@ class BaseSessionFactory(LoggingMixin):
             raise AirflowException(
                 f'Unsupported federation: {federation}. Currently "google" only are supported.'
             )
-        assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
         return botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
             client_creator=client_creator,
             web_identity_token_loader=web_identity_token_loader,
             role_arn=self.role_arn,
-            extra_args=assume_role_kwargs,
+            extra_args=self.conn.assume_role_kwargs,
         )
 
     def _get_google_identity_token_loader(self):
@@ -395,6 +311,37 @@ class BaseSessionFactory(LoggingMixin):
 
         return web_identity_token_loader
 
+    def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
+        return slugify(role_session_name, regex_pattern=r'[^\w+=,.@-]+')
+
+    def _get_region_name(self) -> Optional[str]:
+        warnings.warn(
+            "`BaseSessionFactory._get_region_name` method will be deprecated in the future."
+            "Please use `BaseSessionFactory.region_name` property instead.",
+            PendingDeprecationWarning,
+            stacklevel=2,
+        )
+        return self.region_name
+
+    def _read_role_arn_from_extra_config(self) -> Optional[str]:
+        warnings.warn(
+            "`BaseSessionFactory._read_role_arn_from_extra_config` method will be deprecated in the future."
+            "Please use `BaseSessionFactory.role_arn` property instead.",
+            PendingDeprecationWarning,
+            stacklevel=2,
+        )
+        return self.role_arn
+
+    def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
+        warnings.warn(
+            "`BaseSessionFactory._read_credentials_from_connection` method will be deprecated in the future."
+            "Please use `BaseSessionFactory.conn.aws_access_key_id` and "
+            "`BaseSessionFactory.aws_secret_access_key` properties instead.",
+            PendingDeprecationWarning,
+            stacklevel=2,
+        )
+        return self.conn.aws_access_key_id, self.conn.aws_secret_access_key
+
 
 class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
     """
@@ -446,24 +393,19 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
 
         try:
-            # Fetch the Airflow connection object
-            connection_object = self.get_connection(self.aws_conn_id)
-            extra_config = connection_object.extra_dejson
-            endpoint_url = extra_config.get("host")
-
-            # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
-            if "config_kwargs" in extra_config:
-                self.log.debug(
-                    "Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s",
-                    extra_config["config_kwargs"],
-                )
-                self.config = Config(**extra_config["config_kwargs"])
+            # Fetch the Airflow connection object and wrap it in helper
+            connection_object = AwsConnectionWrapper(self.get_connection(self.aws_conn_id))
+
+            if connection_object.botocore_config:
+                # For historical reason botocore.config.Config from connection overwrites
+                # config which explicitly set in Hook.
+                self.config = connection_object.botocore_config
 
             session = SessionFactory(
                 conn=connection_object, region_name=region_name, config=self.config
             ).create_session()
 
-            return session, endpoint_url
+            return session, connection_object.endpoint_url
 
         except AirflowException:
             self.log.warning(
@@ -675,57 +617,6 @@ class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
     """
 
 
-def _parse_s3_config(
-    config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
-) -> Tuple[Optional[str], Optional[str]]:
-    """
-    Parses a config file for s3 credentials. Can currently
-    parse boto, s3cmd.conf and AWS SDK config formats
-
-    :param config_file_name: path to the config file
-    :param config_format: config type. One of "boto", "s3cmd" or "aws".
-        Defaults to "boto"
-    :param profile: profile name in AWS type config file
-    """
-    config = configparser.ConfigParser()
-    if config.read(config_file_name):  # pragma: no cover
-        sections = config.sections()
-    else:
-        raise AirflowException(f"Couldn't read {config_file_name}")
-    # Setting option names depending on file format
-    if config_format is None:
-        config_format = "boto"
-    conf_format = config_format.lower()
-    if conf_format == "boto":  # pragma: no cover
-        if profile is not None and "profile " + profile in sections:
-            cred_section = "profile " + profile
-        else:
-            cred_section = "Credentials"
-    elif conf_format == "aws" and profile is not None:
-        cred_section = profile
-    else:
-        cred_section = "default"
-    # Option names
-    if conf_format in ("boto", "aws"):  # pragma: no cover
-        key_id_option = "aws_access_key_id"
-        secret_key_option = "aws_secret_access_key"
-        # security_token_option = 'aws_security_token'
-    else:
-        key_id_option = "access_key"
-        secret_key_option = "secret_key"
-    # Actual Parsing
-    if cred_section not in sections:
-        raise AirflowException("This config file format is not recognized")
-    else:
-        try:
-            access_key = config.get(cred_section, key_id_option)
-            secret_key = config.get(cred_section, secret_key_option)
-        except Exception:
-            logging.warning("Option Error in parsing s3 config file")
-            raise
-        return access_key, secret_key
-
-
 def resolve_session_factory() -> Type[BaseSessionFactory]:
     """Resolves custom SessionFactory class"""
     clazz = conf.getimport("aws", "session_factory", fallback=None)
@@ -740,3 +631,16 @@ def resolve_session_factory() -> Type[BaseSessionFactory]:
 
 
 SessionFactory = resolve_session_factory()
+
+
+def _parse_s3_config(
+    config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
+):
+    """For compatibility with airflow.contrib.hooks.aws_hook"""
+    from airflow.providers.amazon.aws.utils.connection_wrapper import _parse_s3_config
+
+    return _parse_s3_config(
+        config_file_name=config_file_name,
+        config_format=config_format,
+        profile=profile,
+    )
diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py
new file mode 100644
index 0000000000..6672b971f6
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import warnings
+from copy import deepcopy
+from dataclasses import InitVar, dataclass, field
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+
+from botocore.config import Config
+
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+    from airflow.models.connection import Connection
+
+
+@dataclass
+class AwsConnectionWrapper(LoggingMixin):
+    """
+    AWS Connection Wrapper class helper.
+    Use for validate and resolve AWS Connection parameters.
+    """
+
+    conn: InitVar[Optional["Connection"]]
+
+    conn_id: Optional[str] = field(init=False, default=None)
+    conn_type: Optional[str] = field(init=False, default=None)
+    login: Optional[str] = field(init=False, repr=False, default=None)
+    password: Optional[str] = field(init=False, repr=False, default=None)
+    extra_config: Dict[str, Any] = field(init=False, repr=False, default_factory=dict)
+
+    aws_access_key_id: Optional[str] = field(init=False)
+    aws_secret_access_key: Optional[str] = field(init=False)
+    aws_session_token: Optional[str] = field(init=False)
+
+    region_name: Optional[str] = field(init=False, default=None)
+    session_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
+    botocore_config: Optional[Config] = field(init=False, default=None)
+    endpoint_url: Optional[str] = field(init=False, default=None)
+
+    role_arn: Optional[str] = field(init=False, default=None)
+    assume_role_method: Optional[str] = field(init=False, default=None)
+    assume_role_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
+
+    @cached_property
+    def conn_repr(self):
+        return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"
+
+    def __post_init__(self, conn: "Connection"):
+        if not conn:
+            return
+
+        extra = deepcopy(conn.extra_dejson)
+
+        # Assign attributes from AWS Connection
+        self.conn_id = conn.conn_id
+        self.conn_type = conn.conn_type or "aws"
+        self.login = conn.login
+        self.password = conn.password
+        self.extra_config = deepcopy(conn.extra_dejson)
+
+        if self.conn_type != "aws":
+            warnings.warn(
+                f"{self.conn_repr} expected connection type 'aws', got {self.conn_type!r}.",
+                UserWarning,
+                stacklevel=2,
+            )
+
+        # Retrieve initial connection credentials
+        init_credentials = self._get_credentials(**extra)
+        self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials
+
+        if "region_name" in extra:
+            self.region_name = extra["region_name"]
+            self.log.info("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)
+
+        if "session_kwargs" in extra:
+            self.session_kwargs = extra["session_kwargs"]
+            self.log.info("Retrieving session_kwargs=%s from %s extra.", self.session_kwargs, self.conn_repr)
+
+        # Warn the user that an invalid parameter is being used which actually not related to 'profile_name'.
+        if "profile" in extra and "s3_config_file" not in extra:
+            if "profile_name" not in self.session_kwargs:
+                warnings.warn(
+                    f"Found 'profile' without specifying 's3_config_file' in {self.conn_repr} extra. "
+                    "If required profile from AWS Shared Credentials please "
+                    f"set 'profile_name' in {self.conn_repr} extra['session_kwargs'].",
+                    UserWarning,
+                    stacklevel=2,
+                )
+
+        config_kwargs = extra.get("config_kwargs")
+        if config_kwargs:
+            # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+            self.log.info("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
+            self.botocore_config = Config(**config_kwargs)
+
+        self.endpoint_url = extra.get("host")
+
+        # Retrieve Assume Role Configuration
+        assume_role_configs = self._get_assume_role_configs(**extra)
+        self.role_arn, self.assume_role_method, self.assume_role_kwargs = assume_role_configs
+
+    @property
+    def extra_dejson(self):
+        return self.extra_config
+
+    def __bool__(self):
+        return self.conn_id is not None
+
+    def _get_credentials(
+        self,
+        *,
+        aws_access_key_id: Optional[str] = None,
+        aws_secret_access_key: Optional[str] = None,
+        aws_session_token: Optional[str] = None,
+        # Deprecated Values
+        s3_config_file: Optional[str] = None,
+        s3_config_format: Optional[str] = None,
+        profile: Optional[str] = None,
+        **kwargs,
+    ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
+        """
+        Get AWS credentials from connection login/password and extra.
+
+        ``aws_access_key_id`` and ``aws_secret_access_key`` order
+        1. From Connection login and password
+        2. From Connection extra['aws_access_key_id'] and extra['aws_access_key_id']
+        3. (deprecated) From local credentials file
+
+        Get ``aws_session_token`` from extra['aws_access_key_id']
+
+        """
+        if self.login and self.password:
+            self.log.info("%s credentials retrieved from login and password.", self.conn_repr)
+            aws_access_key_id, aws_secret_access_key = self.login, self.password
+        elif aws_access_key_id and aws_secret_access_key:
+            self.log.info("%s credentials retrieved from extra.", self.conn_repr)
+        elif s3_config_file:
+            aws_access_key_id, aws_secret_access_key = _parse_s3_config(
+                s3_config_file,
+                s3_config_format,
+                profile,
+            )
+            self.log.info("%s credentials retrieved from extra['s3_config_file']", self.conn_repr)
+
+        if aws_session_token:
+            self.log.info(
+                "%s session token retrieved from extra, please note you are responsible for renewing these.",
+                self.conn_repr,
+            )
+
+        return aws_access_key_id, aws_secret_access_key, aws_session_token
+
+    def _get_assume_role_configs(
+        self,
+        *,
+        role_arn: Optional[str] = None,
+        assume_role_method: str = "assume_role",
+        assume_role_kwargs: Optional[Dict[str, Any]] = None,
+        # Deprecated Values
+        aws_account_id: Optional[str] = None,
+        aws_iam_role: Optional[str] = None,
+        external_id: Optional[str] = None,
+        **kwargs,
+    ) -> Tuple[Optional[str], Optional[str], Dict[Any, str]]:
+        """Get assume role configs from Connection extra."""
+        if role_arn:
+            self.log.info("Retrieving role_arn=%r from %s extra.", role_arn, self.conn_repr)
+        elif aws_account_id and aws_iam_role:
+            warnings.warn(
+                "Constructing 'role_arn' from extra['aws_account_id'] and extra['aws_iam_role'] is deprecated"
+                f" and will be removed in a future releases."
+                f" Please set 'role_arn' in {self.conn_repr} extra.",
+                DeprecationWarning,
+                stacklevel=3,
+            )
+            role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
+            self.log.info(
+                "Constructions role_arn=%r from %s extra['aws_account_id'] and extra['aws_iam_role'].",
+                role_arn,
+                self.conn_repr,
+            )
+
+        if not role_arn:
+            # There is no reason obtain `assume_role_method` and `assume_role_kwargs` if `role_arn` not set.
+            return None, None, {}
+
+        supported_methods = ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
+        if assume_role_method not in supported_methods:
+            raise NotImplementedError(
+                f'Found assume_role_method={assume_role_method!r} in {self.conn_repr} extra.'
+                f' Currently {supported_methods} are supported.'
+                ' (Exclude this setting will default to "assume_role").'
+            )
+        self.log.info("Retrieve assume_role_method=%r from %s.", assume_role_method, self.conn_repr)
+
+        assume_role_kwargs = assume_role_kwargs or {}
+        if "ExternalId" not in assume_role_kwargs and external_id:
+            warnings.warn(
+                "'external_id' in extra config is deprecated and will be removed in a future releases. "
+                f"Please set 'ExternalId' in 'assume_role_kwargs' in {self.conn_repr} extra.",
+                DeprecationWarning,
+                stacklevel=3,
+            )
+            assume_role_kwargs["ExternalId"] = external_id
+
+        return role_arn, assume_role_method, assume_role_kwargs
+
+
+def _parse_s3_config(
+    config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
+) -> Tuple[Optional[str], Optional[str]]:
+    """
+    Parses a config file for s3 credentials. Can currently
+    parse boto, s3cmd.conf and AWS SDK config formats
+
+    :param config_file_name: path to the config file
+    :param config_format: config type. One of "boto", "s3cmd" or "aws".
+        Defaults to "boto"
+    :param profile: profile name in AWS type config file
+    """
+    warnings.warn(
+        "Use local credentials file is never documented and well tested. "
+        "Obtain credentials by this way deprecated and will be removed in a future releases.",
+        DeprecationWarning,
+        stacklevel=4,
+    )
+
+    import configparser
+
+    config = configparser.ConfigParser()
+    if config.read(config_file_name):  # pragma: no cover
+        sections = config.sections()
+    else:
+        raise AirflowException(f"Couldn't read {config_file_name}")
+    # Setting option names depending on file format
+    if config_format is None:
+        config_format = "boto"
+    conf_format = config_format.lower()
+    if conf_format == "boto":  # pragma: no cover
+        if profile is not None and "profile " + profile in sections:
+            cred_section = "profile " + profile
+        else:
+            cred_section = "Credentials"
+    elif conf_format == "aws" and profile is not None:
+        cred_section = profile
+    else:
+        cred_section = "default"
+    # Option names
+    if conf_format in ("boto", "aws"):  # pragma: no cover
+        key_id_option = "aws_access_key_id"
+        secret_key_option = "aws_secret_access_key"
+    else:
+        key_id_option = "access_key"
+        secret_key_option = "secret_key"
+    # Actual Parsing
+    if cred_section not in sections:
+        raise AirflowException("This config file format is not recognized")
+    else:
+        try:
+            access_key = config.get(cred_section, key_id_option)
+            secret_key = config.get(cred_section, secret_key_option)
+        except Exception:
+            raise AirflowException("Option Error in parsing s3 config file")
+        return access_key, secret_key
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index d1c3da9cdd..a04473b5b0 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import (
     BaseSessionFactory,
     resolve_session_factory,
 )
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
 from tests.test_utils.config import conf_vars
 
 try:
@@ -42,6 +43,9 @@ except ImportError:
     mock_sts = None
     mock_iam = None
 
+MOCK_AWS_CONN_ID = "mock-conn-id"
+MOCK_CONN_TYPE = "aws"
+
 SAML_ASSERTION = """
 <?xml version="1.0"?>
 <samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="_00000000-0000-0000-0000-000000000000" Version="2.0" IssueInstant="2012-01-01T12:00:00.000Z" Destination="https://signin.aws.amazon.com/saml" Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
@@ -114,7 +118,7 @@ class CustomSessionFactory(BaseSessionFactory):
         return mock.MagicMock()
 
 
-class TestAwsBaseHook:
+class TestSessionFactory:
     @conf_vars(
         {("aws", "session_factory"): "tests.providers.amazon.aws.hooks.test_base_aws.CustomSessionFactory"}
     )
@@ -131,6 +135,23 @@ class TestAwsBaseHook:
         cls = resolve_session_factory()
         assert issubclass(cls, BaseSessionFactory)
 
+    @pytest.mark.parametrize(
+        "mock_conn",
+        [
+            Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID),
+            AwsConnectionWrapper(conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID)),
+        ],
+    )
+    def test_conn_property(self, mock_conn):
+        sf = BaseSessionFactory(conn=mock_conn, region_name=None, config=None)
+        session_factory_conn = sf.conn
+        assert isinstance(session_factory_conn, AwsConnectionWrapper)
+        assert session_factory_conn.conn_id == MOCK_AWS_CONN_ID
+        assert session_factory_conn.conn_type == MOCK_CONN_TYPE
+        assert sf.conn is session_factory_conn
+
+
+class TestAwsBaseHook:
     @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
     @mock_emr
     def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
@@ -260,80 +281,6 @@ class TestAwsBaseHook:
 
         assert table.item_count == 0
 
-    @mock.patch.object(AwsBaseHook, 'get_connection')
-    def test_get_credentials_from_login_with_token(self, mock_get_connection):
-        mock_connection = Connection(
-            login='aws_access_key_id',
-            password='aws_secret_access_key',
-            extra='{"aws_session_token": "test_token"}',
-        )
-        mock_get_connection.return_value = mock_connection
-        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
-        credentials_from_hook = hook.get_credentials()
-        assert credentials_from_hook.access_key == 'aws_access_key_id'
-        assert credentials_from_hook.secret_key == 'aws_secret_access_key'
-        assert credentials_from_hook.token == 'test_token'
-
-    @mock.patch.object(AwsBaseHook, 'get_connection')
-    def test_get_credentials_from_login_without_token(self, mock_get_connection):
-        mock_connection = Connection(
-            login='aws_access_key_id',
-            password='aws_secret_access_key',
-        )
-
-        mock_get_connection.return_value = mock_connection
-        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam')
-        credentials_from_hook = hook.get_credentials()
-        assert credentials_from_hook.access_key == 'aws_access_key_id'
-        assert credentials_from_hook.secret_key == 'aws_secret_access_key'
-        assert credentials_from_hook.token is None
-
-    @mock.patch.object(AwsBaseHook, 'get_connection')
-    def test_get_credentials_from_extra_with_token(self, mock_get_connection):
-        mock_connection = Connection(
-            extra='{"aws_access_key_id": "aws_access_key_id",'
-            '"aws_secret_access_key": "aws_secret_access_key",'
-            ' "aws_session_token": "session_token"}'
-        )
-        mock_get_connection.return_value = mock_connection
-        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
-        credentials_from_hook = hook.get_credentials()
-        assert credentials_from_hook.access_key == 'aws_access_key_id'
-        assert credentials_from_hook.secret_key == 'aws_secret_access_key'
-        assert credentials_from_hook.token == 'session_token'
-
-    @mock.patch.object(AwsBaseHook, 'get_connection')
-    def test_get_credentials_from_extra_without_token(self, mock_get_connection):
-        mock_connection = Connection(
-            extra='{"aws_access_key_id": "aws_access_key_id",'
-            '"aws_secret_access_key": "aws_secret_access_key"}'
-        )
-        mock_get_connection.return_value = mock_connection
-        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
-        credentials_from_hook = hook.get_credentials()
-        assert credentials_from_hook.access_key == 'aws_access_key_id'
-        assert credentials_from_hook.secret_key == 'aws_secret_access_key'
-        assert credentials_from_hook.token is None
-
-    @mock.patch(
-        'airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config',
-        return_value=('aws_access_key_id', 'aws_secret_access_key'),
-    )
-    @mock.patch.object(AwsBaseHook, 'get_connection')
-    def test_get_credentials_from_extra_with_s3_config_and_profile(
-        self, mock_get_connection, mock_parse_s3_config
-    ):
-        mock_connection = Connection(
-            extra='{"s3_config_format": "aws", '
-            '"profile": "test", '
-            '"s3_config_file": "aws-credentials", '
-            '"region_name": "us-east-1"}'
-        )
-        mock_get_connection.return_value = mock_connection
-        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
-        hook._get_credentials(region_name=None)
-        mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws', 'test')
-
     @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
     @mock.patch.object(AwsBaseHook, 'get_connection')
     @mock_sts
diff --git a/tests/providers/amazon/aws/utils/test_connection_wrapper.py b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
new file mode 100644
index 0000000000..74696db98e
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
@@ -0,0 +1,283 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional
+from unittest import mock
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
+
+MOCK_AWS_CONN_ID = "mock-conn-id"
+MOCK_CONN_TYPE = "aws"
+MOCK_ROLE_ARN = "arn:aws:iam::222222222222:role/awesome-role"
+
+
+def mock_connection_factory(
+    conn_id: Optional[str] = MOCK_AWS_CONN_ID, conn_type: Optional[str] = MOCK_CONN_TYPE, **kwargs
+) -> Connection:
+    return Connection(conn_id=conn_id, conn_type=conn_type, **kwargs)
+
+
+class TestAwsConnectionWrapper:
+    @pytest.mark.parametrize("extra", [{"foo": "bar", "spam": "egg"}, '{"foo": "bar", "spam": "egg"}', None])
+    def test_values_from_connection(self, extra):
+        mock_conn = mock_connection_factory(
+            login="mock-login",
+            password="mock-password",
+            extra=extra,
+            # AwsBaseHook never use this attributes from airflow.models.Connection
+            host="mock-host",
+            schema="mock-schema",
+            port=42,
+        )
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+
+        assert wrap_conn.conn_id == mock_conn.conn_id
+        assert wrap_conn.conn_type == mock_conn.conn_type
+        assert wrap_conn.login == mock_conn.login
+        assert wrap_conn.password == mock_conn.password
+
+        # Check that original extra config from connection persists in wrapper
+        assert wrap_conn.extra_config == mock_conn.extra_dejson
+        assert wrap_conn.extra_config is not mock_conn.extra_dejson
+        # `extra_config` is a same object that return by `extra_dejson`
+        assert wrap_conn.extra_config is wrap_conn.extra_dejson
+
+        # Check that not assigned other attributes from airflow.models.Connection to wrapper
+        assert not hasattr(wrap_conn, "host")
+        assert not hasattr(wrap_conn, "schema")
+        assert not hasattr(wrap_conn, "port")
+
+        # Check that Wrapper is True if assign connection
+        assert wrap_conn
+
+    def test_no_connection(self):
+        assert not AwsConnectionWrapper(conn=None)
+
+    @pytest.mark.parametrize("conn_type", ["aws", None])
+    def test_expected_aws_connection_type(self, conn_type):
+        wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory(conn_type=conn_type))
+        assert wrap_conn.conn_type == "aws"
+
+    @pytest.mark.parametrize("conn_type", ["AWS", "boto3", "s3", "emr", "google", "google-cloud-platform"])
+    def test_unexpected_aws_connection_type(self, conn_type):
+        warning_message = f"expected connection type 'aws', got '{conn_type}'"
+        with pytest.warns(UserWarning, match=warning_message):
+            wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory(conn_type=conn_type))
+            assert wrap_conn.conn_type == conn_type
+
+    @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"])
+    @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"])
+    @pytest.mark.parametrize("aws_access_key_id", ["mock-aws-access-key-id"])
+    def test_get_credentials_from_login(self, aws_access_key_id, aws_secret_access_key, aws_session_token):
+        mock_conn = mock_connection_factory(
+            login=aws_access_key_id,
+            password=aws_secret_access_key,
+            extra={"aws_session_token": aws_session_token} if aws_session_token else None,
+        )
+
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.aws_access_key_id == aws_access_key_id
+        assert wrap_conn.aws_secret_access_key == aws_secret_access_key
+        assert wrap_conn.aws_session_token == aws_session_token
+
+    @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"])
+    @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"])
+    @pytest.mark.parametrize("aws_access_key_id", ["mock-aws-access-key-id"])
+    def test_get_credentials_from_extra(self, aws_access_key_id, aws_secret_access_key, aws_session_token):
+        mock_conn_extra = {
+            "aws_access_key_id": aws_access_key_id,
+            "aws_secret_access_key": aws_secret_access_key,
+        }
+        if aws_session_token:
+            mock_conn_extra["aws_session_token"] = aws_session_token
+        mock_conn = mock_connection_factory(login=None, password=None, extra=mock_conn_extra)
+
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.aws_access_key_id == aws_access_key_id
+        assert wrap_conn.aws_secret_access_key == aws_secret_access_key
+        assert wrap_conn.aws_session_token == aws_session_token
+
+    # This function never tested and mark as deprecated. Only test expected output
+    @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper._parse_s3_config")
+    @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"])
+    @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"])
+    @pytest.mark.parametrize("aws_access_key_id", ["mock-aws-access-key-id"])
+    def test_get_credentials_from_s3_config(
+        self, mock_parse_s3_config, aws_access_key_id, aws_secret_access_key, aws_session_token
+    ):
+        mock_parse_s3_config.return_value = (aws_access_key_id, aws_secret_access_key)
+        mock_conn_extra = {
+            "s3_config_format": "aws",
+            "profile": "test",
+            "s3_config_file": "aws-credentials",
+        }
+        if aws_session_token:
+            mock_conn_extra["aws_session_token"] = aws_session_token
+        mock_conn = mock_connection_factory(login=None, password=None, extra=mock_conn_extra)
+
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws', 'test')
+        assert wrap_conn.aws_access_key_id == aws_access_key_id
+        assert wrap_conn.aws_secret_access_key == aws_secret_access_key
+        assert wrap_conn.aws_session_token == aws_session_token
+
+    @pytest.mark.parametrize("region_name", [None, "mock-aws-region"])
+    def test_get_region_name(self, region_name):
+        mock_conn = mock_connection_factory(extra={"region_name": region_name} if region_name else None)
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.region_name == region_name
+
+    @pytest.mark.parametrize("session_kwargs", [None, {"profile_name": "mock-profile"}])
+    def test_get_session_kwargs(self, session_kwargs):
+        mock_conn = mock_connection_factory(
+            extra={"session_kwargs": session_kwargs} if session_kwargs else None
+        )
+        expected = session_kwargs or {}
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.session_kwargs == expected
+
+    def test_warn_wrong_profile_param_used(self):
+        mock_conn = mock_connection_factory(extra={"profile": "mock-profile"})
+        warning_message = "Found 'profile' without specifying 's3_config_file' in .* set 'profile_name' in"
+        with pytest.warns(UserWarning, match=warning_message):
+            wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert "profile_name" not in wrap_conn.session_kwargs
+
+    @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper.Config", autospec=True)
+    @pytest.mark.parametrize("botocore_config_kwargs", [None, {"user_agent": "Airflow Amazon Provider"}])
+    def test_get_botocore_config(self, mock_botocore_config, botocore_config_kwargs):
+        mock_conn = mock_connection_factory(
+            extra={"config_kwargs": botocore_config_kwargs} if botocore_config_kwargs else None
+        )
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+
+        if not botocore_config_kwargs:
+            assert not mock_botocore_config.called
+            assert wrap_conn.botocore_config is None
+        else:
+            assert mock_botocore_config.called
+            assert mock_botocore_config.call_count == 1
+            assert mock.call(**botocore_config_kwargs) in mock_botocore_config.mock_calls
+
+    @pytest.mark.parametrize("endpoint_url", [None, "https://example.org"])
+    def test_get_endpoint_url(self, endpoint_url):
+        mock_conn = mock_connection_factory(extra={"host": endpoint_url} if endpoint_url else None)
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.endpoint_url == endpoint_url
+
+    @pytest.mark.parametrize("aws_account_id, aws_iam_role", [(None, None), ("111111111111", "another-role")])
+    def test_get_role_arn(self, aws_account_id, aws_iam_role):
+        mock_conn = mock_connection_factory(
+            extra={
+                "role_arn": MOCK_ROLE_ARN,
+                "aws_account_id": aws_account_id,
+                "aws_iam_role": aws_iam_role,
+            }
+        )
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.role_arn == MOCK_ROLE_ARN
+
+    @pytest.mark.parametrize(
+        "aws_account_id, aws_iam_role, expected",
+        [
+            ("222222222222", "mock-role", "arn:aws:iam::222222222222:role/mock-role"),
+            ("333333333333", "role-path/mock-role", "arn:aws:iam::333333333333:role/role-path/mock-role"),
+        ],
+    )
+    def test_constructing_role_arn(self, aws_account_id, aws_iam_role, expected):
+        mock_conn = mock_connection_factory(
+            extra={
+                "aws_account_id": aws_account_id,
+                "aws_iam_role": aws_iam_role,
+            }
+        )
+        with pytest.warns(DeprecationWarning, match="Please set 'role_arn' in .* extra"):
+            wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.role_arn == expected
+
+    def test_empty_role_arn(self):
+        wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory())
+        assert wrap_conn.role_arn is None
+        assert wrap_conn.assume_role_method is None
+        assert wrap_conn.assume_role_kwargs == {}
+
+    @pytest.mark.parametrize(
+        "assume_role_method", ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
+    )
+    def test_get_assume_role_method(self, assume_role_method):
+        mock_conn = mock_connection_factory(
+            extra={"role_arn": MOCK_ROLE_ARN, "assume_role_method": assume_role_method}
+        )
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.assume_role_method == assume_role_method
+
+    def test_default_assume_role_method(self):
+        mock_conn = mock_connection_factory(
+            extra={
+                "role_arn": MOCK_ROLE_ARN,
+            }
+        )
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert wrap_conn.assume_role_method == "assume_role"
+
+    def test_unsupported_assume_role_method(self):
+        mock_conn = mock_connection_factory(
+            extra={"role_arn": MOCK_ROLE_ARN, "assume_role_method": "dummy_method"}
+        )
+        with pytest.raises(NotImplementedError, match="Found assume_role_method='dummy_method' in .* extra"):
+            AwsConnectionWrapper(conn=mock_conn)
+
+    @pytest.mark.parametrize("assume_role_kwargs", [None, {"DurationSeconds": 42}])
+    def test_get_assume_role_kwargs(self, assume_role_kwargs):
+        mock_conn_extra = {"role_arn": MOCK_ROLE_ARN}
+        if assume_role_kwargs:
+            mock_conn_extra["assume_role_kwargs"] = assume_role_kwargs
+        mock_conn = mock_connection_factory(extra=mock_conn_extra)
+
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        expected = assume_role_kwargs or {}
+        assert wrap_conn.assume_role_kwargs == expected
+
+    @pytest.mark.parametrize("external_id_in_extra", [None, "mock-external-id-in-extra"])
+    def test_get_assume_role_kwargs_external_id_in_kwargs(self, external_id_in_extra):
+        mock_external_id_in_kwargs = "mock-external-id-in-kwargs"
+        mock_conn_extra = {
+            "role_arn": MOCK_ROLE_ARN,
+            "assume_role_kwargs": {"ExternalId": mock_external_id_in_kwargs},
+        }
+        if external_id_in_extra:
+            mock_conn_extra["external_id"] = external_id_in_extra
+        mock_conn = mock_connection_factory(extra=mock_conn_extra)
+
+        wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert "ExternalId" in wrap_conn.assume_role_kwargs
+        assert wrap_conn.assume_role_kwargs["ExternalId"] == mock_external_id_in_kwargs
+        assert wrap_conn.assume_role_kwargs["ExternalId"] != external_id_in_extra
+
+    def test_get_assume_role_kwargs_external_id_in_extra(self):
+        mock_external_id_in_extra = "mock-external-id-in-extra"
+        mock_conn_extra = {"role_arn": MOCK_ROLE_ARN, "external_id": mock_external_id_in_extra}
+        mock_conn = mock_connection_factory(extra=mock_conn_extra)
+
+        warning_message = "Please set 'ExternalId' in 'assume_role_kwargs' in .* extra."
+        with pytest.warns(DeprecationWarning, match=warning_message):
+            wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+        assert "ExternalId" in wrap_conn.assume_role_kwargs
+        assert wrap_conn.assume_role_kwargs["ExternalId"] == mock_external_id_in_extra