You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fo...@apache.org on 2018/05/14 19:53:38 UTC

incubator-airflow git commit: [AIRFLOW-2110][AIRFLOW-2122] Enhance Http Hook

Repository: incubator-airflow
Updated Branches:
  refs/heads/master cb9ba02cf -> 6c19468e0


[AIRFLOW-2110][AIRFLOW-2122] Enhance Http Hook

- Use a header in passed in the "extra" argument and
  add tenacity retry
- Fix the tests with proper mocking

Closes #3071 from albertocalderari/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/6c19468e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/6c19468e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/6c19468e

Branch: refs/heads/master
Commit: 6c19468e0b3b938249acc43e4b833a753d093efc
Parents: cb9ba02
Author: alberto.calderari <al...@just-eat.com>
Authored: Mon May 14 21:52:19 2018 +0200
Committer: Fokko Driesprong <fo...@godatadriven.com>
Committed: Mon May 14 21:52:22 2018 +0200

----------------------------------------------------------------------
 airflow/hooks/base_hook.py    |   2 -
 airflow/hooks/http_hook.py    | 128 +++++++++++++++----
 scripts/ci/requirements.txt   |   1 +
 setup.py                      |   1 +
 tests/core.py                 |  50 --------
 tests/hooks/test_http_hook.py | 254 +++++++++++++++++++++++++++++++++++++
 6 files changed, 356 insertions(+), 80 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c19468e/airflow/hooks/base_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py
index 65086ad..16b19e0 100644
--- a/airflow/hooks/base_hook.py
+++ b/airflow/hooks/base_hook.py
@@ -25,7 +25,6 @@ from __future__ import unicode_literals
 import os
 import random
 
-from airflow import settings
 from airflow.models import Connection
 from airflow.exceptions import AirflowException
 from airflow.utils.db import provide_session
@@ -45,7 +44,6 @@ class BaseHook(LoggingMixin):
     def __init__(self, source):
         pass
 
-
     @classmethod
     @provide_session
     def _get_connections_from_db(cls, conn_id, session=None):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c19468e/airflow/hooks/http_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py
index aa316bb..a108b53 100644
--- a/airflow/hooks/http_hook.py
+++ b/airflow/hooks/http_hook.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -20,6 +20,7 @@
 from builtins import str
 
 import requests
+import tenacity
 
 from airflow.hooks.base_hook import BaseHook
 from airflow.exceptions import AirflowException
@@ -28,16 +29,31 @@ from airflow.exceptions import AirflowException
 class HttpHook(BaseHook):
     """
     Interact with HTTP servers.
+    :param http_conn_id: connection that has the base API url i.e https://www.google.com/
+        and optional authentication credentials. Default headers can also be specified in
+        the Extra field in json format.
+    :type http_conn_id: str
+    :param method: the API method to be called
+    :type method: str
     """
 
-    def __init__(self, method='POST', http_conn_id='http_default'):
+    def __init__(
+        self,
+        method='POST',
+        http_conn_id='http_default'
+    ):
         self.http_conn_id = http_conn_id
         self.method = method
+        self.base_url = None
+        self._retry_obj = None
 
-    # headers is required to make it required
-    def get_conn(self, headers):
+    # headers may be passed through directly or in the "extra" field in the connection
+    # definition
+    def get_conn(self, headers=None):
         """
         Returns http session for use with requests
+        :param headers: additional headers to be passed through as a dictionary
+        :type headers: dict
         """
         conn = self.get_connection(self.http_conn_id)
         session = requests.Session()
@@ -53,6 +69,8 @@ class HttpHook(BaseHook):
             self.base_url = self.base_url + ":" + str(conn.port) + "/"
         if conn.login:
             session.auth = (conn.login, conn.password)
+        if conn.extra:
+            session.headers.update(conn.extra_dejson)
         if headers:
             session.headers.update(headers)
 
@@ -61,6 +79,16 @@ class HttpHook(BaseHook):
     def run(self, endpoint, data=None, headers=None, extra_options=None):
         """
         Performs the request
+        :param endpoint: the endpoint to be called i.e. resource/v1/query?
+        :type endpoint: str
+        :param data: payload to be uploaded or request parameters
+        :type data: dict
+        :param headers: additional headers to be passed through as a dictionary
+        :type headers: dict
+        :param extra_options: additional options to be used when executing the request
+            i.e. {'check_response': False} to avoid checking raising exceptions on non
+            2XX or 3XX status codes
+        :type extra_options: dict
         """
         extra_options = extra_options or {}
 
