You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/08/29 07:26:30 UTC

[GitHub] Fokko closed pull request #3570: [AIRFLOW-2709] Improve error handling in Databricks hook

Fokko closed pull request #3570: [AIRFLOW-2709] Improve error handling in Databricks hook
URL: https://github.com/apache/incubator-airflow/pull/3570
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 54f00e0090..5b97a0eba0 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -24,6 +24,7 @@
 from airflow.hooks.base_hook import BaseHook
 from requests import exceptions as requests_exceptions
 from requests.auth import AuthBase
+from time import sleep
 
 from airflow.utils.log.logging_mixin import LoggingMixin
 
@@ -47,7 +48,8 @@ def __init__(
             self,
             databricks_conn_id='databricks_default',
             timeout_seconds=180,
-            retry_limit=3):
+            retry_limit=3,
+            retry_delay=1.0):
         """
         :param databricks_conn_id: The name of the databricks connection to use.
         :type databricks_conn_id: string
@@ -57,6 +59,9 @@ def __init__(
         :param retry_limit: The number of times to retry the connection in case of
             service outages.
         :type retry_limit: int
+        :param retry_delay: The number of seconds to wait between retries (it
+            might be a floating point number).
+        :type retry_delay: float
         """
         self.databricks_conn_id = databricks_conn_id
         self.databricks_conn = self.get_connection(databricks_conn_id)
@@ -64,6 +69,7 @@ def __init__(
         if retry_limit < 1:
             raise ValueError('Retry limit must be greater than equal to 1')
         self.retry_limit = retry_limit
+        self.retry_delay = retry_delay
 
     @staticmethod
     def _parse_host(host):
@@ -119,7 +125,8 @@ def _do_api_call(self, endpoint_info, json):
         else:
             raise AirflowException('Unexpected HTTP Method: ' + method)
 
-        for attempt_num in range(1, self.retry_limit + 1):
+        attempt_num = 1
+        while True:
             try:
                 response = request_func(
                     url,
@@ -127,21 +134,29 @@ def _do_api_call(self, endpoint_info, json):
                     auth=auth,
                     headers=USER_AGENT_HEADER,
                     timeout=self.timeout_seconds)
-                if response.status_code == requests.codes.ok:
-                    return response.json()
-                else:
+                response.raise_for_status()
+                return response.json()
+            except requests_exceptions.RequestException as e:
+                if not _retryable_error(e):
                     # In this case, the user probably made a mistake.
                     # Don't retry.
                     raise AirflowException('Response: {0}, Status Code: {1}'.format(
-                        response.content, response.status_code))
-            except (requests_exceptions.ConnectionError,
-                    requests_exceptions.Timeout) as e:
-                self.log.error(
-                    'Attempt %s API Request to Databricks failed with reason: %s',
-                    attempt_num, e
-                )
-        raise AirflowException(('API requests to Databricks failed {} times. ' +
-                               'Giving up.').format(self.retry_limit))
+                        e.response.content, e.response.status_code))
+
+                self._log_request_error(attempt_num, e)
+
+            if attempt_num == self.retry_limit:
+                raise AirflowException(('API requests to Databricks failed {} times. ' +
+                                        'Giving up.').format(self.retry_limit))
+
+            attempt_num += 1
+            sleep(self.retry_delay)
+
+    def _log_request_error(self, attempt_num, error):
+        self.log.error(
+            'Attempt %s API Request to Databricks failed with reason: %s',
+            attempt_num, error
+        )
 
     def submit_run(self, json):
         """
@@ -175,6 +190,12 @@ def cancel_run(self, run_id):
         self._do_api_call(CANCEL_RUN_ENDPOINT, json)
 
 
+def _retryable_error(exception):
+    return isinstance(exception, requests_exceptions.ConnectionError) \
+        or isinstance(exception, requests_exceptions.Timeout) \
+        or exception.response is not None and exception.response.status_code >= 500
+
+
 RUN_LIFE_CYCLE_STATES = [
     'PENDING',
     'RUNNING',
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
index 7b8d522dba..3245a99256 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator):
     :param databricks_retry_limit: Amount of times retry if the Databricks backend is
         unreachable. Its value must be greater than or equal to 1.
     :type databricks_retry_limit: int
+    :param databricks_retry_delay: Number of seconds to wait between retries (it
+            might be a floating point number).
+    :type databricks_retry_delay: float
     :param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
     :type do_xcom_push: boolean
     """
