You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/26 12:55:04 UTC
[airflow] branch main updated: Resolve and validate AWS Connection parameters in wrapper (#25256)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 432977be0c Resolve and validate AWS Connection parameters in wrapper (#25256)
432977be0c is described below
commit 432977be0cd1e95f623fa5edda2a227798fa2939
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Tue Jul 26 16:54:54 2022 +0400
Resolve and validate AWS Connection parameters in wrapper (#25256)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 320 ++++++++-------------
.../amazon/aws/utils/connection_wrapper.py | 282 ++++++++++++++++++
tests/providers/amazon/aws/hooks/test_base_aws.py | 97 ++-----
.../amazon/aws/utils/test_connection_wrapper.py | 283 ++++++++++++++++++
4 files changed, 699 insertions(+), 283 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index a69798990f..37e3780c58 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -24,7 +24,6 @@ This module contains Base AWS Hook.
:ref:`howto/connection:AWSHook`
"""
-import configparser
import datetime
import json
import logging
@@ -48,6 +47,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.utils.log.logging_mixin import LoggingMixin
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
@@ -65,78 +65,62 @@ class BaseSessionFactory(LoggingMixin):
:ref:`howto/connection:aws:session-factory`
"""
- def __init__(self, conn: Connection, region_name: Optional[str], config: Config) -> None:
+ def __init__(
+ self, conn: Union[Connection, AwsConnectionWrapper], region_name: Optional[str], config: Config
+ ) -> None:
super().__init__()
- self.conn = conn
- self.region_name = region_name
+ self._conn = conn
+ self._region_name = region_name
self.config = config
- self.extra_config = self.conn.extra_dejson
- self.basic_session: Optional[boto3.session.Session] = None
- self.role_arn: Optional[str] = None
+ @cached_property
+ def conn(self) -> AwsConnectionWrapper:
+ """Cached AWS Connection Wrapper."""
+ if isinstance(self._conn, AwsConnectionWrapper):
+ return self._conn
+ return AwsConnectionWrapper(self._conn)
- def create_session(self) -> boto3.session.Session:
- """Create AWS session."""
- session_kwargs = {}
- if "session_kwargs" in self.extra_config:
- self.log.info(
- "Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s",
- self.extra_config["session_kwargs"],
- )
- session_kwargs = self.extra_config["session_kwargs"]
-
- if "profile" in self.extra_config and "s3_config_file" not in self.extra_config:
- if "profile_name" not in session_kwargs:
- self.log.warning(
- "Found 'profile' without specifying 's3_config_file'. "
- "If required profile from AWS Shared Credentials please "
- "set 'profile_name' in extra 'session_kwargs'."
- )
+ @cached_property
+ def basic_session(self) -> boto3.session.Session:
+ """Cached property with basic boto3.session.Session."""
+ return self._create_basic_session(session_kwargs=self.conn.session_kwargs)
- self.basic_session = self._create_basic_session(session_kwargs=session_kwargs)
- self.role_arn = self._read_role_arn_from_extra_config()
- # If role_arn was specified then STS + assume_role
- if self.role_arn is None:
- return self.basic_session
+ @property
+ def extra_config(self) -> Dict[str, Any]:
+ """AWS Connection extra_config."""
+ return self.conn.extra_config
- return self._create_session_with_assume_role(session_kwargs=session_kwargs)
+ @property
+ def region_name(self) -> Optional[str]:
+ """Resolve region name.
- def _get_region_name(self) -> Optional[str]:
- region_name = self.region_name
- if self.region_name is None and 'region_name' in self.extra_config:
- self.log.info("Retrieving region_name from Connection.extra_config['region_name']")
- region_name = self.extra_config["region_name"]
- return region_name
+ 1. SessionFactory region_name
+ 2. Connection region_name
+ """
+ return self._region_name or self.conn.region_name
- def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
- aws_access_key_id, aws_secret_access_key = self._read_credentials_from_connection()
- aws_session_token = self.extra_config.get("aws_session_token")
- region_name = self._get_region_name()
- self.log.debug(
- "Creating session with aws_access_key_id=%s region_name=%s",
- aws_access_key_id,
- region_name,
- )
+ @property
+ def role_arn(self) -> Optional[str]:
+ """Assume Role ARN from AWS Connection"""
+ return self.conn.role_arn
+ def create_session(self) -> boto3.session.Session:
+ """Create AWS session."""
+ if not self.role_arn:
+ return self.basic_session
+ return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)
+
+ def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
return boto3.session.Session(
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=region_name,
- aws_session_token=aws_session_token,
+ aws_access_key_id=self.conn.aws_access_key_id,
+ aws_secret_access_key=self.conn.aws_secret_access_key,
+ aws_session_token=self.conn.aws_session_token,
+ region_name=self.region_name,
**session_kwargs,
)
def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
- assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
- self.log.debug("assume_role_method=%s", assume_role_method)
- supported_methods = ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
- if assume_role_method not in supported_methods:
- raise NotImplementedError(
- f'assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra.'
- f'Currently {supported_methods} are supported.'
- '(Exclude this setting will default to "assume_role").'
- )
- if assume_role_method == 'assume_role_with_web_identity':
+ if self.conn.assume_role_method == 'assume_role_with_web_identity':
# Deferred credentials have no initial credentials
credential_fetcher = self._get_web_identity_credential_fetcher()
credentials = botocore.credentials.DeferredRefreshableCredentials(
@@ -151,12 +135,9 @@ class BaseSessionFactory(LoggingMixin):
refresh_using=self._refresh_credentials,
method="sts-assume-role",
)
+
session = botocore.session.get_session()
session._credentials = credentials
-
- if self.basic_session is None:
- raise RuntimeError("The basic session should be created here!")
-
region_name = self.basic_session.region_name
session.set_config_variable("region", region_name)
@@ -164,25 +145,21 @@ class BaseSessionFactory(LoggingMixin):
def _refresh_credentials(self) -> Dict[str, Any]:
self.log.debug('Refreshing credentials')
- assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
- sts_session = self.basic_session
-
- if sts_session is None:
- raise RuntimeError(
- "Session should be initialized when refresh credentials with assume_role is used!"
- )
+ assume_role_method = self.conn.assume_role_method
+ if assume_role_method not in ('assume_role', 'assume_role_with_saml'):
+ raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
- sts_client = sts_session.client("sts", config=self.config)
+ sts_client = self.basic_session.client("sts", config=self.config)
if assume_role_method == 'assume_role':
sts_response = self._assume_role(sts_client=sts_client)
- elif assume_role_method == 'assume_role_with_saml':
- sts_response = self._assume_role_with_saml(sts_client=sts_client)
else:
- raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
+ sts_response = self._assume_role_with_saml(sts_client=sts_client)
+
sts_response_http_status = sts_response['ResponseMetadata']['HTTPStatusCode']
- if not sts_response_http_status == 200:
+ if sts_response_http_status != 200:
raise RuntimeError(f'sts_response_http_status={sts_response_http_status}')
+
credentials = sts_response['Credentials']
expiry_time = credentials.get('Expiration').isoformat()
self.log.debug('New credentials expiry_time: %s', expiry_time)
@@ -194,70 +171,13 @@ class BaseSessionFactory(LoggingMixin):
}
return credentials
- def _read_role_arn_from_extra_config(self) -> Optional[str]:
- aws_account_id = self.extra_config.get("aws_account_id")
- aws_iam_role = self.extra_config.get("aws_iam_role")
- role_arn = self.extra_config.get("role_arn")
- if role_arn is None and aws_account_id is not None and aws_iam_role is not None:
- self.log.info("Constructing role_arn from aws_account_id and aws_iam_role")
- warnings.warn(
- "Constructing 'role_arn' from 'aws_account_id' and 'aws_iam_role' is deprecated and "
- "will be removed in a future releases. Please set 'role_arn' in extra config.",
- DeprecationWarning,
- stacklevel=3,
- )
- role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
- self.log.debug("role_arn is %s", role_arn)
- return role_arn
-
- def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
- aws_access_key_id = None
- aws_secret_access_key = None
- if self.conn.login:
- aws_access_key_id = self.conn.login
- aws_secret_access_key = self.conn.password
- self.log.info("Credentials retrieved from login")
- elif "aws_access_key_id" in self.extra_config and "aws_secret_access_key" in self.extra_config:
- aws_access_key_id = self.extra_config["aws_access_key_id"]
- aws_secret_access_key = self.extra_config["aws_secret_access_key"]
- self.log.info("Credentials retrieved from extra_config")
- elif "s3_config_file" in self.extra_config:
- warnings.warn(
- "Use local credentials file is never documented and well tested. "
- "Obtain credentials by this way deprecated and will be removed in a future releases.",
- DeprecationWarning,
- stacklevel=3,
- )
- aws_access_key_id, aws_secret_access_key = _parse_s3_config(
- self.extra_config["s3_config_file"],
- self.extra_config.get("s3_config_format"),
- self.extra_config.get("profile"),
- )
- self.log.info("Credentials retrieved from extra_config['s3_config_file']")
- return aws_access_key_id, aws_secret_access_key
-
- def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
- return slugify(role_session_name, regex_pattern=r'[^\w+=,.@-]+')
-
def _assume_role(self, sts_client: boto3.client) -> Dict:
- assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
- if "ExternalId" not in assume_role_kwargs and "external_id" in self.extra_config:
- warnings.warn(
- "'external_id' in extra config is deprecated and will be removed in a future releases. "
- "Set 'ExternalId' in 'assume_role_kwargs' in extra config.",
- DeprecationWarning,
- stacklevel=3,
- )
- assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id")
- role_session_name = self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}")
- self.log.debug(
- "Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)",
- self.role_arn,
- role_session_name,
- )
- return sts_client.assume_role(
- RoleArn=self.role_arn, RoleSessionName=role_session_name, **assume_role_kwargs
- )
+ kw = {
+ "RoleSessionName": self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),
+ **self.conn.assume_role_kwargs,
+ "RoleArn": self.role_arn,
+ }
+ return sts_client.assume_role(**kw)
def _assume_role_with_saml(self, sts_client: boto3.client) -> Dict[str, Any]:
saml_config = self.extra_config['assume_role_with_saml']
@@ -273,12 +193,11 @@ class BaseSessionFactory(LoggingMixin):
)
self.log.debug("Doing sts_client.assume_role_with_saml to role_arn=%s", self.role_arn)
- assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
return sts_client.assume_role_with_saml(
RoleArn=self.role_arn,
PrincipalArn=principal_arn,
SAMLAssertion=saml_assertion,
- **assume_role_kwargs,
+ **self.conn.assume_role_kwargs,
)
def _get_idp_response(
@@ -357,8 +276,6 @@ class BaseSessionFactory(LoggingMixin):
def _get_web_identity_credential_fetcher(
self,
) -> botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher:
- if self.basic_session is None:
- raise Exception("Session should be set where identity is fetched!")
base_session = self.basic_session._session or botocore.session.get_session()
client_creator = base_session.create_client
federation = self.extra_config.get('assume_role_with_web_identity_federation')
@@ -368,12 +285,11 @@ class BaseSessionFactory(LoggingMixin):
raise AirflowException(
f'Unsupported federation: {federation}. Currently "google" only are supported.'
)
- assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
return botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator=client_creator,
web_identity_token_loader=web_identity_token_loader,
role_arn=self.role_arn,
- extra_args=assume_role_kwargs,
+ extra_args=self.conn.assume_role_kwargs,
)
def _get_google_identity_token_loader(self):
@@ -395,6 +311,37 @@ class BaseSessionFactory(LoggingMixin):
return web_identity_token_loader
+ def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
+ return slugify(role_session_name, regex_pattern=r'[^\w+=,.@-]+')
+
+ def _get_region_name(self) -> Optional[str]:
+ warnings.warn(
+ "`BaseSessionFactory._get_region_name` method will be deprecated in the future."
+ "Please use `BaseSessionFactory.region_name` property instead.",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.region_name
+
+ def _read_role_arn_from_extra_config(self) -> Optional[str]:
+ warnings.warn(
+ "`BaseSessionFactory._read_role_arn_from_extra_config` method will be deprecated in the future."
+ "Please use `BaseSessionFactory.role_arn` property instead.",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.role_arn
+
+ def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
+ warnings.warn(
+ "`BaseSessionFactory._read_credentials_from_connection` method will be deprecated in the future."
+ "Please use `BaseSessionFactory.conn.aws_access_key_id` and "
+ "`BaseSessionFactory.aws_secret_access_key` properties instead.",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.conn.aws_access_key_id, self.conn.aws_secret_access_key
+
class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
"""
@@ -446,24 +393,19 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
try:
- # Fetch the Airflow connection object
- connection_object = self.get_connection(self.aws_conn_id)
- extra_config = connection_object.extra_dejson
- endpoint_url = extra_config.get("host")
-
- # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
- if "config_kwargs" in extra_config:
- self.log.debug(
- "Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s",
- extra_config["config_kwargs"],
- )
- self.config = Config(**extra_config["config_kwargs"])
+ # Fetch the Airflow connection object and wrap it in helper
+ connection_object = AwsConnectionWrapper(self.get_connection(self.aws_conn_id))
+
+ if connection_object.botocore_config:
+ # For historical reason botocore.config.Config from connection overwrites
+ # config which explicitly set in Hook.
+ self.config = connection_object.botocore_config
session = SessionFactory(
conn=connection_object, region_name=region_name, config=self.config
).create_session()
- return session, endpoint_url
+ return session, connection_object.endpoint_url
except AirflowException:
self.log.warning(
@@ -675,57 +617,6 @@ class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
-def _parse_s3_config(
- config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
-) -> Tuple[Optional[str], Optional[str]]:
- """
- Parses a config file for s3 credentials. Can currently
- parse boto, s3cmd.conf and AWS SDK config formats
-
- :param config_file_name: path to the config file
- :param config_format: config type. One of "boto", "s3cmd" or "aws".
- Defaults to "boto"
- :param profile: profile name in AWS type config file
- """
- config = configparser.ConfigParser()
- if config.read(config_file_name): # pragma: no cover
- sections = config.sections()
- else:
- raise AirflowException(f"Couldn't read {config_file_name}")
- # Setting option names depending on file format
- if config_format is None:
- config_format = "boto"
- conf_format = config_format.lower()
- if conf_format == "boto": # pragma: no cover
- if profile is not None and "profile " + profile in sections:
- cred_section = "profile " + profile
- else:
- cred_section = "Credentials"
- elif conf_format == "aws" and profile is not None:
- cred_section = profile
- else:
- cred_section = "default"
- # Option names
- if conf_format in ("boto", "aws"): # pragma: no cover
- key_id_option = "aws_access_key_id"
- secret_key_option = "aws_secret_access_key"
- # security_token_option = 'aws_security_token'
- else:
- key_id_option = "access_key"
- secret_key_option = "secret_key"
- # Actual Parsing
- if cred_section not in sections:
- raise AirflowException("This config file format is not recognized")
- else:
- try:
- access_key = config.get(cred_section, key_id_option)
- secret_key = config.get(cred_section, secret_key_option)
- except Exception:
- logging.warning("Option Error in parsing s3 config file")
- raise
- return access_key, secret_key
-
-
def resolve_session_factory() -> Type[BaseSessionFactory]:
"""Resolves custom SessionFactory class"""
clazz = conf.getimport("aws", "session_factory", fallback=None)
@@ -740,3 +631,16 @@ def resolve_session_factory() -> Type[BaseSessionFactory]:
SessionFactory = resolve_session_factory()
+
+
+def _parse_s3_config(
+ config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
+):
+ """For compatibility with airflow.contrib.hooks.aws_hook"""
+ from airflow.providers.amazon.aws.utils.connection_wrapper import _parse_s3_config
+
+ return _parse_s3_config(
+ config_file_name=config_file_name,
+ config_format=config_format,
+ profile=profile,
+ )
diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py
new file mode 100644
index 0000000000..6672b971f6
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import warnings
+from copy import deepcopy
+from dataclasses import InitVar, dataclass, field
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+
+from botocore.config import Config
+
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+ from airflow.models.connection import Connection
+
+
+@dataclass
+class AwsConnectionWrapper(LoggingMixin):
+ """
+ AWS Connection Wrapper class helper.
+ Use for validate and resolve AWS Connection parameters.
+ """
+
+ conn: InitVar[Optional["Connection"]]
+
+ conn_id: Optional[str] = field(init=False, default=None)
+ conn_type: Optional[str] = field(init=False, default=None)
+ login: Optional[str] = field(init=False, repr=False, default=None)
+ password: Optional[str] = field(init=False, repr=False, default=None)
+ extra_config: Dict[str, Any] = field(init=False, repr=False, default_factory=dict)
+
+ aws_access_key_id: Optional[str] = field(init=False)
+ aws_secret_access_key: Optional[str] = field(init=False)
+ aws_session_token: Optional[str] = field(init=False)
+
+ region_name: Optional[str] = field(init=False, default=None)
+ session_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
+ botocore_config: Optional[Config] = field(init=False, default=None)
+ endpoint_url: Optional[str] = field(init=False, default=None)
+
+ role_arn: Optional[str] = field(init=False, default=None)
+ assume_role_method: Optional[str] = field(init=False, default=None)
+ assume_role_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
+
+ @cached_property
+ def conn_repr(self):
+ return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"
+
+ def __post_init__(self, conn: "Connection"):
+ if not conn:
+ return
+
+ extra = deepcopy(conn.extra_dejson)
+
+ # Assign attributes from AWS Connection
+ self.conn_id = conn.conn_id
+ self.conn_type = conn.conn_type or "aws"
+ self.login = conn.login
+ self.password = conn.password
+ self.extra_config = deepcopy(conn.extra_dejson)
+
+ if self.conn_type != "aws":
+ warnings.warn(
+ f"{self.conn_repr} expected connection type 'aws', got {self.conn_type!r}.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ # Retrieve initial connection credentials
+ init_credentials = self._get_credentials(**extra)
+ self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials
+
+ if "region_name" in extra:
+ self.region_name = extra["region_name"]
+ self.log.info("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)
+
+ if "session_kwargs" in extra:
+ self.session_kwargs = extra["session_kwargs"]
+ self.log.info("Retrieving session_kwargs=%s from %s extra.", self.session_kwargs, self.conn_repr)
+
+ # Warn the user that an invalid parameter is being used which actually not related to 'profile_name'.
+ if "profile" in extra and "s3_config_file" not in extra:
+ if "profile_name" not in self.session_kwargs:
+ warnings.warn(
+ f"Found 'profile' without specifying 's3_config_file' in {self.conn_repr} extra. "
+ "If required profile from AWS Shared Credentials please "
+ f"set 'profile_name' in {self.conn_repr} extra['session_kwargs'].",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ config_kwargs = extra.get("config_kwargs")
+ if config_kwargs:
+ # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ self.log.info("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
+ self.botocore_config = Config(**config_kwargs)
+
+ self.endpoint_url = extra.get("host")
+
+ # Retrieve Assume Role Configuration
+ assume_role_configs = self._get_assume_role_configs(**extra)
+ self.role_arn, self.assume_role_method, self.assume_role_kwargs = assume_role_configs
+
+ @property
+ def extra_dejson(self):
+ return self.extra_config
+
+ def __bool__(self):
+ return self.conn_id is not None
+
+ def _get_credentials(
+ self,
+ *,
+ aws_access_key_id: Optional[str] = None,
+ aws_secret_access_key: Optional[str] = None,
+ aws_session_token: Optional[str] = None,
+ # Deprecated Values
+ s3_config_file: Optional[str] = None,
+ s3_config_format: Optional[str] = None,
+ profile: Optional[str] = None,
+ **kwargs,
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
+ """
+ Get AWS credentials from connection login/password and extra.
+
+ ``aws_access_key_id`` and ``aws_secret_access_key`` order
+ 1. From Connection login and password
+ 2. From Connection extra['aws_access_key_id'] and extra['aws_access_key_id']
+ 3. (deprecated) From local credentials file
+
+ Get ``aws_session_token`` from extra['aws_access_key_id']
+
+ """
+ if self.login and self.password:
+ self.log.info("%s credentials retrieved from login and password.", self.conn_repr)
+ aws_access_key_id, aws_secret_access_key = self.login, self.password
+ elif aws_access_key_id and aws_secret_access_key:
+ self.log.info("%s credentials retrieved from extra.", self.conn_repr)
+ elif s3_config_file:
+ aws_access_key_id, aws_secret_access_key = _parse_s3_config(
+ s3_config_file,
+ s3_config_format,
+ profile,
+ )
+ self.log.info("%s credentials retrieved from extra['s3_config_file']", self.conn_repr)
+
+ if aws_session_token:
+ self.log.info(
+ "%s session token retrieved from extra, please note you are responsible for renewing these.",
+ self.conn_repr,
+ )
+
+ return aws_access_key_id, aws_secret_access_key, aws_session_token
+
+ def _get_assume_role_configs(
+ self,
+ *,
+ role_arn: Optional[str] = None,
+ assume_role_method: str = "assume_role",
+ assume_role_kwargs: Optional[Dict[str, Any]] = None,
+ # Deprecated Values
+ aws_account_id: Optional[str] = None,
+ aws_iam_role: Optional[str] = None,
+ external_id: Optional[str] = None,
+ **kwargs,
+ ) -> Tuple[Optional[str], Optional[str], Dict[Any, str]]:
+ """Get assume role configs from Connection extra."""
+ if role_arn:
+ self.log.info("Retrieving role_arn=%r from %s extra.", role_arn, self.conn_repr)
+ elif aws_account_id and aws_iam_role:
+ warnings.warn(
+ "Constructing 'role_arn' from extra['aws_account_id'] and extra['aws_iam_role'] is deprecated"
+ f" and will be removed in a future releases."
+ f" Please set 'role_arn' in {self.conn_repr} extra.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+ role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
+ self.log.info(
+ "Constructions role_arn=%r from %s extra['aws_account_id'] and extra['aws_iam_role'].",
+ role_arn,
+ self.conn_repr,
+ )
+
+ if not role_arn:
+ # There is no reason obtain `assume_role_method` and `assume_role_kwargs` if `role_arn` not set.
+ return None, None, {}
+
+ supported_methods = ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
+ if assume_role_method not in supported_methods:
+ raise NotImplementedError(
+ f'Found assume_role_method={assume_role_method!r} in {self.conn_repr} extra.'
+ f' Currently {supported_methods} are supported.'
+ ' (Exclude this setting will default to "assume_role").'
+ )
+ self.log.info("Retrieve assume_role_method=%r from %s.", assume_role_method, self.conn_repr)
+
+ assume_role_kwargs = assume_role_kwargs or {}
+ if "ExternalId" not in assume_role_kwargs and external_id:
+ warnings.warn(
+ "'external_id' in extra config is deprecated and will be removed in a future releases. "
+ f"Please set 'ExternalId' in 'assume_role_kwargs' in {self.conn_repr} extra.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+ assume_role_kwargs["ExternalId"] = external_id
+
+ return role_arn, assume_role_method, assume_role_kwargs
+
+
+def _parse_s3_config(
+ config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
+) -> Tuple[Optional[str], Optional[str]]:
+ """
+ Parses a config file for s3 credentials. Can currently
+ parse boto, s3cmd.conf and AWS SDK config formats
+
+ :param config_file_name: path to the config file
+ :param config_format: config type. One of "boto", "s3cmd" or "aws".
+ Defaults to "boto"
+ :param profile: profile name in AWS type config file
+ """
+ warnings.warn(
+ "Use local credentials file is never documented and well tested. "
+ "Obtain credentials by this way deprecated and will be removed in a future releases.",
+ DeprecationWarning,
+ stacklevel=4,
+ )
+
+ import configparser
+
+ config = configparser.ConfigParser()
+ if config.read(config_file_name): # pragma: no cover
+ sections = config.sections()
+ else:
+ raise AirflowException(f"Couldn't read {config_file_name}")
+ # Setting option names depending on file format
+ if config_format is None:
+ config_format = "boto"
+ conf_format = config_format.lower()
+ if conf_format == "boto": # pragma: no cover
+ if profile is not None and "profile " + profile in sections:
+ cred_section = "profile " + profile
+ else:
+ cred_section = "Credentials"
+ elif conf_format == "aws" and profile is not None:
+ cred_section = profile
+ else:
+ cred_section = "default"
+ # Option names
+ if conf_format in ("boto", "aws"): # pragma: no cover
+ key_id_option = "aws_access_key_id"
+ secret_key_option = "aws_secret_access_key"
+ else:
+ key_id_option = "access_key"
+ secret_key_option = "secret_key"
+ # Actual Parsing
+ if cred_section not in sections:
+ raise AirflowException("This config file format is not recognized")
+ else:
+ try:
+ access_key = config.get(cred_section, key_id_option)
+ secret_key = config.get(cred_section, secret_key_option)
+ except Exception:
+ raise AirflowException("Option Error in parsing s3 config file")
+ return access_key, secret_key
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index d1c3da9cdd..a04473b5b0 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import (
BaseSessionFactory,
resolve_session_factory,
)
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from tests.test_utils.config import conf_vars
try:
@@ -42,6 +43,9 @@ except ImportError:
mock_sts = None
mock_iam = None
+MOCK_AWS_CONN_ID = "mock-conn-id"
+MOCK_CONN_TYPE = "aws"
+
SAML_ASSERTION = """
<?xml version="1.0"?>
<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="_00000000-0000-0000-0000-000000000000" Version="2.0" IssueInstant="2012-01-01T12:00:00.000Z" Destination="https://signin.aws.amazon.com/saml" Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
@@ -114,7 +118,7 @@ class CustomSessionFactory(BaseSessionFactory):
return mock.MagicMock()
-class TestAwsBaseHook:
+class TestSessionFactory:
@conf_vars(
{("aws", "session_factory"): "tests.providers.amazon.aws.hooks.test_base_aws.CustomSessionFactory"}
)
@@ -131,6 +135,23 @@ class TestAwsBaseHook:
cls = resolve_session_factory()
assert issubclass(cls, BaseSessionFactory)
+ @pytest.mark.parametrize(
+ "mock_conn",
+ [
+ Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID),
+ AwsConnectionWrapper(conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID)),
+ ],
+ )
+ def test_conn_property(self, mock_conn):
+ sf = BaseSessionFactory(conn=mock_conn, region_name=None, config=None)
+ session_factory_conn = sf.conn
+ assert isinstance(session_factory_conn, AwsConnectionWrapper)
+ assert session_factory_conn.conn_id == MOCK_AWS_CONN_ID
+ assert session_factory_conn.conn_type == MOCK_CONN_TYPE
+ assert sf.conn is session_factory_conn
+
+
+class TestAwsBaseHook:
@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@mock_emr
def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
@@ -260,80 +281,6 @@ class TestAwsBaseHook:
assert table.item_count == 0
- @mock.patch.object(AwsBaseHook, 'get_connection')
- def test_get_credentials_from_login_with_token(self, mock_get_connection):
- mock_connection = Connection(
- login='aws_access_key_id',
- password='aws_secret_access_key',
- extra='{"aws_session_token": "test_token"}',
- )
- mock_get_connection.return_value = mock_connection
- hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
- credentials_from_hook = hook.get_credentials()
- assert credentials_from_hook.access_key == 'aws_access_key_id'
- assert credentials_from_hook.secret_key == 'aws_secret_access_key'
- assert credentials_from_hook.token == 'test_token'
-
- @mock.patch.object(AwsBaseHook, 'get_connection')
- def test_get_credentials_from_login_without_token(self, mock_get_connection):
- mock_connection = Connection(
- login='aws_access_key_id',
- password='aws_secret_access_key',
- )
-
- mock_get_connection.return_value = mock_connection
- hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam')
- credentials_from_hook = hook.get_credentials()
- assert credentials_from_hook.access_key == 'aws_access_key_id'
- assert credentials_from_hook.secret_key == 'aws_secret_access_key'
- assert credentials_from_hook.token is None
-
- @mock.patch.object(AwsBaseHook, 'get_connection')
- def test_get_credentials_from_extra_with_token(self, mock_get_connection):
- mock_connection = Connection(
- extra='{"aws_access_key_id": "aws_access_key_id",'
- '"aws_secret_access_key": "aws_secret_access_key",'
- ' "aws_session_token": "session_token"}'
- )
- mock_get_connection.return_value = mock_connection
- hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
- credentials_from_hook = hook.get_credentials()
- assert credentials_from_hook.access_key == 'aws_access_key_id'
- assert credentials_from_hook.secret_key == 'aws_secret_access_key'
- assert credentials_from_hook.token == 'session_token'
-
- @mock.patch.object(AwsBaseHook, 'get_connection')
- def test_get_credentials_from_extra_without_token(self, mock_get_connection):
- mock_connection = Connection(
- extra='{"aws_access_key_id": "aws_access_key_id",'
- '"aws_secret_access_key": "aws_secret_access_key"}'
- )
- mock_get_connection.return_value = mock_connection
- hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
- credentials_from_hook = hook.get_credentials()
- assert credentials_from_hook.access_key == 'aws_access_key_id'
- assert credentials_from_hook.secret_key == 'aws_secret_access_key'
- assert credentials_from_hook.token is None
-
- @mock.patch(
- 'airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config',
- return_value=('aws_access_key_id', 'aws_secret_access_key'),
- )
- @mock.patch.object(AwsBaseHook, 'get_connection')
- def test_get_credentials_from_extra_with_s3_config_and_profile(
- self, mock_get_connection, mock_parse_s3_config
- ):
- mock_connection = Connection(
- extra='{"s3_config_format": "aws", '
- '"profile": "test", '
- '"s3_config_file": "aws-credentials", '
- '"region_name": "us-east-1"}'
- )
- mock_get_connection.return_value = mock_connection
- hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
- hook._get_credentials(region_name=None)
- mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws', 'test')
-
@unittest.skipIf(mock_sts is None, 'mock_sts package not present')
@mock.patch.object(AwsBaseHook, 'get_connection')
@mock_sts
diff --git a/tests/providers/amazon/aws/utils/test_connection_wrapper.py b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
new file mode 100644
index 0000000000..74696db98e
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
@@ -0,0 +1,283 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional
+from unittest import mock
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
+
+MOCK_AWS_CONN_ID = "mock-conn-id"
+MOCK_CONN_TYPE = "aws"
+MOCK_ROLE_ARN = "arn:aws:iam::222222222222:role/awesome-role"
+
+
+def mock_connection_factory(
+ conn_id: Optional[str] = MOCK_AWS_CONN_ID, conn_type: Optional[str] = MOCK_CONN_TYPE, **kwargs
+) -> Connection:
+ return Connection(conn_id=conn_id, conn_type=conn_type, **kwargs)
+
+
+class TestAwsConnectionWrapper:
+ @pytest.mark.parametrize("extra", [{"foo": "bar", "spam": "egg"}, '{"foo": "bar", "spam": "egg"}', None])
+ def test_values_from_connection(self, extra):
+ mock_conn = mock_connection_factory(
+ login="mock-login",
+ password="mock-password",
+ extra=extra,
+ # AwsBaseHook never use this attributes from airflow.models.Connection
+ host="mock-host",
+ schema="mock-schema",
+ port=42,
+ )
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+
+ assert wrap_conn.conn_id == mock_conn.conn_id
+ assert wrap_conn.conn_type == mock_conn.conn_type
+ assert wrap_conn.login == mock_conn.login
+ assert wrap_conn.password == mock_conn.password
+
+ # Check that original extra config from connection persists in wrapper
+ assert wrap_conn.extra_config == mock_conn.extra_dejson
+ assert wrap_conn.extra_config is not mock_conn.extra_dejson
+ # `extra_config` is a same object that return by `extra_dejson`
+ assert wrap_conn.extra_config is wrap_conn.extra_dejson
+
+ # Check that not assigned other attributes from airflow.models.Connection to wrapper
+ assert not hasattr(wrap_conn, "host")
+ assert not hasattr(wrap_conn, "schema")
+ assert not hasattr(wrap_conn, "port")
+
+ # Check that Wrapper is True if assign connection
+ assert wrap_conn
+
+ def test_no_connection(self):
+ assert not AwsConnectionWrapper(conn=None)
+
+ @pytest.mark.parametrize("conn_type", ["aws", None])
+ def test_expected_aws_connection_type(self, conn_type):
+ wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory(conn_type=conn_type))
+ assert wrap_conn.conn_type == "aws"
+
+ @pytest.mark.parametrize("conn_type", ["AWS", "boto3", "s3", "emr", "google", "google-cloud-platform"])
+ def test_unexpected_aws_connection_type(self, conn_type):
+ warning_message = f"expected connection type 'aws', got '{conn_type}'"
+ with pytest.warns(UserWarning, match=warning_message):
+ wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory(conn_type=conn_type))
+ assert wrap_conn.conn_type == conn_type
+
+ @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"])
+ @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"])
+ @pytest.mark.parametrize("aws_access_key_id", ["mock-aws-access-key-id"])
+ def test_get_credentials_from_login(self, aws_access_key_id, aws_secret_access_key, aws_session_token):
+ mock_conn = mock_connection_factory(
+ login=aws_access_key_id,
+ password=aws_secret_access_key,
+ extra={"aws_session_token": aws_session_token} if aws_session_token else None,
+ )
+
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.aws_access_key_id == aws_access_key_id
+ assert wrap_conn.aws_secret_access_key == aws_secret_access_key
+ assert wrap_conn.aws_session_token == aws_session_token
+
+ @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"])
+ @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"])
+ @pytest.mark.parametrize("aws_access_key_id", ["mock-aws-access-key-id"])
+ def test_get_credentials_from_extra(self, aws_access_key_id, aws_secret_access_key, aws_session_token):
+ mock_conn_extra = {
+ "aws_access_key_id": aws_access_key_id,
+ "aws_secret_access_key": aws_secret_access_key,
+ }
+ if aws_session_token:
+ mock_conn_extra["aws_session_token"] = aws_session_token
+ mock_conn = mock_connection_factory(login=None, password=None, extra=mock_conn_extra)
+
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.aws_access_key_id == aws_access_key_id
+ assert wrap_conn.aws_secret_access_key == aws_secret_access_key
+ assert wrap_conn.aws_session_token == aws_session_token
+
+ # This function never tested and mark as deprecated. Only test expected output
+ @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper._parse_s3_config")
+ @pytest.mark.parametrize("aws_session_token", [None, "mock-aws-session-token"])
+ @pytest.mark.parametrize("aws_secret_access_key", ["mock-aws-secret-access-key"])
+ @pytest.mark.parametrize("aws_access_key_id", ["mock-aws-access-key-id"])
+ def test_get_credentials_from_s3_config(
+ self, mock_parse_s3_config, aws_access_key_id, aws_secret_access_key, aws_session_token
+ ):
+ mock_parse_s3_config.return_value = (aws_access_key_id, aws_secret_access_key)
+ mock_conn_extra = {
+ "s3_config_format": "aws",
+ "profile": "test",
+ "s3_config_file": "aws-credentials",
+ }
+ if aws_session_token:
+ mock_conn_extra["aws_session_token"] = aws_session_token
+ mock_conn = mock_connection_factory(login=None, password=None, extra=mock_conn_extra)
+
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws', 'test')
+ assert wrap_conn.aws_access_key_id == aws_access_key_id
+ assert wrap_conn.aws_secret_access_key == aws_secret_access_key
+ assert wrap_conn.aws_session_token == aws_session_token
+
+ @pytest.mark.parametrize("region_name", [None, "mock-aws-region"])
+ def test_get_region_name(self, region_name):
+ mock_conn = mock_connection_factory(extra={"region_name": region_name} if region_name else None)
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.region_name == region_name
+
+ @pytest.mark.parametrize("session_kwargs", [None, {"profile_name": "mock-profile"}])
+ def test_get_session_kwargs(self, session_kwargs):
+ mock_conn = mock_connection_factory(
+ extra={"session_kwargs": session_kwargs} if session_kwargs else None
+ )
+ expected = session_kwargs or {}
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.session_kwargs == expected
+
+ def test_warn_wrong_profile_param_used(self):
+ mock_conn = mock_connection_factory(extra={"profile": "mock-profile"})
+ warning_message = "Found 'profile' without specifying 's3_config_file' in .* set 'profile_name' in"
+ with pytest.warns(UserWarning, match=warning_message):
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert "profile_name" not in wrap_conn.session_kwargs
+
+ @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper.Config", autospec=True)
+ @pytest.mark.parametrize("botocore_config_kwargs", [None, {"user_agent": "Airflow Amazon Provider"}])
+ def test_get_botocore_config(self, mock_botocore_config, botocore_config_kwargs):
+ mock_conn = mock_connection_factory(
+ extra={"config_kwargs": botocore_config_kwargs} if botocore_config_kwargs else None
+ )
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+
+ if not botocore_config_kwargs:
+ assert not mock_botocore_config.called
+ assert wrap_conn.botocore_config is None
+ else:
+ assert mock_botocore_config.called
+ assert mock_botocore_config.call_count == 1
+ assert mock.call(**botocore_config_kwargs) in mock_botocore_config.mock_calls
+
+ @pytest.mark.parametrize("endpoint_url", [None, "https://example.org"])
+ def test_get_endpoint_url(self, endpoint_url):
+ mock_conn = mock_connection_factory(extra={"host": endpoint_url} if endpoint_url else None)
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.endpoint_url == endpoint_url
+
+ @pytest.mark.parametrize("aws_account_id, aws_iam_role", [(None, None), ("111111111111", "another-role")])
+ def test_get_role_arn(self, aws_account_id, aws_iam_role):
+ mock_conn = mock_connection_factory(
+ extra={
+ "role_arn": MOCK_ROLE_ARN,
+ "aws_account_id": aws_account_id,
+ "aws_iam_role": aws_iam_role,
+ }
+ )
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.role_arn == MOCK_ROLE_ARN
+
+ @pytest.mark.parametrize(
+ "aws_account_id, aws_iam_role, expected",
+ [
+ ("222222222222", "mock-role", "arn:aws:iam::222222222222:role/mock-role"),
+ ("333333333333", "role-path/mock-role", "arn:aws:iam::333333333333:role/role-path/mock-role"),
+ ],
+ )
+ def test_constructing_role_arn(self, aws_account_id, aws_iam_role, expected):
+ mock_conn = mock_connection_factory(
+ extra={
+ "aws_account_id": aws_account_id,
+ "aws_iam_role": aws_iam_role,
+ }
+ )
+ with pytest.warns(DeprecationWarning, match="Please set 'role_arn' in .* extra"):
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.role_arn == expected
+
+ def test_empty_role_arn(self):
+ wrap_conn = AwsConnectionWrapper(conn=mock_connection_factory())
+ assert wrap_conn.role_arn is None
+ assert wrap_conn.assume_role_method is None
+ assert wrap_conn.assume_role_kwargs == {}
+
+ @pytest.mark.parametrize(
+ "assume_role_method", ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
+ )
+ def test_get_assume_role_method(self, assume_role_method):
+ mock_conn = mock_connection_factory(
+ extra={"role_arn": MOCK_ROLE_ARN, "assume_role_method": assume_role_method}
+ )
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.assume_role_method == assume_role_method
+
+ def test_default_assume_role_method(self):
+ mock_conn = mock_connection_factory(
+ extra={
+ "role_arn": MOCK_ROLE_ARN,
+ }
+ )
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert wrap_conn.assume_role_method == "assume_role"
+
+ def test_unsupported_assume_role_method(self):
+ mock_conn = mock_connection_factory(
+ extra={"role_arn": MOCK_ROLE_ARN, "assume_role_method": "dummy_method"}
+ )
+ with pytest.raises(NotImplementedError, match="Found assume_role_method='dummy_method' in .* extra"):
+ AwsConnectionWrapper(conn=mock_conn)
+
+ @pytest.mark.parametrize("assume_role_kwargs", [None, {"DurationSeconds": 42}])
+ def test_get_assume_role_kwargs(self, assume_role_kwargs):
+ mock_conn_extra = {"role_arn": MOCK_ROLE_ARN}
+ if assume_role_kwargs:
+ mock_conn_extra["assume_role_kwargs"] = assume_role_kwargs
+ mock_conn = mock_connection_factory(extra=mock_conn_extra)
+
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ expected = assume_role_kwargs or {}
+ assert wrap_conn.assume_role_kwargs == expected
+
+ @pytest.mark.parametrize("external_id_in_extra", [None, "mock-external-id-in-extra"])
+ def test_get_assume_role_kwargs_external_id_in_kwargs(self, external_id_in_extra):
+ mock_external_id_in_kwargs = "mock-external-id-in-kwargs"
+ mock_conn_extra = {
+ "role_arn": MOCK_ROLE_ARN,
+ "assume_role_kwargs": {"ExternalId": mock_external_id_in_kwargs},
+ }
+ if external_id_in_extra:
+ mock_conn_extra["external_id"] = external_id_in_extra
+ mock_conn = mock_connection_factory(extra=mock_conn_extra)
+
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert "ExternalId" in wrap_conn.assume_role_kwargs
+ assert wrap_conn.assume_role_kwargs["ExternalId"] == mock_external_id_in_kwargs
+ assert wrap_conn.assume_role_kwargs["ExternalId"] != external_id_in_extra
+
+ def test_get_assume_role_kwargs_external_id_in_extra(self):
+ mock_external_id_in_extra = "mock-external-id-in-extra"
+ mock_conn_extra = {"role_arn": MOCK_ROLE_ARN, "external_id": mock_external_id_in_extra}
+ mock_conn = mock_connection_factory(extra=mock_conn_extra)
+
+ warning_message = "Please set 'ExternalId' in 'assume_role_kwargs' in .* extra."
+ with pytest.warns(DeprecationWarning, match=warning_message):
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn)
+ assert "ExternalId" in wrap_conn.assume_role_kwargs
+ assert wrap_conn.assume_role_kwargs["ExternalId"] == mock_external_id_in_extra