You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/17 23:52:50 UTC

[GitHub] stale[bot] closed pull request #2932: [AIRFLOW-1974] Improve Databricks Hook/Operator

stale[bot] closed pull request #2932: [AIRFLOW-1974] Improve Databricks Hook/Operator
URL: https://github.com/apache/incubator-airflow/pull/2932
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 54f00e0090..77a97a53d2 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -19,43 +19,69 @@
 #
 import requests
 
+from collections import namedtuple
+from requests.exceptions import ConnectionError, Timeout
+from requests.auth import AuthBase
+
 from airflow import __version__
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
-from requests import exceptions as requests_exceptions
-from requests.auth import AuthBase
-
-from airflow.utils.log.logging_mixin import LoggingMixin
 
 try:
-    from urllib import parse as urlparse
+    from urllib.parse import urlparse
 except ImportError:
-    import urlparse
-
+    from urlparse import urlparse
 
-SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
-GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
-CANCEL_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/cancel')
 USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+DEFAULT_API_VERSION = '2.0'
+
+Endpoint = namedtuple('Endpoint', ['http_method', 'path', 'method'])
 
 
-class DatabricksHook(BaseHook, LoggingMixin):
+class DatabricksHook(BaseHook):
     """
     Interact with Databricks.
     """
-    def __init__(
-            self,
-            databricks_conn_id='databricks_default',
-            timeout_seconds=180,
-            retry_limit=3):
+
+    API = {
+        # API V2.0
+        #   JOBS API
+        # '2.0/jobs/create': Endpoint('POST', '2.0/jobs/create', ''),
+        # '2.0/jobs/list': Endpoint('GET', '2.0/jobs/list', ''),
+        # '2.0/jobs/delete': Endpoint('POST', '2.0/jobs/delete', ''),
+        # '2.0/jobs/get': Endpoint('GET', '2.0/jobs/get', ''),
+        # '2.0/jobs/reset': Endpoint('POST', '2.0/jobs/reset', ''),
+        '2.0/jobs/run-now': Endpoint('POST', '2.0/jobs/run-now', ''),
+        '2.0/jobs/runs/submit': Endpoint('POST', '2.0/jobs/runs/submit',
+                                         'jobs_runs_submit'),
+        # '2.0/jobs/runs/list': Endpoint('GET', '2.0/jobs/runs/list', ''),
+        '2.0/jobs/runs/get': Endpoint('GET', '2.0/jobs/runs/get',
+                                      'jobs_runs_get'),
+        # '2.0/jobs/runs/export': Endpoint('GET', '2.0/jobs/runs/export', ''),
+        '2.0/jobs/runs/cancel': Endpoint('POST', '2.0/jobs/runs/cancel',
+                                         'jobs_runs_cancel')
+        # '2.0/jobs/runs/get-output': Endpoint('GET',
+        #                                      '2.0/jobs/runs/get-output', '')
+    }
+    # TODO: https://docs.databricks.com/api/latest/index.html
+    # TODO: https://docs.databricks.com/api/latest/dbfs.html
+    # TODO: https://docs.databricks.com/api/latest/groups.html
+    # TODO: https://docs.databricks.com/api/latest/instance-profiles.html
+    # TODO: https://docs.databricks.com/api/latest/libraries.html
+    # TODO: https://docs.databricks.com/api/latest/tokens.html
+    # TODO: https://docs.databricks.com/api/latest/workspace.html
+
+    def __init__(self, databricks_conn_id='databricks_default',
+                 timeout_seconds=180, retry_limit=3):
         """
-        :param databricks_conn_id: The name of the databricks connection to use.
+        :param databricks_conn_id: The name of the databricks connection to
+            use.
         :type databricks_conn_id: string
-        :param timeout_seconds: The amount of time in seconds the requests library
-            will wait before timing-out.
+        :param timeout_seconds: The amount of time in seconds the requests
+            library will wait before timing-out.
         :type timeout_seconds: int
-        :param retry_limit: The number of times to retry the connection in case of
-            service outages.
+        :param retry_limit: The number of times to retry the connection in
+            case of service outages.
         :type retry_limit: int
         """
         self.databricks_conn_id = databricks_conn_id