@@ -90,34 +118,78 @@ class HttpHook(BaseHook):
         self.log.info("Sending '%s' to url: %s", self.method, url)
         return self.run_and_check(session, prepped_request, extra_options)
 
+    def check_response(self, response):
+        """
+        Checks the status code and raise an AirflowException exception on non 2XX or 3XX
+        status codes
+        :param response: A requests response object
+        :type response: requests.response
+        """
+        try:
+            response.raise_for_status()
+        except requests.exceptions.HTTPError:
+            self.log.error("HTTP error: %s", response.reason)
+            if self.method not in ['GET', 'HEAD']:
+                self.log.error(response.text)
+            raise AirflowException(str(response.status_code) + ":" + response.reason)
+
     def run_and_check(self, session, prepped_request, extra_options):
         """
         Grabs extra options like timeout and actually runs the request,
         checking for the result
+        :param session: the session to be used to execute the request
+        :type session: requests.Session
+        :param prepped_request: the prepared request generated in run()
+        :type prepped_request: session.prepare_request
+        :param extra_options: additional options to be used when executing the request
+            i.e. {'check_response': False} to avoid checking raising exceptions on non 2XX
+            or 3XX status codes
+        :type extra_options: dict
         """
         extra_options = extra_options or {}
 
-        response = session.send(
-            prepped_request,
-            stream=extra_options.get("stream", False),
-            verify=extra_options.get("verify", False),
-            proxies=extra_options.get("proxies", {}),
-            cert=extra_options.get("cert"),
-            timeout=extra_options.get("timeout"),
-            allow_redirects=extra_options.get("allow_redirects", True))
-
         try:
-            response.raise_for_status()
-        except requests.exceptions.HTTPError:
-            # Tried rewrapping, but not supported. This way, it's possible
-            # to get reason and code for failure by checking first 3 chars
-            # for the code, or do a split on ':'
-            self.log.error("HTTP error: %s", response.reason)
-            if self.method not in ('GET', 'HEAD'):
-                # The sensor uses GET, so this prevents filling up the log
-                # with the body every time the GET 'misses'.
-                # That's ok to do, because GETs should be repeatable and
-                # all data should be visible in the log (no post data)
-                self.log.error(response.text)
-            raise AirflowException(str(response.status_code)+":"+response.reason)
-        return response
+            response = session.send(
+                prepped_request,
+                stream=extra_options.get("stream", False),
+                verify=extra_options.get("verify", False),
+                proxies=extra_options.get("proxies", {}),
+                cert=extra_options.get("cert"),
+                timeout=extra_options.get("timeout"),
+                allow_redirects=extra_options.get("allow_redirects", True))
+
+            if extra_options.get('check_response', True):
+                self.check_response(response)
+            return response
+
+        except requests.exceptions.ConnectionError as ex:
+            self.log.warn(str(ex) + ' Tenacity will retry to execute the operation')
+            raise ex
+
+    def run_with_advanced_retry(self, _retry_args, *args, **kwargs):
+        """
+        Runs Hook.run() with a Tenacity decorator attached to it. This is useful for
+        connectors which might be disturbed by intermittent issues and should not
+        instantly fail.
+        :param _retry_args: Arguments which define the retry behaviour.
+            See Tenacity documentation at https://github.com/jd/tenacity
+        :type _retry_args: dict
+
+
+        Example: ::
+            hook = HttpHook(http_conn_id='my_conn',method='GET')
+            retry_args = dict(
+                 wait=tenacity.wait_exponential(),
+                 stop=tenacity.stop_after_attempt(10),
+                 retry=requests.exceptions.ConnectionError
+             )
+             hook.run_with_advanced_retry(
+                     endpoint='v1/test',
+                     _retry_args=retry_args
+                 )
+        """
+        self._retry_obj = tenacity.Retrying(
+            **_retry_args
+        )
+
+        self._retry_obj(self.run, *args, **kwargs)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c19468e/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index 8ab52fc..cb4dd41 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -90,6 +90,7 @@ Sphinx-PyPI-upload
 sphinx_rtd_theme
 sqlalchemy>=1.1.15, <1.2.0
 statsd
+tenacity==4.8.0
 thrift
 thrift_sasl
 unicodecsv

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c19468e/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 8999695..8907c00 100644
--- a/setup.py
+++ b/setup.py
@@ -268,6 +268,7 @@ def do_setup():
             'sqlalchemy>=1.1.15, <1.2.0',
             'sqlalchemy-utc>=0.9.0',
             'tabulate>=0.7.5, <0.8.0',
