You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/05 15:43:21 UTC
[airflow] branch main updated: Resolve Amazon Hook's `region_name` and `config` in wrapper (#25336)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4193558e80 Resolve Amazon Hook's `region_name` and `config` in wrapper (#25336)
4193558e80 is described below
commit 4193558e808af0d0eac0636b4bb6f88606ca54c6
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Aug 5 19:43:09 2022 +0400
Resolve Amazon Hook's `region_name` and `config` in wrapper (#25336)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 111 ++++++++++++---------
.../amazon/aws/utils/connection_wrapper.py | 64 ++++++++++--
tests/providers/amazon/aws/hooks/test_base_aws.py | 104 ++++++++++++++++++-
.../amazon/aws/utils/test_connection_wrapper.py | 100 ++++++++++++++++---
4 files changed, 306 insertions(+), 73 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index f0de49556d..1132892e73 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -44,7 +44,7 @@ from slugify import slugify
from airflow.compat.functools import cached_property
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
@@ -66,19 +66,24 @@ class BaseSessionFactory(LoggingMixin):
"""
def __init__(
- self, conn: Union[Connection, AwsConnectionWrapper], region_name: Optional[str], config: Config
+ self,
+ conn: Optional[Union[Connection, AwsConnectionWrapper]],
+ region_name: Optional[str] = None,
+ config: Optional[Config] = None,
) -> None:
super().__init__()
self._conn = conn
self._region_name = region_name
- self.config = config
+ self._config = config
@cached_property
def conn(self) -> AwsConnectionWrapper:
"""Cached AWS Connection Wrapper."""
- if isinstance(self._conn, AwsConnectionWrapper):
- return self._conn
- return AwsConnectionWrapper(self._conn)
+ return AwsConnectionWrapper(
+ conn=self._conn,
+ region_name=self._region_name,
+ botocore_config=self._config,
+ )
@cached_property
def basic_session(self) -> boto3.session.Session:
@@ -92,12 +97,13 @@ class BaseSessionFactory(LoggingMixin):
@property
def region_name(self) -> Optional[str]:
- """Resolve region name.
+ """AWS Region Name read-only property."""
+ return self.conn.region_name
- 1. SessionFactory region_name
- 2. Connection region_name
- """
- return self._region_name or self.conn.region_name
+ @property
+ def config(self) -> Optional[Config]:
+ """Configuration for botocore client read-only property."""
+ return self.conn.botocore_config
@property
def role_arn(self) -> Optional[str]:
@@ -105,8 +111,15 @@ class BaseSessionFactory(LoggingMixin):
return self.conn.role_arn
def create_session(self) -> boto3.session.Session:
- """Create AWS session."""
- if not self.role_arn:
+ """Create boto3 Session from connection config."""
+ if not self.conn:
+ self.log.info(
+ "No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
+ "See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
+ self.region_name,
+ )
+ return boto3.session.Session(region_name=self.region_name)
+ elif not self.role_arn:
return self.basic_session
return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)
@@ -381,45 +394,50 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
self.verify = verify
self.client_type = client_type
self.resource_type = resource_type
- self.region_name = region_name
- self.config = config
-
- def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
-
- if not self.aws_conn_id:
- session = boto3.session.Session(region_name=region_name)
- return session, None
+ self._region_name = region_name
+ self._config = config
- self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
+ @cached_property
+ def conn_config(self) -> AwsConnectionWrapper:
+ """Get the Airflow Connection object and wrap it in helper (cached)."""
+ connection = None
+ if self.aws_conn_id:
+ try:
+ connection = self.get_connection(self.aws_conn_id)
+ except AirflowNotFoundException:
+ warnings.warn(
+ f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. "
+ "This behaviour is deprecated and will be removed in a future releases. "
+ "Please provide existed AWS connection ID or if required boto3 credential strategy "
+ "explicit set AWS Connection ID to None.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- try:
- # Fetch the Airflow connection object and wrap it in helper
- connection_object = AwsConnectionWrapper(self.get_connection(self.aws_conn_id))
+ return AwsConnectionWrapper(
+ conn=connection or Connection(conn_id=None, conn_type="aws"),
+ region_name=self._region_name,
+ botocore_config=self._config,
+ )
- if connection_object.botocore_config:
- # For historical reason botocore.config.Config from connection overwrites
- # config which explicitly set in Hook.
- self.config = connection_object.botocore_config
+ @property
+ def region_name(self) -> Optional[str]:
+ """AWS Region Name read-only property."""
+ return self.conn_config.region_name
- session = SessionFactory(
- conn=connection_object, region_name=region_name, config=self.config
- ).create_session()
+ @property
+ def config(self) -> Optional[Config]:
+ """Configuration for botocore client read-only property."""
+ return self.conn_config.botocore_config
- return session, connection_object.endpoint_url
+ def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
+ self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
- except AirflowException:
- self.log.warning(
- "Unable to use Airflow Connection for credentials. "
- "Fallback on boto3 credential strategy. See: "
- "https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html"
- )
+ session = SessionFactory(
+ conn=self.conn_config, region_name=region_name, config=self.config
+ ).create_session()
- self.log.debug(
- "Creating session using boto3 credential strategy region_name=%s",
- region_name,
- )
- session = boto3.session.Session(region_name=region_name)
- return session, None
+ return session, self.conn_config.endpoint_url
def get_client_type(
self,
@@ -491,6 +509,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
@cached_property
def conn_client_meta(self) -> ClientMeta:
+ """Get botocore client metadata from Hook connection (cached)."""
conn = self.conn
if isinstance(conn, botocore.client.BaseClient):
return conn.meta
@@ -498,10 +517,12 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
@property
def conn_region_name(self) -> str:
+ """Get actual AWS Region Name from Hook connection (cached)."""
return self.conn_client_meta.region_name
@property
def conn_partition(self) -> str:
+ """Get associated AWS Region Partition from Hook connection (cached)."""
return self.conn_client_meta.partition
def get_conn(self) -> BaseAwsConnection:
diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py
index 6672b971f6..cf7c33a823 100644
--- a/airflow/providers/amazon/aws/utils/connection_wrapper.py
+++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -17,8 +17,8 @@
import warnings
from copy import deepcopy
-from dataclasses import InitVar, dataclass, field
-from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+from dataclasses import MISSING, InitVar, dataclass, field, fields
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from botocore.config import Config
@@ -35,25 +35,43 @@ class AwsConnectionWrapper(LoggingMixin):
"""
AWS Connection Wrapper class helper.
Use for validate and resolve AWS Connection parameters.
+
+ ``conn`` reference to Airflow Connection object or AwsConnectionWrapper
+ if it set to ``None`` than default values would use.
+
+ The precedence rules for ``region_name``
+ 1. Explicit set (in Hook) ``region_name``.
+ 2. Airflow Connection Extra 'region_name'.
+
+ The precedence rules for ``botocore_config``
+ 1. Explicit set (in Hook) ``botocore_config``.
+ 2. Construct from Airflow Connection Extra 'botocore_kwargs'.
+ 3. The wrapper's default value
"""
- conn: InitVar[Optional["Connection"]]
+ conn: InitVar[Optional[Union["Connection", "AwsConnectionWrapper"]]]
+ region_name: Optional[str] = field(default=None)
+ botocore_config: Optional[Config] = field(default=None)
+ # Reference to Airflow Connection attributes
+ # ``extra_config`` contains original Airflow Connection Extra.
conn_id: Optional[str] = field(init=False, default=None)
conn_type: Optional[str] = field(init=False, default=None)
login: Optional[str] = field(init=False, repr=False, default=None)
password: Optional[str] = field(init=False, repr=False, default=None)
extra_config: Dict[str, Any] = field(init=False, repr=False, default_factory=dict)
- aws_access_key_id: Optional[str] = field(init=False)
- aws_secret_access_key: Optional[str] = field(init=False)
- aws_session_token: Optional[str] = field(init=False)
+ # AWS Credentials from connection.
+ aws_access_key_id: Optional[str] = field(init=False, default=None)
+ aws_secret_access_key: Optional[str] = field(init=False, default=None)
+ aws_session_token: Optional[str] = field(init=False, default=None)
- region_name: Optional[str] = field(init=False, default=None)
+ # Additional boto3.session.Session keyword arguments.
session_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
- botocore_config: Optional[Config] = field(init=False, default=None)
+ # Custom endpoint_url for boto3.client and boto3.resource
endpoint_url: Optional[str] = field(init=False, default=None)
+ # Assume Role Configurations
role_arn: Optional[str] = field(init=False, default=None)
assume_role_method: Optional[str] = field(init=False, default=None)
assume_role_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
@@ -63,7 +81,30 @@ class AwsConnectionWrapper(LoggingMixin):
return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"
def __post_init__(self, conn: "Connection"):
- if not conn:
+ if isinstance(conn, type(self)):
+ # For every field with init=False we copy reference value from original wrapper
+ # For every field with init=True we use init values if it not equal default
+ # We can't use ``dataclasses.replace`` in classmethod because
+ # we limited by InitVar arguments since it not stored in object,
+ # and also we do not want to run __post_init__ method again which print all logs/warnings again.
+ for fl in fields(conn):
+ value = getattr(conn, fl.name)
+ if not fl.init:
+ setattr(self, fl.name, value)
+ else:
+ if fl.default is not MISSING:
+ default = fl.default
+ elif fl.default_factory is not MISSING:
+ default = fl.default_factory() # zero-argument callable
+ else:
+ continue # Value mandatory, skip
+
+ orig_value = getattr(self, fl.name)
+ if orig_value == default:
+ # Only replace value if it not equal default value
+ setattr(self, fl.name, value)
+ return
+ elif not conn:
return
extra = deepcopy(conn.extra_dejson)
@@ -86,7 +127,7 @@ class AwsConnectionWrapper(LoggingMixin):
init_credentials = self._get_credentials(**extra)
self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials
- if "region_name" in extra:
+ if not self.region_name and "region_name" in extra:
self.region_name = extra["region_name"]
self.log.info("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)
@@ -106,7 +147,7 @@ class AwsConnectionWrapper(LoggingMixin):
)
config_kwargs = extra.get("config_kwargs")
- if config_kwargs:
+ if not self.botocore_config and config_kwargs:
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
self.log.info("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
self.botocore_config = Config(**config_kwargs)
@@ -119,6 +160,7 @@ class AwsConnectionWrapper(LoggingMixin):
@property
def extra_dejson(self):
+ """Compatibility with `airflow.models.Connection.extra_dejson` property."""
return self.extra_config
def __bool__(self):
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index a04473b5b0..f2d55cd3bb 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -17,6 +17,7 @@
# under the License.
#
import json
+import os
import unittest
from base64 import b64encode
from datetime import datetime, timedelta, timezone
@@ -24,6 +25,7 @@ from unittest import mock
import boto3
import pytest
+from botocore.config import Config
from moto.core import ACCOUNT_ID
from airflow.models import Connection
@@ -45,6 +47,8 @@ except ImportError:
MOCK_AWS_CONN_ID = "mock-conn-id"
MOCK_CONN_TYPE = "aws"
+MOCK_BOTO3_SESSION = mock.MagicMock(return_value="Mock boto3.session.Session")
+
SAML_ASSERTION = """
<?xml version="1.0"?>
@@ -150,6 +154,70 @@ class TestSessionFactory:
assert session_factory_conn.conn_type == MOCK_CONN_TYPE
assert sf.conn is session_factory_conn
+ def test_empty_conn_property(self):
+ sf = BaseSessionFactory(conn=None, region_name=None, config=None)
+ assert isinstance(sf.conn, AwsConnectionWrapper)
+
+ @pytest.mark.parametrize(
+ "region_name,conn_region_name",
+ [
+ ("eu-west-1", "cn-north-1"),
+ ("eu-west-1", None),
+ (None, "cn-north-1"),
+ (None, None),
+ ],
+ )
+ def test_resolve_region_name(self, region_name, conn_region_name):
+ conn = AwsConnectionWrapper(
+ conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID),
+ region_name=conn_region_name,
+ )
+ sf = BaseSessionFactory(conn=conn, region_name=region_name, config=None)
+ expected = region_name or conn_region_name
+ assert sf.region_name == expected
+
+ @pytest.mark.parametrize(
+ "botocore_config, conn_botocore_config",
+ [
+ (Config(s3={"us_east_1_regional_endpoint": "regional"}), None),
+ (Config(s3={"us_east_1_regional_endpoint": "regional"}), Config(region_name="ap-southeast-1")),
+ (None, Config(region_name="ap-southeast-1")),
+ (None, None),
+ ],
+ )
+ def test_resolve_botocore_config(self, botocore_config, conn_botocore_config):
+ conn = AwsConnectionWrapper(
+ conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID),
+ botocore_config=conn_botocore_config,
+ )
+ sf = BaseSessionFactory(conn=conn, config=botocore_config)
+ expected = botocore_config or conn_botocore_config
+ assert sf.config == expected
+
+ @pytest.mark.parametrize("region_name", ["eu-central-1", None])
+ @mock.patch("boto3.session.Session", new_callable=mock.PropertyMock, return_value=MOCK_BOTO3_SESSION)
+ def test_create_session_boto3_credential_strategy(self, mock_boto3_session, region_name, caplog):
+ sf = BaseSessionFactory(conn=AwsConnectionWrapper(conn=None), region_name=region_name, config=None)
+ session = sf.create_session()
+ mock_boto3_session.assert_called_once_with(region_name=region_name)
+ assert session == MOCK_BOTO3_SESSION
+ logging_message = "No connection ID provided. Fallback on boto3 credential strategy"
+ assert any(logging_message in log_text for log_text in caplog.messages)
+
+ @pytest.mark.parametrize("region_name", ["eu-central-1", None])
+ @mock.patch("boto3.session.Session", new_callable=mock.PropertyMock, return_value=MOCK_BOTO3_SESSION)
+ def test_create_session_from_credentials(self, mock_boto3_session, region_name):
+ mock_conn = AwsConnectionWrapper(conn=Connection(conn_type=MOCK_CONN_TYPE, conn_id=MOCK_AWS_CONN_ID))
+ sf = BaseSessionFactory(conn=mock_conn, region_name=region_name, config=None)
+ session = sf.create_session()
+ mock_boto3_session.assert_called_once_with(
+ aws_access_key_id=mock_conn.aws_access_key_id,
+ aws_secret_access_key=mock_conn.aws_secret_access_key,
+ aws_session_token=mock_conn.aws_session_token,
+ region_name=region_name,
+ )
+ assert session == MOCK_BOTO3_SESSION
+
class TestAwsBaseHook:
@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@@ -341,7 +409,11 @@ class TestAwsBaseHook:
@mock.patch.object(AwsBaseHook, 'get_connection')
@mock_sts
def test_get_credentials_from_role_arn(self, mock_get_connection):
- mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}')
+ mock_connection = Connection(
+ conn_id='aws_default',
+ conn_type=MOCK_CONN_TYPE,
+ extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}',
+ )
mock_get_connection.return_value = mock_connection
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
credentials_from_hook = hook.get_credentials()
@@ -445,6 +517,7 @@ class TestAwsBaseHook:
duration_seconds = 901
mock_connection = Connection(
+ conn_id=MOCK_AWS_CONN_ID,
extra=json.dumps(
{
"role_arn": role_arn,
@@ -459,7 +532,7 @@ class TestAwsBaseHook:
},
"assume_role_kwargs": {"DurationSeconds": duration_seconds},
}
- )
+ ),
)
mock_get_connection.return_value = mock_connection
@@ -671,6 +744,33 @@ class TestAwsBaseHook:
assert result
assert hook.client_type == "s3" # Same client_type which defined during initialisation
+ @mock.patch.dict(os.environ, {f"AIRFLOW_CONN_{MOCK_AWS_CONN_ID.upper()}": "aws://"})
+ def test_conn_config_conn_id_exists(self):
+ """Test retrieve connection config if aws_conn_id exists."""
+ hook = AwsBaseHook(aws_conn_id=MOCK_AWS_CONN_ID)
+ conn_config_exist = hook.conn_config
+ assert conn_config_exist is hook.conn_config, "Expected cached Connection Config"
+ assert isinstance(conn_config_exist, AwsConnectionWrapper)
+ assert conn_config_exist
+
+ @pytest.mark.parametrize("aws_conn_id", ["", None], ids=["empty", "None"])
+ def test_conn_config_conn_id_empty(self, aws_conn_id):
+ """Test retrieve connection config if aws_conn_id empty or None."""
+ conn_config_empty = AwsBaseHook(aws_conn_id=aws_conn_id).conn_config
+ assert isinstance(conn_config_empty, AwsConnectionWrapper)
+ assert not conn_config_empty
+
+ def test_conn_config_conn_id_not_exists(self):
+ """Test fallback connection config if aws_conn_id not exists."""
+ warning_message = (
+ r"Unable to find AWS Connection ID '.*', switching to empty\. "
+ r"This behaviour is deprecated and will be removed in a future releases"
+ )
+ with pytest.warns(DeprecationWarning, match=warning_message):
+ conn_config_fallback_not_exists = AwsBaseHook(aws_conn_id="aws-conn-not-exists").conn_config
+ assert isinstance(conn_config_fallback_not_exists, AwsConnectionWrapper)
+ assert not conn_config_fallback_not_exists
+
class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/utils/test_connection_wrapper.py b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
index 74696db98e..47ed16fa73 100644
--- a/tests/providers/amazon/aws/utils/test_connection_wrapper.py
+++ b/tests/providers/amazon/aws/utils/test_connection_wrapper.py
@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
+from dataclasses import fields
from typing import Optional
from unittest import mock
import pytest
+from botocore.config import Config
from airflow.models import Connection
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
@@ -138,11 +140,24 @@ class TestAwsConnectionWrapper:
assert wrap_conn.aws_secret_access_key == aws_secret_access_key
assert wrap_conn.aws_session_token == aws_session_token
- @pytest.mark.parametrize("region_name", [None, "mock-aws-region"])
- def test_get_region_name(self, region_name):
- mock_conn = mock_connection_factory(extra={"region_name": region_name} if region_name else None)
- wrap_conn = AwsConnectionWrapper(conn=mock_conn)
- assert wrap_conn.region_name == region_name
+ @pytest.mark.parametrize(
+ "region_name,conn_region_name",
+ [
+ ("mock-region-name", None),
+ ("mock-region-name", "mock-connection-region-name"),
+ (None, "mock-connection-region-name"),
+ (None, None),
+ ],
+ )
+ def test_get_region_name(self, region_name, conn_region_name):
+ mock_conn = mock_connection_factory(
+ extra={"region_name": conn_region_name} if conn_region_name else None
+ )
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn, region_name=region_name)
+ if region_name:
+ assert wrap_conn.region_name == region_name, "Expected provided region_name"
+ else:
+ assert wrap_conn.region_name == conn_region_name, "Expected connection region_name"
@pytest.mark.parametrize("session_kwargs", [None, {"profile_name": "mock-profile"}])
def test_get_session_kwargs(self, session_kwargs):
@@ -160,20 +175,30 @@ class TestAwsConnectionWrapper:
wrap_conn = AwsConnectionWrapper(conn=mock_conn)
assert "profile_name" not in wrap_conn.session_kwargs
- @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper.Config", autospec=True)
- @pytest.mark.parametrize("botocore_config_kwargs", [None, {"user_agent": "Airflow Amazon Provider"}])
- def test_get_botocore_config(self, mock_botocore_config, botocore_config_kwargs):
+ @mock.patch("airflow.providers.amazon.aws.utils.connection_wrapper.Config")
+ @pytest.mark.parametrize(
+ "botocore_config, botocore_config_kwargs",
+ [
+ (Config(s3={"us_east_1_regional_endpoint": "regional"}), None),
+ (Config(region_name="ap-southeast-1"), {"user_agent": "Airflow Amazon Provider"}),
+ (None, {"user_agent": "Airflow Amazon Provider"}),
+ (None, None),
+ ],
+ )
+ def test_get_botocore_config(self, mock_botocore_config, botocore_config, botocore_config_kwargs):
mock_conn = mock_connection_factory(
extra={"config_kwargs": botocore_config_kwargs} if botocore_config_kwargs else None
)
- wrap_conn = AwsConnectionWrapper(conn=mock_conn)
-
- if not botocore_config_kwargs:
- assert not mock_botocore_config.called
- assert wrap_conn.botocore_config is None
+ wrap_conn = AwsConnectionWrapper(conn=mock_conn, botocore_config=botocore_config)
+
+ if botocore_config:
+ assert wrap_conn.botocore_config == botocore_config, "Expected provided botocore_config"
+ assert mock_botocore_config.assert_not_called
+ elif not botocore_config_kwargs:
+ assert wrap_conn.botocore_config is None, "Expected default botocore_config"
+ assert mock_botocore_config.assert_not_called
else:
- assert mock_botocore_config.called
- assert mock_botocore_config.call_count == 1
+ assert mock_botocore_config.assert_called_once
assert mock.call(**botocore_config_kwargs) in mock_botocore_config.mock_calls
@pytest.mark.parametrize("endpoint_url", [None, "https://example.org"])
@@ -281,3 +306,48 @@ class TestAwsConnectionWrapper:
wrap_conn = AwsConnectionWrapper(conn=mock_conn)
assert "ExternalId" in wrap_conn.assume_role_kwargs
assert wrap_conn.assume_role_kwargs["ExternalId"] == mock_external_id_in_extra
+
+ @pytest.mark.parametrize(
+ "orig_wrapper",
+ [
+ AwsConnectionWrapper(
+ conn=mock_connection_factory(
+ login="mock-login",
+ password="mock-password",
+ extra={
+ "region_name": "mock-region",
+ "botocore_kwargs": {"user_agent": "Airflow Amazon Provider"},
+ "role_arn": MOCK_ROLE_ARN,
+ "aws_session_token": "mock-aws-session-token",
+ },
+ ),
+ ),
+ AwsConnectionWrapper(conn=mock_connection_factory()),
+ AwsConnectionWrapper(conn=None),
+ AwsConnectionWrapper(
+ conn=None,
+ region_name="mock-region",
+ botocore_config=Config(user_agent="Airflow Amazon Provider"),
+ ),
+ ],
+ )
+ @pytest.mark.parametrize("region_name", [None, "ca-central-1"])
+ @pytest.mark.parametrize("botocore_config", [None, Config(region_name="ap-southeast-1")])
+ def test_wrap_wrapper(self, orig_wrapper, region_name, botocore_config):
+ wrap_kwargs = {}
+ if region_name:
+ wrap_kwargs["region_name"] = region_name
+ if botocore_config:
+ wrap_kwargs["botocore_config"] = botocore_config
+ wrap_conn = AwsConnectionWrapper(conn=orig_wrapper, **wrap_kwargs)
+
+ # Non init fields should be same in orig_wrapper and child wrapper
+ wrap_non_init_fields = [f.name for f in fields(wrap_conn) if not f.init]
+ for field in wrap_non_init_fields:
+ assert getattr(wrap_conn, field) == getattr(
+ orig_wrapper, field
+ ), "Expected no changes in non-init values"
+
+ # Test overwrite/inherit init fields
+ assert wrap_conn.region_name == (region_name or orig_wrapper.region_name)
+ assert wrap_conn.botocore_config == (botocore_config or orig_wrapper.botocore_config)