@@ -67,34 +93,32 @@ def __init__(
 
     @staticmethod
     def _parse_host(host):
-        """
+        """Verify connection host setting provided by the user.
+
         The purpose of this function is to be robust to improper connections
         settings provided by users, specifically in the host field.
 
+        For example -- when users supply ``https://xx.cloud.databricks.com`` as
+        the host, we must strip out the protocol to get the host.
 
-        For example -- when users supply ``https://xx.cloud.databricks.com`` as the
-        host, we must strip out the protocol to get the host.
+        >>> host = 'https://xx.cloud.databricks.com'
         >>> h = DatabricksHook()
-        >>> assert h._parse_host('https://xx.cloud.databricks.com') == \
-            'xx.cloud.databricks.com'
+        >>> assert h._parse_host(host) == 'xx.cloud.databricks.com'
 
-        In the case where users supply the correct ``xx.cloud.databricks.com`` as the
+        In the case where users supply the correct ``xx.cloud.databricks.com``
         host, this function is a no-op.
-        >>> assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
-        """
-        urlparse_host = urlparse.urlparse(host).hostname
-        if urlparse_host:
-            # In this case, host = https://xx.cloud.databricks.com
-            return urlparse_host
-        else:
-            # In this case, host = xx.cloud.databricks.com
-            return host
 
-    def _do_api_call(self, endpoint_info, json):
+        >>> host = 'xx.cloud.databricks.com'
+        >>> assert h._parse_host(host) == 'xx.cloud.databricks.com'
+
         """
-        Utility function to perform an API call with retries
-        :param endpoint_info: Tuple of method and endpoint
-        :type endpoint_info: (string, string)
+        return urlparse(host).hostname or host
+
+    def _do_api_call(self, endpoint, json):
+        """Perform an API call with retries.
+
+        :param endpoint_info: Instance of Endpoint with http_method and path.
+        :type endpoint_info: Endpoint
         :param json: Parameters for this API call.
         :type json: dict
         :return: If the api call returns a OK status code,
@@ -102,24 +126,28 @@ def _do_api_call(self, endpoint_info, json):
             we throw an AirflowException.
         :rtype: dict
         """
-        method, endpoint = endpoint_info
-        url = 'https://{host}/{endpoint}'.format(
+        if isinstance(endpoint, str):
+            endpoint = self.API[endpoint]
+        url = 'https://{host}/api/{endpoint}'.format(
             host=self._parse_host(self.databricks_conn.host),
-            endpoint=endpoint)
+            endpoint=endpoint.path)
+        self.log.info('Calling endpoint (%s) with url: %s', endpoint, url)
         if 'token' in self.databricks_conn.extra_dejson:
             self.log.info('Using token auth.')
             auth = _TokenAuth(self.databricks_conn.extra_dejson['token'])
         else:
             self.log.info('Using basic auth.')
             auth = (self.databricks_conn.login, self.databricks_conn.password)
-        if method == 'GET':
+        if endpoint.http_method == 'GET':
             request_func = requests.get
-        elif method == 'POST':
+        elif endpoint.http_method == 'POST':
             request_func = requests.post
         else:
-            raise AirflowException('Unexpected HTTP Method: ' + method)
+            raise AirflowException('Unexpected HTTP Method: {}'.format(
+                endpoint.http_method
+            ))
 
-        for attempt_num in range(1, self.retry_limit + 1):
+        for attempt in range(1, self.retry_limit + 1):
             try:
                 response = request_func(
                     url,
@@ -132,79 +160,119 @@ def _do_api_call(self, endpoint_info, json):
                 else:
                     # In this case, the user probably made a mistake.
                     # Don't retry.
-                    raise AirflowException('Response: {0}, Status Code: {1}'.format(
-                        response.content, response.status_code))
-            except (requests_exceptions.ConnectionError,
-                    requests_exceptions.Timeout) as e:
-                self.log.error(
-                    'Attempt %s API Request to Databricks failed with reason: %s',
-                    attempt_num, e
-                )
-        raise AirflowException(('API requests to Databricks failed {} times. ' +
-                               'Giving up.').format(self.retry_limit))
-
-    def submit_run(self, json):
-        """
-        Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.
+                    msg = 'Response: {}, Status Code: {}'
+                    raise AirflowException(msg.format(response.content,
+                                                      response.status_code))
+            except (ConnectionError, Timeout) as exception:
+                msg = 'Attempt {} API Request to Databricks failed with'
+                msg += 'reason: {}'
+                self.log.error(msg.format(attempt, exception))
+
+        msg = 'API requests to Databricks failed {} times. Giving up.'
+        raise AirflowException(msg.format(self.retry_limit))
 
-        :param json: The data used in the body of the request to the ``submit`` endpoint.
+    def get_api_method(self, endpoint):
+        """Retrieve the method that handle one endpoint."""
+        try:
+            return getattr(self, self.API[endpoint].method)
+        except AttributeError:
+            AirflowException('Endpoint handler %s not yet implemented.',
+                             endpoint)
+
+    def jobs_runs_submit(self, json):
+        """Call the ``api/2.0/jobs/runs/submit`` endpoint.
+
+        :param json: The data used in the body of the request to the
+            ``submit`` endpoint.
         :type json: dict
-        :return: the run_id as a string
-        :rtype: string
+        :return: A dict with the run_id as a string
+        :rtype: dict
+        """
+        return self._do_api_call(self.API['2.0/jobs/runs/submit'], json)
+
+    def jobs_runs_get(self, run_id):
+        """Call the ``2.0/jobs/runs/get`` endpoint.
+
+        Retrieves the metadata of a run.
+
+        :param run_id: The id of a run you wish to get the data about.
+        :type run_id: int or string
+        :return: a dict with the metadata from a run (endpoint json response)
         """
-        response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
-        return response['run_id']
+        return self._do_api_call(self.API['2.0/jobs/runs/get'],
+                                 {'run_id': run_id})
+
+    def jobs_runs_cancel(self, run_id):
+        """Call the ``2.0/jobs/runs/cancel`` endpoint.
+
+        Send a CANCEL command to a specific run.
+
+        This method does not verify if the run was actually cancelled.
+
+        :param  run_id: The id of the RUN.
+        :type run_id: int
+        """
+        self._do_api_call(self.API['2.0/jobs/runs/cancel'],
+                          {'run_id': run_id})
 
     def get_run_page_url(self, run_id):
-        json = {'run_id': run_id}
-        response = self._do_api_call(GET_RUN_ENDPOINT, json)
-        return response['run_page_url']
+        """Get the webpage url to follow the run status.
+
+        :param run_id: The id of the RUN.
+        :type run_id: int
+        :return: A URL, such as:
+            https://<user>.cloud.databricks.com/#job/1234/run/2
+        :rtype: string
+        """
+        return self.jobs_runs_get(run_id).get('run_page_url')
 
     def get_run_state(self, run_id):
-        json = {'run_id': run_id}
-        response = self._do_api_call(GET_RUN_ENDPOINT, json)
-        state = response['state']
-        life_cycle_state = state['life_cycle_state']
+        """Get the state of a RUN.
+
+        :param  run_id: The id of the RUN.
+        :type run_id: int
+        :return: A RunState object with the run state information.
+        :rtype: RunState
+        """
+        response = self.jobs_runs_get(run_id)
+        state = response.get('state')
+        life_cycle_state = state.get('life_cycle_state')
         # result_state may not be in the state if not terminal
         result_state = state.get('result_state', None)
-        state_message = state['state_message']
+        state_message = state.get('state_message')
         return RunState(life_cycle_state, result_state, state_message)
 
-    def cancel_run(self, run_id):
-        json = {'run_id': run_id}
-        self._do_api_call(CANCEL_RUN_ENDPOINT, json)
-
-
-RUN_LIFE_CYCLE_STATES = [
-    'PENDING',
-    'RUNNING',
-    'TERMINATING',
-    'TERMINATED',
-    'SKIPPED',
-    'INTERNAL_ERROR'
-]
-
 
 class RunState:
     """
     Utility class for the run state concept of Databricks runs.
     """
+
+    STATES = [
+        'PENDING',
+        'RUNNING',
+        'TERMINATING',
+        'TERMINATED',
+        'SKIPPED',
+        'INTERNAL_ERROR'
+    ]
+
+    TERMINAL_STATES = ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR')
+
     def __init__(self, life_cycle_state, result_state, state_message):
         self.life_cycle_state = life_cycle_state
         self.result_state = result_state
         self.state_message = state_message
 
-    @property
     def is_terminal(self):
-        if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
-            raise AirflowException(
-                ('Unexpected life cycle state: {}: If the state has '
-                 'been introduced recently, please check the Databricks user '
-                 'guide for troubleshooting information').format(
-                    self.life_cycle_state))
-        return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR')
-
-    @property
+        """Check wether the terminal is a terminal state."""
+        if self.life_cycle_state not in RunState.STATES:
+            msg = 'Unexpected life cycle state: {}: If the state has been '
+            msg += 'introduced recently, please check the Databricks user '
+            msg += 'guide for troubleshooting information'
+            raise AirflowException(msg.format(self.life_cycle_state))
+        return self.life_cycle_state in RunState.TERMINAL_STATES
+
     def is_successful(self):
         return self.result_state == 'SUCCESS'
 
@@ -218,9 +286,9 @@ def __repr__(self):
 
 
 class _TokenAuth(AuthBase):
-    """
-    Helper class for requests Auth field. AuthBase requires you to implement the __call__
-    magic function.
+    """Helper class for requests Auth field.
+
+    AuthBase requires you to implement the __call__ magic function.
     """
     def __init__(self, token):
         self.token = token
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
index 7b8d522dba..1be69801aa 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -18,11 +18,12 @@
 # under the License.
 #
 
-import six
 import time
 
-from airflow.exceptions import AirflowException
+import six
+
 from airflow.contrib.hooks.databricks_hook import DatabricksHook
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 
 
@@ -30,8 +31,165 @@
 XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
 
 
-class DatabricksSubmitRunOperator(BaseOperator):
-    """
+class DatabricksOperator(BaseOperator):
+    """Talks to Databricks (Spark) via REST API."""
+
+    # Used in airflow.models.BaseOperator
+    template_fields = ('json',)
+    # Databricks brand color (blue) under white text
+    ui_color = '#1CB1C2'
+    ui_fgcolor = '#fff'
+
+    def __init__(self, endpoint, json=None, run_name=None,
+                 timeout_seconds=None, databricks_conn_id='databricks_default',
+                 polling_period_seconds=30, databricks_retry_limit=3,
+                 wait_run_end=True, do_xcom_push=False, **kwargs):
+        """Create a new ``DatabricksOperator``.
+
+        :param endpoint: A string representing the endpoint you wish to
+            interact with. One of the options from DatabricksHook.API.
+            Ex: '2.0/jobs/runs/get'
+        :type action: string
+        :param json: A JSON object containing the API parameters to be passed
+            directly to the requested endpoint relative to the given
+            ``action``. The ``timeout_seconds`` and ``run_name`` named
+            arguments will be merged onto this json dictionary if they are
+            provided. If specified, the named parameters will always take
+            precedence and override the top level json keys. This field will be
+            templated.
+
+            .. seealso::
+                For more information about templating see
+                :ref:`jinja-templating`.
+                https://docs.databricks.com/api/latest/jobs.html#runs-submit
+
+        :type json: dict
+        :param run_name: The run name used for this task.
+            By default this will be set to the Airflow ``task_id``. This
+            ``task_id`` is a required parameter of the superclass
+            ``BaseOperator``.
+            This field will be templated.
+        :type run_name: string
+        :param timeout_seconds: The timeout for this run. By default a value of
+            0 is used which means to have no timeout.
+            This field will be templated.
+        :type timeout_seconds: int
+        :param databricks_conn_id: The name of the Airflow connection to use.
+            By default and in the common case this will be
+            ``databricks_default``.
+        :type databricks_conn_id: string
+        :param polling_period_seconds: Controls the rate which we poll for the
+            result of this run. By default the operator will poll every 30
+            seconds.
+        :type polling_period_seconds: int
+        :param databricks_retry_limit: Amount of times retry if the Databricks
+            backend is unreachable. Its value must be greater than or equal to
+            1.
+        :type databricks_retry_limit: int
+        :param wait_run_end: If we should get the 'run_id' from the initial
+            ``action`` and wait until it reaches a terminal state.
+        :param do_xcom_push: Whether we should push run_id and run_page_url to
+            xcom.
+        :type do_xcom_push: boolean
+        """
+        super(DatabricksOperator, self).__init__(**kwargs)
+        self.json = json or {}
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.wait_run_end = wait_run_end
+        self.hook = self._get_hook()
+        self.action = self.hook.get_api_method(endpoint)
+        self.do_xcom_push = do_xcom_push
+        if timeout_seconds is not None:
+            self.json['timeout_seconds'] = timeout_seconds
+        if 'run_name' not in self.json:
+            self.json['run_name'] = run_name or kwargs.get('task_id', None)
+
+        self.json = self._deep_string_coerce(self.json)
+
+        self.run_id = None
+        self._run_page_url = None
+
+    def _deep_string_coerce(self, content, json_path='json'):
+        """Coerce content or all values of content if it is a dict to a string.
+
+        The function will throw if content contains non-string or non-numeric
+        types.
+
+        The reason why we have this function is because the ``self.json`` field
+        must be a dict with only string values. This is because
+        ``render_template`` will fail for numerical values.
+        """
+        coerced = self._deep_string_coerce
+        if isinstance(content, six.string_types):
+            return content
+        elif isinstance(content, six.integer_types + (float,)):
+            # Databricks can tolerate either numeric or string types in the API
+            # backend.
+            return str(content)
+        elif isinstance(content, (list, tuple)):
+            return [coerced(e, '{0}[{1}]'.format(json_path, i))
+                    for i, e in enumerate(content)]
+        elif isinstance(content, dict):
+            return {k: coerced(v, '{0}[{1}]'.format(json_path, k))
+                    for k, v in list(content.items())}
+        msg = 'Type {} used for parameter {} is not a number or a string'
+        raise AirflowException(msg.format(type(content), json_path))
+
+    def _get_hook(self):
+        return DatabricksHook(self.databricks_conn_id,
+                              retry_limit=self.databricks_retry_limit)
+
+    def _log_run_page_url(self):
+        msg = 'View run status, Spark UI, and logs at %s'
+        self.log.info(msg, self._run_page_url)
+
+    def _run_until_terminal_state(self):
+        """Keep running until the run reaches a terminal state."""
+        while True:
+            run_state = self.hook.get_run_state(self.run_id)
+            if run_state.is_terminal():
+                if run_state.is_successful():
+                    self.log.info('%s completed successfully.', self.task_id)
+                    self._log_run_page_url()
+                    return
+                msg = '{t} failed with terminal state: {s}'
+                raise AirflowException(msg.format(t=self.task_id, s=run_state))
+            self.log.info('%s in run state: %s', self.task_id, run_state)
+            self._log_run_page_url()
+            self.log.info('Sleeping for %s seconds.',
+                          self.polling_period_seconds)
+            time.sleep(self.polling_period_seconds)
+
+    def execute(self, context):
+        response = self.action(self.json)
+        if 'run_id' in response:
+            self.run_id = response.get('run_id')
+            if self.do_xcom_push:
+                context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=self.run_id)
+
+            self._run_page_url = self.hook.get_run_page_url(self.run_id)
+            if self.do_xcom_push:
+                context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY,
+                                        value=self._run_page_url)
+
+            self.log.info('Run submitted with run_id: %s', self.run_id)
+            self.log.info('View run status, Spark UI, and logs at %s',
+                          self._run_page_url)
+
+            if self.wait_run_end:
+                self._run_until_terminal_state()
+
+    def on_kill(self):
+        self.hook.cancel_run(self.run_id)
+        msg = 'Task: {t} with run_id: {r} was requested to be cancelled.'
+        self.log.info(msg, t=self.task_id, r=self.run_id)
+
+
+class DatabricksSubmitRunOperator(DatabricksOperator):
+    """Post to the job run Databricks endpoint.
+
     Submits an Spark job run to Databricks using the
     `api/2.0/jobs/runs/submit
     <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_
@@ -52,7 +210,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
             'notebook_path': '/Users/airflow@example.com/PrepareData',
           },
         }
-        notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json)
+        notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run',
+                                                   json=json)
 
     Another way to accomplish the same thing is to use the named parameters
     of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly
@@ -70,11 +229,13 @@ class DatabricksSubmitRunOperator(BaseOperator):
             new_cluster=new_cluster,
             notebook_task=notebook_task)
 
-    In the case where both the json parameter **AND** the named parameters
-    are provided, they will be merged together. If there are conflicts during the merge,
-    the named parameters will take precedence and override the top level ``json`` keys.
+    In the case where both the json parameter **AND** the named parameters are
+    provided, they will be merged together. If there are conflicts during the
+    merge, the named parameters will take precedence and override the top level
+    ``json`` keys.
 
-    Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are
+    Currently the named parameters that ``DatabricksSubmitRunOperator``
+    supports are:
         - ``spark_jar_task``
         - ``notebook_task``
         - ``new_cluster``
@@ -84,41 +245,42 @@ class DatabricksSubmitRunOperator(BaseOperator):
         - ``timeout_seconds``
 
     :param json: A JSON object containing API parameters which will be passed
-        directly to the ``api/2.0/jobs/runs/submit`` endpoint. The other named parameters
-        (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will
-        be merged with this json dictionary if they are provided.
-        If there are conflicts during the merge, the named parameters will
-        take precedence and override the top level json keys. (templated)
+        directly to the ``api/2.0/jobs/runs/submit`` endpoint. The other named
+        parameters (i.e. ``spark_jar_task``, ``notebook_task``..) to this
+        operator will be merged with this json dictionary if they are provided.
+        If there are conflicts during the merge, the named parameters will take
+        precedence and override the top level json keys. This field will be
+        templated.
 
         .. seealso::
             For more information about templating see :ref:`jinja-templating`.
             https://docs.databricks.com/api/latest/jobs.html#runs-submit
     :type json: dict
-    :param spark_jar_task: The main class and parameters for the JAR task. Note that
-        the actual JAR is specified in the ``libraries``.
+    :param spark_jar_task: The main class and parameters for the JAR task. Note
+        that the actual JAR is specified in the ``libraries``.
         *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
         This field will be templated.
 
         .. seealso::
             https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask
     :type spark_jar_task: dict
-    :param notebook_task: The notebook path and parameters for the notebook task.
-        *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
-        This field will be templated.
+    :param notebook_task: The notebook path and parameters for the notebook
+        task. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be
+        specified. This field will be templated.
 
         .. seealso::
             https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask
     :type notebook_task: dict
     :param new_cluster: Specs for a new cluster on which this task will be run.
-        *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified.
-        This field will be templated.
+        *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be
+        specified. This field will be templated.
 
         .. seealso::
             https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster
     :type new_cluster: dict
-    :param existing_cluster_id: ID for existing cluster on which to run this task.
-        *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified.
-        This field will be templated.
+    :param existing_cluster_id: ID for existing cluster on which to run this
+        task. *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be
+        specified. This field will be templated.
     :type existing_cluster_id: string
     :param libraries: Libraries which this run will use.
         This field will be templated.
@@ -127,145 +289,29 @@ class DatabricksSubmitRunOperator(BaseOperator):
             https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary
     :type libraries: list of dicts
     :param run_name: The run name used for this task.
-        By default this will be set to the Airflow ``task_id``. This ``task_id`` is a
-        required parameter of the superclass ``BaseOperator``.
+        By default this will be set to the Airflow ``task_id``. This
+        ``task_id`` is a required parameter of the superclass ``BaseOperator``.
         This field will be templated.
     :type run_name: string
-    :param timeout_seconds: The timeout for this run. By default a value of 0 is used
-        which means to have no timeout.
-        This field will be templated.
-    :type timeout_seconds: int32
-    :param databricks_conn_id: The name of the Airflow connection to use.
-        By default and in the common case this will be ``databricks_default``. To use
-        token based authentication, provide the key ``token`` in the extra field for the
-        connection.
-    :type databricks_conn_id: string
-    :param polling_period_seconds: Controls the rate which we poll for the result of
-        this run. By default the operator will poll every 30 seconds.
-    :type polling_period_seconds: int
-    :param databricks_retry_limit: Amount of times retry if the Databricks backend is
-        unreachable. Its value must be greater than or equal to 1.
-    :type databricks_retry_limit: int
-    :param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
-    :type do_xcom_push: boolean
     """
-    # Used in airflow.models.BaseOperator
-    template_fields = ('json',)
-    # Databricks brand color (blue) under white text
-    ui_color = '#1CB1C2'
-    ui_fgcolor = '#fff'
 
-    def __init__(
-            self,
-            json=None,
-            spark_jar_task=None,
-            notebook_task=None,
-            new_cluster=None,
-            existing_cluster_id=None,
-            libraries=None,
-            run_name=None,
-            timeout_seconds=None,
-            databricks_conn_id='databricks_default',
-            polling_period_seconds=30,
-            databricks_retry_limit=3,
-            do_xcom_push=False,
-            **kwargs):
-        """
-        Creates a new ``DatabricksSubmitRunOperator``.
-        """
-        super(DatabricksSubmitRunOperator, self).__init__(**kwargs)
-        self.json = json or {}
-        self.databricks_conn_id = databricks_conn_id
-        self.polling_period_seconds = polling_period_seconds
-        self.databricks_retry_limit = databricks_retry_limit
+    def __init__(self, json=None, spark_jar_task=None, notebook_task=None,
+                 new_cluster=None, existing_cluster_id=None, libraries=None,
+                 **kwargs):
+        """Create a new ``DatabricksSubmitRunOperator``."""
+        json = json or {}
         if spark_jar_task is not None:
-            self.json['spark_jar_task'] = spark_jar_task
+            json['spark_jar_task'] = spark_jar_task
         if notebook_task is not None:
-            self.json['notebook_task'] = notebook_task
+            json['notebook_task'] = notebook_task
         if new_cluster is not None:
-            self.json['new_cluster'] = new_cluster
+            json['new_cluster'] = new_cluster
         if existing_cluster_id is not None:
-            self.json['existing_cluster_id'] = existing_cluster_id
+            json['existing_cluster_id'] = existing_cluster_id
         if libraries is not None:
-            self.json['libraries'] = libraries
-        if run_name is not None:
-            self.json['run_name'] = run_name
-        if timeout_seconds is not None:
-            self.json['timeout_seconds'] = timeout_seconds
-        if 'run_name' not in self.json:
-            self.json['run_name'] = run_name or kwargs['task_id']
+            json['libraries'] = libraries
 
-        self.json = self._deep_string_coerce(self.json)
-        # This variable will be used in case our task gets killed.
-        self.run_id = None
-        self.do_xcom_push = do_xcom_push
-
-    def _deep_string_coerce(self, content, json_path='json'):
-        """
-        Coerces content or all values of content if it is a dict to a string. The
-        function will throw if content contains non-string or non-numeric types.
-
-        The reason why we have this function is because the ``self.json`` field must be a
-         dict with only string values. This is because ``render_template`` will fail
-        for numerical values.
-        """
-        c = self._deep_string_coerce
-        if isinstance(content, six.string_types):
-            return content
-        elif isinstance(content, six.integer_types + (float,)):
-            # Databricks can tolerate either numeric or string types in the API backend.
-            return str(content)
-        elif isinstance(content, (list, tuple)):
-            return [c(e, '{0}[{1}]'.format(json_path, i)) for i, e in enumerate(content)]
-        elif isinstance(content, dict):
-            return {k: c(v, '{0}[{1}]'.format(json_path, k))
-                    for k, v in list(content.items())}
-        else:
-            param_type = type(content)
-            msg = 'Type {0} used for parameter {1} is not a number or a string'\
-                .format(param_type, json_path)
-            raise AirflowException(msg)
-
-    def _log_run_page_url(self, url):
-        self.log.info('View run status, Spark UI, and logs at %s', url)
-
-    def get_hook(self):
-        return DatabricksHook(
-            self.databricks_conn_id,
-            retry_limit=self.databricks_retry_limit)
-
-    def execute(self, context):
-        hook = self.get_hook()
-        self.run_id = hook.submit_run(self.json)
-        if self.do_xcom_push:
-            context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=self.run_id)
-        self.log.info('Run submitted with run_id: %s', self.run_id)
-        run_page_url = hook.get_run_page_url(self.run_id)
-        if self.do_xcom_push:
-            context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
-        self._log_run_page_url(run_page_url)
-        while True:
-            run_state = hook.get_run_state(self.run_id)
-            if run_state.is_terminal:
-                if run_state.is_successful:
-                    self.log.info('%s completed successfully.', self.task_id)
-                    self._log_run_page_url(run_page_url)
-                    return
-                else:
-                    error_message = '{t} failed with terminal state: {s}'.format(
-                        t=self.task_id,
-                        s=run_state)
-                    raise AirflowException(error_message)
-            else:
-                self.log.info('%s in run state: %s', self.task_id, run_state)
-                self._log_run_page_url(run_page_url)
-                self.log.info('Sleeping for %s seconds.', self.polling_period_seconds)
-                time.sleep(self.polling_period_seconds)
-
-    def on_kill(self):
-        hook = self.get_hook()
-        hook.cancel_run(self.run_id)
-        self.log.info(
-            'Task: %s with run_id: %s was requested to be cancelled.',
-            self.task_id, self.run_id
-        )
+        super(DatabricksSubmitRunOperator, self).__init__(
+            endpoint='2.0/jobs/runs/submit',
+            json=json,
+            **kwargs)
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index aca8dd9600..ba6a419a0a 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -22,11 +22,11 @@
 import unittest
 
 from airflow import __version__
-from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.utils import db
-from requests import exceptions as requests_exceptions
+from requests.exceptions import (ConnectionError, Timeout)
 
 try:
     from unittest import mock
@@ -79,12 +79,14 @@ def get_run_endpoint(host):
     """
     return 'https://{}/api/2.0/jobs/runs/get'.format(host)
 
+
 def cancel_run_endpoint(host):
     """
     Utility function to generate the get run endpoint given the host.
     """
     return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
 
+
 class DatabricksHookTest(unittest.TestCase):
     """
     Tests for DatabricksHook.
@@ -111,19 +113,20 @@ def test_parse_host_with_scheme(self):
 
     def test_init_bad_retry_limit(self):
         with self.assertRaises(ValueError):
-            DatabricksHook(retry_limit = 0)
+            DatabricksHook(retry_limit=0)
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_do_api_call_with_error_retry(self, mock_requests):
-        for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
+        for exception in [ConnectionError, Timeout]:
             with mock.patch.object(self.hook.log, 'error') as mock_errors:
                 mock_requests.reset_mock()
                 mock_requests.post.side_effect = exception()
 
                 with self.assertRaises(AirflowException):
-                    self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+                    self.hook._do_api_call('2.0/jobs/runs/submit', {})
 
-                self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit)
+                self.assertEquals(len(mock_errors.mock_calls),
+                                  self.hook.retry_limit)
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_do_api_call_with_bad_status_code(self, mock_requests):
@@ -131,7 +134,7 @@ def test_do_api_call_with_bad_status_code(self, mock_requests):
         status_code_mock = mock.PropertyMock(return_value=500)
         type(mock_requests.post.return_value).status_code = status_code_mock
         with self.assertRaises(AirflowException):
-            self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+            self.hook._do_api_call('2.0/jobs/runs/submit', {})
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_submit_run(self, mock_requests):
@@ -140,10 +143,10 @@ def test_submit_run(self, mock_requests):
         status_code_mock = mock.PropertyMock(return_value=200)
         type(mock_requests.post.return_value).status_code = status_code_mock
         json = {
-          'notebook_task': NOTEBOOK_TASK,
-          'new_cluster': NEW_CLUSTER
+            'notebook_task': NOTEBOOK_TASK,
+            'new_cluster': NEW_CLUSTER
         }
-        run_id = self.hook.submit_run(json)
+        run_id = self.hook.jobs_runs_submit(json).get('run_id')
 
         self.assertEquals(run_id, '1')
         mock_requests.post.assert_called_once_with(
@@ -200,7 +203,7 @@ def test_cancel_run(self, mock_requests):
         status_code_mock = mock.PropertyMock(return_value=200)
         type(mock_requests.post.return_value).status_code = status_code_mock
 
-        self.hook.cancel_run(RUN_ID)
+        self.hook.jobs_runs_cancel(RUN_ID)
 
         mock_requests.post.assert_called_once_with(
             cancel_run_endpoint(HOST),
@@ -231,10 +234,10 @@ def test_submit_run(self, mock_requests):
         status_code_mock = mock.PropertyMock(return_value=200)
         type(mock_requests.post.return_value).status_code = status_code_mock
         json = {
-          'notebook_task': NOTEBOOK_TASK,
-          'new_cluster': NEW_CLUSTER
+            'notebook_task': NOTEBOOK_TASK,
+            'new_cluster': NEW_CLUSTER
         }
-        run_id = self.hook.submit_run(json)
+        run_id = self.hook.jobs_runs_submit(json).get('run_id')
 
         self.assertEquals(run_id, '1')
         args = mock_requests.post.call_args
diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py
index f77da2ec18..9bc2af6db5 100644
--- a/tests/contrib/operators/test_databricks_operator.py
+++ b/tests/contrib/operators/test_databricks_operator.py
@@ -173,14 +173,16 @@ def test_exec_success(self, db_mock_class):
         """
         Test the execute function in case where the run is successful.
         """
+        db_mock = db_mock_class.return_value
+        db_mock.jobs_runs_submit.return_value = {'run_id': 1}
+        db_mock.get_api_method.return_value = db_mock.jobs_runs_submit
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
         run = {
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
-        db_mock = db_mock_class.return_value
-        db_mock.submit_run.return_value = 1
-        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
 
         op.execute(None)
 
@@ -190,9 +192,9 @@ def test_exec_success(self, db_mock_class):
           'run_name': TASK_ID
         })
         db_mock_class.assert_called_once_with(
-                DEFAULT_CONN_ID,
-                retry_limit=op.databricks_retry_limit)
-        db_mock.submit_run.assert_called_once_with(expected)
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit)
+        db_mock.jobs_runs_submit.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run_state.assert_called_once_with(RUN_ID)
         self.assertEquals(RUN_ID, op.run_id)
@@ -202,14 +204,16 @@ def test_exec_failure(self, db_mock_class):
         """
         Test the execute function in case where the run failed.
         """
+        db_mock = db_mock_class.return_value
+        db_mock.jobs_runs_submit.return_value = {'run_id': 1}
+        db_mock.get_api_method.return_value = db_mock.jobs_runs_submit
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
+
         run = {
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
-        db_mock = db_mock_class.return_value
-        db_mock.submit_run.return_value = 1
-        db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
 
         with self.assertRaises(AirflowException):
             op.execute(None)
@@ -219,10 +223,12 @@ def test_exec_failure(self, db_mock_class):
           'notebook_task': NOTEBOOK_TASK,
           'run_name': TASK_ID,
         })
+        endpoint = '2.0/jobs/runs/submit'
         db_mock_class.assert_called_once_with(
                 DEFAULT_CONN_ID,
                 retry_limit=op.databricks_retry_limit)
-        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_api_method.assert_called_once_with(endpoint)
+        db_mock.jobs_runs_submit.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run_state.assert_called_once_with(RUN_ID)
         self.assertEquals(RUN_ID, op.run_id)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services