@@ -168,6 +171,7 @@ def __init__(
             databricks_conn_id='databricks_default',
             polling_period_seconds=30,
             databricks_retry_limit=3,
+            databricks_retry_delay=1,
             do_xcom_push=False,
             **kwargs):
         """
@@ -178,6 +182,7 @@ def __init__(
         self.databricks_conn_id = databricks_conn_id
         self.polling_period_seconds = polling_period_seconds
         self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
         if spark_jar_task is not None:
             self.json['spark_jar_task'] = spark_jar_task
         if notebook_task is not None:
@@ -232,7 +237,8 @@ def _log_run_page_url(self, url):
     def get_hook(self):
         return DatabricksHook(
             self.databricks_conn_id,
-            retry_limit=self.databricks_retry_limit)
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay)
 
     def execute(self, context):
         hook = self.get_hook()
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index aca8dd9600..a022431899 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -18,15 +18,21 @@
 # under the License.
 #
 
+import itertools
 import json
 import unittest
 
+from requests import exceptions as requests_exceptions
+
 from airflow import __version__
-from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth
+from airflow.contrib.hooks.databricks_hook import (
+    DatabricksHook,
+    RunState,
+    SUBMIT_RUN_ENDPOINT
+)
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.utils import db
-from requests import exceptions as requests_exceptions
 
 try:
     from unittest import mock
@@ -79,12 +85,48 @@ def get_run_endpoint(host):
     """
     return 'https://{}/api/2.0/jobs/runs/get'.format(host)
 
+
 def cancel_run_endpoint(host):
     """
     Utility function to generate the get run endpoint given the host.
     """
     return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
 
