You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/06/28 10:57:02 UTC
[airflow] branch main updated: AWS Hook - allow IDP HTTP retry
(#12639) (#16612)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0d80383 AWS Hook - allow IDP HTTP retry (#12639) (#16612)
0d80383 is described below
commit 0d80383bdd506c2eff8ef29d0ff461620a966f86
Author: Bjorn Olsen <bj...@gmail.com>
AuthorDate: Mon Jun 28 12:55:34 2021 +0200
AWS Hook - allow IDP HTTP retry (#12639) (#16612)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 39 +++++-
.../connections/aws.rst | 7 +
tests/providers/amazon/aws/hooks/test_base_aws.py | 142 +++++++++++++++++++++
3 files changed, 181 insertions(+), 7 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index c1c5b1d..5a3c313 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import boto3
import botocore
import botocore.session
+import requests
import tenacity
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
@@ -41,6 +42,7 @@ try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
+
from dateutil.tz import tzlocal
from airflow.exceptions import AirflowException
@@ -214,18 +216,42 @@ class _SessionFactory(LoggingMixin):
RoleArn=role_arn, PrincipalArn=principal_arn, SAMLAssertion=saml_assertion, **assume_role_kwargs
)
- def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
- import requests
+ def _get_idp_response(
+ self, saml_config: Dict[str, Any], auth: requests.auth.AuthBase
+ ) -> requests.models.Response:
+ idp_url = saml_config["idp_url"]
+ self.log.info("idp_url= %s", idp_url)
+
+ session = requests.Session()
+
+ # Configurable Retry when querying the IDP endpoint
+ if "idp_request_retry_kwargs" in saml_config:
+ idp_request_retry_kwargs = saml_config["idp_request_retry_kwargs"]
+ self.log.info("idp_request_retry_kwargs= %s", idp_request_retry_kwargs)
+ from requests.adapters import HTTPAdapter
+ from requests.packages.urllib3.util.retry import Retry
+
+ retry_strategy = Retry(**idp_request_retry_kwargs)
+ adapter = HTTPAdapter(max_retries=retry_strategy)
+ session.mount("https://", adapter)
+ session.mount("http://", adapter)
+
+ idp_request_kwargs = {}
+ if "idp_request_kwargs" in saml_config:
+ idp_request_kwargs = saml_config["idp_request_kwargs"]
+ idp_response = session.get(idp_url, auth=auth, **idp_request_kwargs)
+ idp_response.raise_for_status()
+
+ return idp_response
+
+ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
# requests_gssapi will need paramiko > 2.6 since you'll need
# 'gssapi' not 'python-gssapi' from PyPi.
# https://github.com/paramiko/paramiko/pull/1311
import requests_gssapi
from lxml import etree
- idp_url = saml_config["idp_url"]
- self.log.info("idp_url= %s", idp_url)
- idp_request_kwargs = saml_config["idp_request_kwargs"]
auth = requests_gssapi.HTTPSPNEGOAuth()
if 'mutual_authentication' in saml_config:
mutual_auth = saml_config['mutual_authentication']
@@ -242,8 +268,7 @@ class _SessionFactory(LoggingMixin):
'(Exclude this setting will default to HTTPSPNEGOAuth() ).'
)
# Query the IDP
- idp_response = requests.get(idp_url, auth=auth, **idp_request_kwargs)
- idp_response.raise_for_status()
+ idp_response = self._get_idp_response(saml_config, auth=auth)
# Assist with debugging. Note: contains sensitive info!
xpath = saml_config['saml_response_xpath']
log_idp_response = 'log_idp_response' in saml_config and saml_config['log_idp_response']
diff --git a/docs/apache-airflow-providers-amazon/connections/aws.rst b/docs/apache-airflow-providers-amazon/connections/aws.rst
index 1c8f361..0d7b04c 100644
--- a/docs/apache-airflow-providers-amazon/connections/aws.rst
+++ b/docs/apache-airflow-providers-amazon/connections/aws.rst
@@ -159,6 +159,12 @@ This assumes all other Connection fields eg **Login** are empty.
"headers":{"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
"verify":false
},
+ "idp_request_retry_kwargs": {
+ "total": 10,
+ "backoff_factor":1,
+ "status":10,
+ "status_forcelist": [400, 429, 500, 502, 503, 504]
+ },
"log_idp_response":false,
"saml_response_xpath":"////INPUT[@NAME='SAMLResponse']/@VALUE",
},
@@ -173,6 +179,7 @@ The following settings may be used within the ``assume_role_with_saml`` containe
* ``idp_auth_method``: Specify "http_spegno_auth" to use the Python ``requests_gssapi`` library. This library is more up to date than ``requests_kerberos`` and is backward compatible. See ``requests_gssapi`` documentation on PyPI.
* ``mutual_authentication``: Can be "REQUIRED", "OPTIONAL" or "DISABLED". See ``requests_gssapi`` documentation on PyPI.
* ``idp_request_kwargs``: Additional ``kwargs`` passed to ``requests`` when requesting from the IDP (over HTTP/S).
+ * ``idp_request_retry_kwargs``: Additional ``kwargs`` to construct a ``urllib3.util.Retry`` used as a retry strategy when requesting from the IDP. See the ``urllib3`` documentation for more details.
* ``log_idp_response``: Useful for debugging - if specified, print the IDP response content to the log. Note that a successful response will contain sensitive information!
* ``saml_response_xpath``: How to query the IDP response using XML / HTML xpath.
* ``assume_role_kwargs``: Additional ``kwargs`` passed to ``sts_client.assume_role_with_saml``.
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 383880d..a6367d8 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -18,10 +18,12 @@
#
import json
import unittest
+from base64 import b64encode
from unittest import mock
import boto3
import pytest
+from moto.core import ACCOUNT_ID
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -34,6 +36,74 @@ except ImportError:
mock_sts = None
mock_iam = None
+# pylint: disable=line-too-long
+SAML_ASSERTION = """
+<?xml version="1.0"?>
+<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="_00000000-0000-0000-0000-000000000000" Version="2.0" IssueInstant="2012-01-01T12:00:00.000Z" Destination="https://signin.aws.amazon.com/saml" Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
+ <Issuer xmlns="urn:oasis:names:tc:SAML:2.0:assertion">http://localhost/</Issuer>
+ <samlp:Status>
+ <samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
+ </samlp:Status>
+ <Assertion xmlns="urn:oasis:names:tc:SAML:2.0:assertion" ID="_00000000-0000-0000-0000-000000000000" IssueInstant="2012-12-01T12:00:00.000Z" Version="2.0">
+ <Issuer>http://localhost:3000/</Issuer>
+ <ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
+ <ds:SignedInfo>
+ <ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>
+ <ds:SignatureMethod Algorithm="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"/>
+ <ds:Reference URI="#_00000000-0000-0000-0000-000000000000">
+ <ds:Transforms>
+ <ds:Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature"/>
+ <ds:Transform Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>
+ </ds:Transforms>
+ <ds:DigestMethod Algorithm="http://www.w3.org/2001/04/xmlenc#sha256"/>
+ <ds:DigestValue>NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo=</ds:DigestValue>
+ </ds:Reference>
+ </ds:SignedInfo>
+ <ds:SignatureValue>NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo=</ds:SignatureValue>
+ <KeyInfo xmlns="http://www.w3.org/2000/09/xmldsig#">
+ <ds:X509Data>
+ <ds:X509Certificate>NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo=</ds:X509Certificate>
+ </ds:X509Data>
+ </KeyInfo>
+ </ds:Signature>
+ <Subject>
+ <NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:persistent">{username}</NameID>
+ <SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
+ <SubjectConfirmationData NotOnOrAfter="2012-01-01T13:00:00.000Z" Recipient="https://signin.aws.amazon.com/saml"/>
+ </SubjectConfirmation>
+ </Subject>
+ <Conditions NotBefore="2012-01-01T12:00:00.000Z" NotOnOrAfter="2012-01-01T13:00:00.000Z">
+ <AudienceRestriction>
+ <Audience>urn:amazon:webservices</Audience>
+ </AudienceRestriction>
+ </Conditions>
+ <AttributeStatement>
+ <Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName">
+ <AttributeValue>{username}@localhost</AttributeValue>
+ </Attribute>
+ <Attribute Name="https://aws.amazon.com/SAML/Attributes/Role">
+ <AttributeValue>arn:aws:iam::{account_id}:saml-provider/{provider_name},arn:aws:iam::{account_id}:role/{role_name}</AttributeValue>
+ </Attribute>
+ <Attribute Name="https://aws.amazon.com/SAML/Attributes/SessionDuration">
+ <AttributeValue>900</AttributeValue>
+ </Attribute>
+ </AttributeStatement>
+ <AuthnStatement AuthnInstant="2012-01-01T12:00:00.000Z" SessionIndex="_00000000-0000-0000-0000-000000000000">
+ <AuthnContext>
+ <AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</AuthnContextClassRef>
+ </AuthnContext>
+ </AuthnStatement>
+ </Assertion>
+</samlp:Response>""".format( # noqa: E501
+ account_id=ACCOUNT_ID,
+ role_name="test-role",
+ provider_name="TestProvFed",
+ username="testuser",
+).replace(
+ "\n", ""
+)
+# pylint: enable=line-too-long
+
class TestAwsBaseHook(unittest.TestCase):
@unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@@ -252,6 +322,78 @@ class TestAwsBaseHook(unittest.TestCase):
[mock.call.get_default_id_token_credentials(target_audience='aws-federation.airflow.apache.org')]
)
+ @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
+ @mock.patch.object(AwsBaseHook, 'get_connection')
+ @mock_sts
+ def test_assume_role_with_saml(self, mock_get_connection):
+
+ idp_url = "https://my-idp.local.corp"
+ principal_arn = "principal_arn_1234567890"
+ role_arn = "arn:aws:iam::123456:role/role_arn"
+ xpath = "1234"
+ duration_seconds = 901
+
+ mock_connection = Connection(
+ extra=json.dumps(
+ {
+ "role_arn": role_arn,
+ "assume_role_method": "assume_role_with_saml",
+ "assume_role_with_saml": {
+ "principal_arn": principal_arn,
+ "idp_url": idp_url,
+ "idp_auth_method": "http_spegno_auth",
+ "mutual_authentication": "REQUIRED",
+ "saml_response_xpath": xpath,
+ "log_idp_response": True,
+ },
+ "assume_role_kwargs": {"DurationSeconds": duration_seconds},
+ }
+ )
+ )
+ mock_get_connection.return_value = mock_connection
+
+ encoded_saml_assertion = b64encode(SAML_ASSERTION.encode("utf-8")).decode("utf-8")
+
+ # Store original __import__
+ orig_import = __import__
+ mock_requests_gssapi = mock.Mock()
+ mock_auth = mock_requests_gssapi.HTTPSPNEGOAuth()
+
+ mock_lxml = mock.Mock()
+ mock_xpath = mock_lxml.etree.fromstring.return_value.xpath
+ mock_xpath.return_value = encoded_saml_assertion
+
+ def import_mock(name, *args, **kwargs):
+ if name == 'requests_gssapi':
+ return mock_requests_gssapi
+ if name == 'lxml':
+ return mock_lxml
+ return orig_import(name, *args, **kwargs)
+
+ with mock.patch('builtins.__import__', side_effect=import_mock), mock.patch(
+ 'airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get'
+ ) as mock_get, mock.patch('airflow.providers.amazon.aws.hooks.base_aws.boto3') as mock_boto3:
+ mock_get.return_value.ok = True
+
+ hook = AwsBaseHook(aws_conn_id='aws_default', client_type='s3')
+ hook.get_client_type('s3')
+
+ mock_get.assert_called_once_with(idp_url, auth=mock_auth)
+ mock_xpath.assert_called_once_with(xpath)
+
+ calls_assume_role_with_saml = [
+ mock.call.session.Session().client('sts', config=None),
+ mock.call.session.Session()
+ .client()
+ .assume_role_with_saml(
+ DurationSeconds=duration_seconds,
+ PrincipalArn=principal_arn,
+ RoleArn=role_arn,
+ SAMLAssertion=encoded_saml_assertion,
+ ),
+ ]
+ mock_boto3.assert_has_calls(calls_assume_role_with_saml)
+
@unittest.skipIf(mock_iam is None, 'mock_iam package not present')
@mock_iam
def test_expand_role(self):