You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/06/28 10:57:02 UTC

[airflow] branch main updated: AWS Hook - allow IDP HTTP retry (#12639) (#16612)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d80383  AWS Hook - allow IDP HTTP retry (#12639) (#16612)
0d80383 is described below

commit 0d80383bdd506c2eff8ef29d0ff461620a966f86
Author: Bjorn Olsen <bj...@gmail.com>
AuthorDate: Mon Jun 28 12:55:34 2021 +0200

    AWS Hook - allow IDP HTTP retry (#12639) (#16612)
---
 airflow/providers/amazon/aws/hooks/base_aws.py     |  39 +++++-
 .../connections/aws.rst                            |   7 +
 tests/providers/amazon/aws/hooks/test_base_aws.py  | 142 +++++++++++++++++++++
 3 files changed, 181 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index c1c5b1d..5a3c313 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
 import boto3
 import botocore
 import botocore.session
+import requests
 import tenacity
 from botocore.config import Config
 from botocore.credentials import ReadOnlyCredentials
@@ -41,6 +42,7 @@ try:
     from functools import cached_property
 except ImportError:
     from cached_property import cached_property
+
 from dateutil.tz import tzlocal
 
 from airflow.exceptions import AirflowException
@@ -214,18 +216,42 @@ class _SessionFactory(LoggingMixin):
             RoleArn=role_arn, PrincipalArn=principal_arn, SAMLAssertion=saml_assertion, **assume_role_kwargs
         )
 
-    def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
-        import requests
+    def _get_idp_response(
+        self, saml_config: Dict[str, Any], auth: requests.auth.AuthBase
+    ) -> requests.models.Response:
+        idp_url = saml_config["idp_url"]
+        self.log.info("idp_url= %s", idp_url)
+
+        session = requests.Session()
+
+        # Configurable Retry when querying the IDP endpoint
+        if "idp_request_retry_kwargs" in saml_config:
+            idp_request_retry_kwargs = saml_config["idp_request_retry_kwargs"]
+            self.log.info("idp_request_retry_kwargs= %s", idp_request_retry_kwargs)
+            from requests.adapters import HTTPAdapter
+            from requests.packages.urllib3.util.retry import Retry
+
+            retry_strategy = Retry(**idp_request_retry_kwargs)
+            adapter = HTTPAdapter(max_retries=retry_strategy)
+            session.mount("https://", adapter)
+            session.mount("http://", adapter)
+
+        idp_request_kwargs = {}
+        if "idp_request_kwargs" in saml_config:
+            idp_request_kwargs = saml_config["idp_request_kwargs"]
 
+        idp_response = session.get(idp_url, auth=auth, **idp_request_kwargs)
+        idp_response.raise_for_status()
+
+        return idp_response
+
+    def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
         # requests_gssapi will need paramiko > 2.6 since you'll need
         # 'gssapi' not 'python-gssapi' from PyPi.
         # https://github.com/paramiko/paramiko/pull/1311
         import requests_gssapi
         from lxml import etree
 
-        idp_url = saml_config["idp_url"]
-        self.log.info("idp_url= %s", idp_url)
-        idp_request_kwargs = saml_config["idp_request_kwargs"]
         auth = requests_gssapi.HTTPSPNEGOAuth()
         if 'mutual_authentication' in saml_config:
             mutual_auth = saml_config['mutual_authentication']
@@ -242,8 +268,7 @@ class _SessionFactory(LoggingMixin):
                     '(Exclude this setting will default to HTTPSPNEGOAuth() ).'
                 )
         # Query the IDP
-        idp_response = requests.get(idp_url, auth=auth, **idp_request_kwargs)
-        idp_response.raise_for_status()
+        idp_response = self._get_idp_response(saml_config, auth=auth)
         # Assist with debugging. Note: contains sensitive info!
         xpath = saml_config['saml_response_xpath']
         log_idp_response = 'log_idp_response' in saml_config and saml_config['log_idp_response']
diff --git a/docs/apache-airflow-providers-amazon/connections/aws.rst b/docs/apache-airflow-providers-amazon/connections/aws.rst
index 1c8f361..0d7b04c 100644
--- a/docs/apache-airflow-providers-amazon/connections/aws.rst
+++ b/docs/apache-airflow-providers-amazon/connections/aws.rst
@@ -159,6 +159,12 @@ This assumes all other Connection fields eg **Login** are empty.
           "headers":{"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
           "verify":false
         },