+            'tenacity==4.8.0',
             'thrift>=0.9.2',
             'tzlocal>=1.4',
             'werkzeug>=0.14.1, <0.15.0',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c19468e/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index ce32482..0230ecf 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -2342,56 +2342,6 @@ class HDFSHookTest(unittest.TestCase):
         client = HDFSHook().get_conn()
         self.assertIsInstance(client, snakebite.client.HAClient)
 
-
-try:
-    from airflow.hooks.http_hook import HttpHook
-except ImportError:
-    HttpHook = None
-
-
-@unittest.skipIf(HttpHook is None,
-                 "Skipping test because HttpHook is not installed")
-class HttpHookTest(unittest.TestCase):
-    def setUp(self):
-        configuration.load_test_config()
-
-    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
-    def test_http_connection(self, mock_get_connection):
-        c = models.Connection(conn_id='http_default', conn_type='http',
-                              host='localhost', schema='http')
-        mock_get_connection.return_value = c
-        hook = HttpHook()
-        hook.get_conn({})
-        self.assertEqual(hook.base_url, 'http://localhost')
-
-    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
-    def test_https_connection(self, mock_get_connection):
-        c = models.Connection(conn_id='http_default', conn_type='http',
-                              host='localhost', schema='https')
-        mock_get_connection.return_value = c
-        hook = HttpHook()
-        hook.get_conn({})
-        self.assertEqual(hook.base_url, 'https://localhost')
-
-    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
-    def test_host_encoded_http_connection(self, mock_get_connection):
-        c = models.Connection(conn_id='http_default', conn_type='http',
-                              host='http://localhost')
-        mock_get_connection.return_value = c
-        hook = HttpHook()
-        hook.get_conn({})
-        self.assertEqual(hook.base_url, 'http://localhost')
-
-    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
-    def test_host_encoded_https_connection(self, mock_get_connection):
-        c = models.Connection(conn_id='http_default', conn_type='http',
-                              host='https://localhost')
-        mock_get_connection.return_value = c
-        hook = HttpHook()
-        hook.get_conn({})
-        self.assertEqual(hook.base_url, 'https://localhost')
-
-
 send_email_test = mock.Mock()
 
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c19468e/tests/hooks/test_http_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_http_hook.py b/tests/hooks/test_http_hook.py
new file mode 100644
index 0000000..c816332
--- /dev/null
+++ b/tests/hooks/test_http_hook.py
@@ -0,0 +1,254 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import json
+
+import requests
+import requests_mock
+
+import tenacity
+
+from airflow import configuration, models
+from airflow.exceptions import AirflowException
+from airflow.hooks.http_hook import HttpHook
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+def get_airflow_connection(conn_id=None):
+    return models.Connection(
+        conn_id='http_default',
+        conn_type='http',
+        host='test:8080/',
+        extra='{"bareer": "test"}'
+    )
+
+
+class TestHttpHook(unittest.TestCase):
+    """Test get, post and raise_for_status"""
+    def setUp(self):
+        session = requests.Session()
+        adapter = requests_mock.Adapter()
+        session.mount('mock', adapter)
+        self.get_hook = HttpHook(method='GET')
+        self.post_hook = HttpHook(method='POST')
+        configuration.load_test_config()
+
+    @requests_mock.mock()
+    def test_raise_for_status_with_200(self, m):
+
+        m.get(
+            'http://test:8080/v1/test',
+            status_code=200,
+            text='{"status":{"status": 200}}',
+            reason='OK'
+        )
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+
+            resp = self.get_hook.run('v1/test')
+            self.assertEquals(resp.text, '{"status":{"status": 200}}')
+
+    @requests_mock.mock()
+    def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m):
+
+        m.get(
+            'http://test:8080/v1/test',
+            status_code=404,
+            text='{"status":{"status": 404}}',
+            reason='Bad request'
+        )
+
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+            resp = self.get_hook.run('v1/test', extra_options={'check_response': False})
+            self.assertEquals(resp.text, '{"status":{"status": 404}}')
+
+    @requests_mock.mock()
+    def test_hook_contains_header_from_extra_field(self, m):
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+            expected_conn = get_airflow_connection()
+            conn = self.get_hook.get_conn()
+            self.assertDictContainsSubset(json.loads(expected_conn.extra), conn.headers)
+            self.assertEquals(conn.headers.get('bareer'), 'test')
+
+    @requests_mock.mock()
+    def test_hook_uses_provided_header(self, m):
+            conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
+            self.assertEquals(conn.headers.get('bareer'), "newT0k3n")
+
+    @requests_mock.mock()
+    def test_hook_has_no_header_from_extra(self, m):
+            conn = self.get_hook.get_conn()
+            self.assertIsNone(conn.headers.get('bareer'))
+
+    @requests_mock.mock()
+    def test_hooks_header_from_extra_is_overridden(self, m):
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+            conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
+            self.assertEquals(conn.headers.get('bareer'), 'newT0k3n')
+
+    @requests_mock.mock()
+    def test_post_request(self, m):
+
+        m.post(
+            'http://test:8080/v1/test',
+            status_code=200,
+            text='{"status":{"status": 200}}',
+            reason='OK'
+        )
+
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+            resp = self.post_hook.run('v1/test')
+            self.assertEquals(resp.status_code, 200)
+
+    @requests_mock.mock()
+    def test_post_request_with_error_code(self, m):
+
+        m.post(
+            'http://test:8080/v1/test',
+            status_code=418,
+            text='{"status":{"status": 418}}',
+            reason='I\'m a teapot'
+        )
+
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+            with self.assertRaises(AirflowException):
+                self.post_hook.run('v1/test')
+
+    @requests_mock.mock()
+    def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, m):
+
+        m.post(
+            'http://test:8080/v1/test',
+            status_code=418,
+            text='{"status":{"status": 418}}',
+            reason='I\'m a teapot'
+        )
+
+        with mock.patch(
+            'airflow.hooks.base_hook.BaseHook.get_connection',
+            side_effect=get_airflow_connection
+        ):
+            resp = self.post_hook.run('v1/test', extra_options={'check_response': False})
+            self.assertEquals(resp.status_code, 418)
+
+    @mock.patch('airflow.hooks.http_hook.requests.Session')
+    def test_retry_on_conn_error(self, mocked_session):
+
+        retry_args = dict(
+            wait=tenacity.wait_none(),
+            stop=tenacity.stop_after_attempt(7),
+            retry=requests.exceptions.ConnectionError
+        )
+
+        def send_and_raise(request, **kwargs):
+            raise requests.exceptions.ConnectionError
+
+        mocked_session().send.side_effect = send_and_raise
+        # The job failed for some reason
+        with self.assertRaises(tenacity.RetryError):
+            self.get_hook.run_with_advanced_retry(
+                endpoint='v1/test',
+                _retry_args=retry_args
+            )
+        self.assertEquals(
+            self.get_hook._retry_obj.stop.max_attempt_number + 1,
+            mocked_session.call_count
+        )
+
+    def test_header_from_extra_and_run_method_are_merged(self):
+
+        def run_and_return(session, prepped_request, extra_options, **kwargs):
+            return prepped_request
+
+        # The job failed for some reason
+        with mock.patch(
+            'airflow.hooks.http_hook.HttpHook.run_and_check',
+            side_effect=run_and_return
+        ):
+            with mock.patch(
+                'airflow.hooks.base_hook.BaseHook.get_connection',
+                side_effect=get_airflow_connection
+            ):
+                pr = self.get_hook.run('v1/test', headers={'some_other_header': 'test'})
+                actual = dict(pr.headers)
+                self.assertEquals(actual.get('bareer'), 'test')
+                self.assertEquals(actual.get('some_other_header'), 'test')
+
+    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
+    def test_http_connection(self, mock_get_connection):
+        c = models.Connection(conn_id='http_default', conn_type='http',
+                              host='localhost', schema='http')
+        mock_get_connection.return_value = c
+        hook = HttpHook()
+        hook.get_conn({})
+        self.assertEqual(hook.base_url, 'http://localhost')
+
+    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
+    def test_https_connection(self, mock_get_connection):
+        c = models.Connection(conn_id='http_default', conn_type='http',
+                              host='localhost', schema='https')
+        mock_get_connection.return_value = c
+        hook = HttpHook()
+        hook.get_conn({})
+        self.assertEqual(hook.base_url, 'https://localhost')
+
+    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
+    def test_host_encoded_http_connection(self, mock_get_connection):
+        c = models.Connection(conn_id='http_default', conn_type='http',
+                              host='http://localhost')
+        mock_get_connection.return_value = c
+        hook = HttpHook()
+        hook.get_conn({})
+        self.assertEqual(hook.base_url, 'http://localhost')
+
+    @mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
+    def test_host_encoded_https_connection(self, mock_get_connection):
+        c = models.Connection(conn_id='http_default', conn_type='http',
+                              host='https://localhost')
+        mock_get_connection.return_value = c
+        hook = HttpHook()
+        hook.get_conn({})
+        self.assertEqual(hook.base_url, 'https://localhost')
+
+
+send_email_test = mock.Mock()
+
+
+if __name__ == '__main__':
+    unittest.main()