+
+def create_valid_response_mock(content):
+    response = mock.MagicMock()
+    response.json.return_value = content
+    return response
+
+
+def create_post_side_effect(exception, status_code=500):
+    if exception != requests_exceptions.HTTPError:
+        return exception()
+    else:
+        response = mock.MagicMock()
+        response.status_code = status_code
+        response.raise_for_status.side_effect = exception(response=response)
+        return response
+
+
+def setup_mock_requests(
+        mock_requests,
+        exception,
+        status_code=500,
+        error_count=None,
+        response_content=None):
+
+    side_effect = create_post_side_effect(exception, status_code)
+
+    if error_count is None:
+        # POST requests will fail indefinitely
+        mock_requests.post.side_effect = itertools.repeat(side_effect)
+    else:
+        # POST requests will fail 'error_count' times, and then they will succeed (once)
+        mock_requests.post.side_effect = \
+            [side_effect] * error_count + [create_valid_response_mock(response_content)]
+
+
 class DatabricksHookTest(unittest.TestCase):
     """
     Tests for DatabricksHook.
@@ -99,7 +141,7 @@ def setUp(self, session=None):
         conn.password = PASSWORD
         session.commit()
 
-        self.hook = DatabricksHook()
+        self.hook = DatabricksHook(retry_delay=0)
 
     def test_parse_host_with_proper_host(self):
         host = self.hook._parse_host(HOST)
@@ -111,34 +153,85 @@ def test_parse_host_with_scheme(self):
 
     def test_init_bad_retry_limit(self):
         with self.assertRaises(ValueError):
-            DatabricksHook(retry_limit = 0)
-
-    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
-    def test_do_api_call_with_error_retry(self, mock_requests):
-        for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
-            with mock.patch.object(self.hook.log, 'error') as mock_errors:
-                mock_requests.reset_mock()
-                mock_requests.post.side_effect = exception()
+            DatabricksHook(retry_limit=0)
+
+    def test_do_api_call_retries_with_retryable_error(self):
+        for exception in [
+                requests_exceptions.ConnectionError,
+                requests_exceptions.SSLError,
+                requests_exceptions.Timeout,
+                requests_exceptions.ConnectTimeout,
+                requests_exceptions.HTTPError]:
+            with mock.patch(
+                'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
+                    mock.patch.object(self.hook.log, 'error') as mock_errors:
+                setup_mock_requests(mock_requests, exception)
 
                 with self.assertRaises(AirflowException):
                     self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
 
-                self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit)
+                self.assertEquals(mock_errors.call_count, self.hook.retry_limit)
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
-    def test_do_api_call_with_bad_status_code(self, mock_requests):
-        mock_requests.codes.ok = 200
-        status_code_mock = mock.PropertyMock(return_value=500)
-        type(mock_requests.post.return_value).status_code = status_code_mock
-        with self.assertRaises(AirflowException):
-            self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+    def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests):
+        setup_mock_requests(
+            mock_requests, requests_exceptions.HTTPError, status_code=400
+        )
+
+        with mock.patch.object(self.hook.log, 'error') as mock_errors:
+            with self.assertRaises(AirflowException):
+                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+            mock_errors.assert_not_called()
+
+    def test_do_api_call_succeeds_after_retrying(self):
+        for exception in [
+                requests_exceptions.ConnectionError,
+                requests_exceptions.SSLError,
+                requests_exceptions.Timeout,
+                requests_exceptions.ConnectTimeout,
+                requests_exceptions.HTTPError]:
+            with mock.patch(
+                'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
+                    mock.patch.object(self.hook.log, 'error') as mock_errors:
+                setup_mock_requests(
+                    mock_requests,
+                    exception,
+                    error_count=2,
+                    response_content={'run_id': '1'}
+                )
+
+                response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+                self.assertEquals(mock_errors.call_count, 2)
+                self.assertEquals(response, {'run_id': '1'})
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
+    def test_do_api_call_waits_between_retries(self, mock_sleep):
+        retry_delay = 5
+        self.hook = DatabricksHook(retry_delay=retry_delay)
+
+        for exception in [
+                requests_exceptions.ConnectionError,
+                requests_exceptions.SSLError,
+                requests_exceptions.Timeout,
+                requests_exceptions.ConnectTimeout,
+                requests_exceptions.HTTPError]:
+            with mock.patch(
+                'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \
+                    mock.patch.object(self.hook.log, 'error'):
+                mock_sleep.reset_mock()
+                setup_mock_requests(mock_requests, exception)
+
+                with self.assertRaises(AirflowException):
+                    self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+                self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1)
+                mock_sleep.assert_called_with(retry_delay)
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_submit_run(self, mock_requests):
-        mock_requests.codes.ok = 200
         mock_requests.post.return_value.json.return_value = {'run_id': '1'}
-        status_code_mock = mock.PropertyMock(return_value=200)
-        type(mock_requests.post.return_value).status_code = status_code_mock
         json = {
           'notebook_task': NOTEBOOK_TASK,
           'new_cluster': NEW_CLUSTER
@@ -158,10 +251,7 @@ def test_submit_run(self, mock_requests):
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_get_run_page_url(self, mock_requests):
-        mock_requests.codes.ok = 200
         mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
-        status_code_mock = mock.PropertyMock(return_value=200)
-        type(mock_requests.get.return_value).status_code = status_code_mock
 
         run_page_url = self.hook.get_run_page_url(RUN_ID)
 
@@ -175,10 +265,7 @@ def test_get_run_page_url(self, mock_requests):
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_get_run_state(self, mock_requests):
-        mock_requests.codes.ok = 200
         mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
-        status_code_mock = mock.PropertyMock(return_value=200)
-        type(mock_requests.get.return_value).status_code = status_code_mock
 
         run_state = self.hook.get_run_state(RUN_ID)
 
@@ -195,10 +282,7 @@ def test_get_run_state(self, mock_requests):
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_cancel_run(self, mock_requests):
-        mock_requests.codes.ok = 200
         mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
-        status_code_mock = mock.PropertyMock(return_value=200)
-        type(mock_requests.post.return_value).status_code = status_code_mock
 
         self.hook.cancel_run(RUN_ID)
 
diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py
index f77da2ec18..afe1a92f28 100644
--- a/tests/contrib/operators/test_databricks_operator.py
+++ b/tests/contrib/operators/test_databricks_operator.py
@@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class):
           'run_name': TASK_ID
         })
         db_mock_class.assert_called_once_with(
-                DEFAULT_CONN_ID,
-                retry_limit=op.databricks_retry_limit)
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay)
         db_mock.submit_run.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run_state.assert_called_once_with(RUN_ID)
@@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class):
           'run_name': TASK_ID,
         })
         db_mock_class.assert_called_once_with(
-                DEFAULT_CONN_ID,
-                retry_limit=op.databricks_retry_limit)
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay)
         db_mock.submit_run.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run_state.assert_called_once_with(RUN_ID)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services