+        "idp_request_retry_kwargs": {
+          "total": 10,
+          "backoff_factor":1,
+          "status":10,
+          "status_forcelist": [400, 429, 500, 502, 503, 504]
+        },
         "log_idp_response":false,
         "saml_response_xpath":"////INPUT[@NAME='SAMLResponse']/@VALUE",
       },
@@ -173,6 +179,7 @@ The following settings may be used within the ``assume_role_with_saml`` containe
     * ``idp_auth_method``: Specify "http_spegno_auth" to use the Python ``requests_gssapi`` library. This library is more up to date than ``requests_kerberos`` and is backward compatible. See ``requests_gssapi`` documentation on PyPI.
     * ``mutual_authentication``: Can be "REQUIRED", "OPTIONAL" or "DISABLED". See ``requests_gssapi`` documentation on PyPI.
     * ``idp_request_kwargs``: Additional ``kwargs`` passed to ``requests`` when requesting from the IDP (over HTTP/S).
+    * ``idp_request_retry_kwargs``: Additional ``kwargs`` to construct a ``urllib3.util.Retry`` used as a retry strategy when requesting from the IDP. See the ``urllib3`` documentation for more details.
     * ``log_idp_response``: Useful for debugging - if specified, print the IDP response content to the log. Note that a successful response will contain sensitive information!
     * ``saml_response_xpath``: How to query the IDP response using XML / HTML xpath.
     * ``assume_role_kwargs``: Additional ``kwargs`` passed to ``sts_client.assume_role_with_saml``.
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 383880d..a6367d8 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -18,10 +18,12 @@
 #
 import json
 import unittest
+from base64 import b64encode
 from unittest import mock
 
 import boto3
 import pytest
+from moto.core import ACCOUNT_ID
 
 from airflow.models import Connection
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -34,6 +36,74 @@ except ImportError:
     mock_sts = None
     mock_iam = None
 
+# pylint: disable=line-too-long
+SAML_ASSERTION = """
+<?xml version="1.0"?>
+<samlp:Response xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="_00000000-0000-0000-0000-000000000000" Version="2.0" IssueInstant="2012-01-01T12:00:00.000Z" Destination="https://signin.aws.amazon.com/saml" Consent="urn:oasis:names:tc:SAML:2.0:consent:unspecified">
+  <Issuer xmlns="urn:oasis:names:tc:SAML:2.0:assertion">http://localhost/</Issuer>
+  <samlp:Status>
+    <samlp:StatusCode Value="urn:oasis:names:tc:SAML:2.0:status:Success"/>
+  </samlp:Status>
+  <Assertion xmlns="urn:oasis:names:tc:SAML:2.0:assertion" ID="_00000000-0000-0000-0000-000000000000" IssueInstant="2012-12-01T12:00:00.000Z" Version="2.0">
+    <Issuer>http://localhost:3000/</Issuer>
+    <ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
+      <ds:SignedInfo>
+        <ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>
+        <ds:SignatureMethod Algorithm="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"/>
+        <ds:Reference URI="#_00000000-0000-0000-0000-000000000000">
+          <ds:Transforms>
+            <ds:Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature"/>
+            <ds:Transform Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>
+          </ds:Transforms>
+          <ds:DigestMethod Algorithm="http://www.w3.org/2001/04/xmlenc#sha256"/>
+          <ds:DigestValue>NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo=</ds:DigestValue>
+        </ds:Reference>
+      </ds:SignedInfo>
+      <ds:SignatureValue>NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo=</ds:SignatureValue>
+      <KeyInfo xmlns="http://www.w3.org/2000/09/xmldsig#">
+        <ds:X509Data>
+          <ds:X509Certificate>NTIyMzk0ZGI4MjI0ZjI5ZGNhYjkyOGQyZGQ1NTZjODViZjk5YTY4ODFjOWRjNjkyYzZmODY2ZDQ4NjlkZjY3YSAgLQo=</ds:X509Certificate>
+        </ds:X509Data>
+      </KeyInfo>
+    </ds:Signature>
+    <Subject>
+      <NameID Format="urn:oasis:names:tc:SAML:2.0:nameid-format:persistent">{username}</NameID>
+      <SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer">
+        <SubjectConfirmationData NotOnOrAfter="2012-01-01T13:00:00.000Z" Recipient="https://signin.aws.amazon.com/saml"/>
+      </SubjectConfirmation>
+    </Subject>
+    <Conditions NotBefore="2012-01-01T12:00:00.000Z" NotOnOrAfter="2012-01-01T13:00:00.000Z">
+      <AudienceRestriction>
+        <Audience>urn:amazon:webservices</Audience>
+      </AudienceRestriction>
+    </Conditions>
+    <AttributeStatement>
+      <Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName">
+        <AttributeValue>{username}@localhost</AttributeValue>
+      </Attribute>
+      <Attribute Name="https://aws.amazon.com/SAML/Attributes/Role">
+        <AttributeValue>arn:aws:iam::{account_id}:saml-provider/{provider_name},arn:aws:iam::{account_id}:role/{role_name}</AttributeValue>
+      </Attribute>
+      <Attribute Name="https://aws.amazon.com/SAML/Attributes/SessionDuration">
+        <AttributeValue>900</AttributeValue>
+      </Attribute>
+    </AttributeStatement>
+    <AuthnStatement AuthnInstant="2012-01-01T12:00:00.000Z" SessionIndex="_00000000-0000-0000-0000-000000000000">
+      <AuthnContext>
+        <AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</AuthnContextClassRef>
+      </AuthnContext>
+    </AuthnStatement>
+  </Assertion>
+</samlp:Response>""".format(  # noqa: E501
+    account_id=ACCOUNT_ID,
+    role_name="test-role",
+    provider_name="TestProvFed",
+    username="testuser",
+).replace(
+    "\n", ""
+)
+# pylint: enable=line-too-long
+
 
 class TestAwsBaseHook(unittest.TestCase):
     @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
@@ -252,6 +322,78 @@ class TestAwsBaseHook(unittest.TestCase):
             [mock.call.get_default_id_token_credentials(target_audience='aws-federation.airflow.apache.org')]
         )
 
+    @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
+    @mock.patch.object(AwsBaseHook, 'get_connection')
+    @mock_sts
+    def test_assume_role_with_saml(self, mock_get_connection):
+
+        idp_url = "https://my-idp.local.corp"
+        principal_arn = "principal_arn_1234567890"
+        role_arn = "arn:aws:iam::123456:role/role_arn"
+        xpath = "1234"
+        duration_seconds = 901
+
+        mock_connection = Connection(
+            extra=json.dumps(
+                {
+                    "role_arn": role_arn,
+                    "assume_role_method": "assume_role_with_saml",
+                    "assume_role_with_saml": {
+                        "principal_arn": principal_arn,
+                        "idp_url": idp_url,
+                        "idp_auth_method": "http_spegno_auth",
+                        "mutual_authentication": "REQUIRED",
+                        "saml_response_xpath": xpath,
+                        "log_idp_response": True,
+                    },
+                    "assume_role_kwargs": {"DurationSeconds": duration_seconds},
+                }
+            )
+        )
+        mock_get_connection.return_value = mock_connection
+
+        encoded_saml_assertion = b64encode(SAML_ASSERTION.encode("utf-8")).decode("utf-8")
+
+        # Store original __import__
+        orig_import = __import__
+        mock_requests_gssapi = mock.Mock()
+        mock_auth = mock_requests_gssapi.HTTPSPNEGOAuth()
+
+        mock_lxml = mock.Mock()
+        mock_xpath = mock_lxml.etree.fromstring.return_value.xpath
+        mock_xpath.return_value = encoded_saml_assertion
+
+        def import_mock(name, *args, **kwargs):
+            if name == 'requests_gssapi':
+                return mock_requests_gssapi
+            if name == 'lxml':
+                return mock_lxml
+            return orig_import(name, *args, **kwargs)
+
+        with mock.patch('builtins.__import__', side_effect=import_mock), mock.patch(
+            'airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get'
+        ) as mock_get, mock.patch('airflow.providers.amazon.aws.hooks.base_aws.boto3') as mock_boto3:
+            mock_get.return_value.ok = True
+
+            hook = AwsBaseHook(aws_conn_id='aws_default', client_type='s3')
+            hook.get_client_type('s3')
+
+            mock_get.assert_called_once_with(idp_url, auth=mock_auth)
+            mock_xpath.assert_called_once_with(xpath)
+
+        calls_assume_role_with_saml = [
+            mock.call.session.Session().client('sts', config=None),
+            mock.call.session.Session()
+            .client()
+            .assume_role_with_saml(
+                DurationSeconds=duration_seconds,
+                PrincipalArn=principal_arn,
+                RoleArn=role_arn,
+                SAMLAssertion=encoded_saml_assertion,
+            ),
+        ]
+        mock_boto3.assert_has_calls(calls_assume_role_with_saml)
+
     @unittest.skipIf(mock_iam is None, 'mock_iam package not present')
     @mock_iam
     def test_expand_role(self):