You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/09/06 07:07:31 UTC

[GitHub] Fokko closed pull request #3813: [AIRFLOW-1998] Implemented DatabricksRunNowOperator for jobs/run-now …

Fokko closed pull request #3813: [AIRFLOW-1998] Implemented DatabricksRunNowOperator for jobs/run-now …
URL: https://github.com/apache/incubator-airflow/pull/3813
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index cb2ba9bd00..284db98d91 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -37,6 +37,7 @@
 START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
 TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
 
+RUN_NOW_ENDPOINT = ('POST', 'api/2.0/jobs/run-now')
 SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
 GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
 CANCEL_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/cancel')
@@ -161,6 +162,18 @@ def _log_request_error(self, attempt_num, error):
             attempt_num, error
         )
 
+    def run_now(self, json):
+        """
+        Utility function to call the ``api/2.0/jobs/run-now`` endpoint.
+
+        :param json: The data used in the body of the request to the ``run-now`` endpoint.
+        :type json: dict
+        :return: the run_id as a string
+        :rtype: string
+        """
+        response = self._do_api_call(RUN_NOW_ENDPOINT, json)
+        return response['run_id']
+
     def submit_run(self, json):
         """
         Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
index 3245a99256..aed2d8909c 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -30,6 +30,66 @@
 XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
 
 
+def _deep_string_coerce(content, json_path='json'):
+    """
+    Coerces content or all values of content if it is a dict to a string. The
+    function will throw if content contains non-string or non-numeric types.
+
+    The reason why we have this function is because the ``self.json`` field must be a
+     dict with only string values. This is because ``render_template`` will fail
+    for numerical values.
+    """
+    c = _deep_string_coerce
+    if isinstance(content, six.string_types):
+        return content
+    elif isinstance(content, six.integer_types + (float,)):
+        # Databricks can tolerate either numeric or string types in the API backend.
+        return str(content)
+    elif isinstance(content, (list, tuple)):
+        return [c(e, '{0}[{1}]'.format(json_path, i)) for i, e in enumerate(content)]
+    elif isinstance(content, dict):
+        return {k: c(v, '{0}[{1}]'.format(json_path, k))
+                for k, v in list(content.items())}
+    else:
+        param_type = type(content)
+        msg = 'Type {0} used for parameter {1} is not a number or a string' \
+            .format(param_type, json_path)
+        raise AirflowException(msg)
+
+
+def _handle_databricks_operator_execution(operator, hook, log, context):
+    """
+    Handles the Airflow + Databricks lifecycle logic for a Databricks operator
+    :param operator: Databricks operator being handled
+    :param context: Airflow context
+    """
+    if operator.do_xcom_push:
+        context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
+    log.info('Run submitted with run_id: %s', operator.run_id)
+    run_page_url = hook.get_run_page_url(operator.run_id)
+    if operator.do_xcom_push:
+        context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
+
+    log.info('View run status, Spark UI, and logs at %s', run_page_url)
+    while True:
+        run_state = hook.get_run_state(operator.run_id)
+        if run_state.is_terminal:
+            if run_state.is_successful:
+                log.info('%s completed successfully.', operator.task_id)
+                log.info('View run status, Spark UI, and logs at %s', run_page_url)
+                return
+            else:
+                error_message = '{t} failed with terminal state: {s}'.format(
+                    t=operator.task_id,
+                    s=run_state)
+                raise AirflowException(error_message)
+        else:
+            log.info('%s in run state: %s', operator.task_id, run_state)
+            log.info('View run status, Spark UI, and logs at %s', run_page_url)
+            log.info('Sleeping for %s seconds.', operator.polling_period_seconds)
+            time.sleep(operator.polling_period_seconds)
+
+
 class DatabricksSubmitRunOperator(BaseOperator):
     """
     Submits an Spark job run to Databricks using the
@@ -200,39 +260,202 @@ def __init__(
         if 'run_name' not in self.json:
             self.json['run_name'] = run_name or kwargs['task_id']
 
-        self.json = self._deep_string_coerce(self.json)
+        self.json = _deep_string_coerce(self.json)
         # This variable will be used in case our task gets killed.
         self.run_id = None
         self.do_xcom_push = do_xcom_push
 
-    def _deep_string_coerce(self, content, json_path='json'):
-        """
-        Coerces content or all values of content if it is a dict to a string. The
-        function will throw if content contains non-string or non-numeric types.
+    def get_hook(self):
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay)
+
+    def execute(self, context):
+        hook = self.get_hook()
+        self.run_id = hook.submit_run(self.json)
+        _handle_databricks_operator_execution(self, hook, self.log, context)
+
+    def on_kill(self):
+        hook = self.get_hook()
+        hook.cancel_run(self.run_id)
+        self.log.info(
+            'Task: %s with run_id: %s was requested to be cancelled.',
+            self.task_id, self.run_id
+        )
+
+
+class DatabricksRunNowOperator(BaseOperator):
+    """
+    Runs an existing Spark job run to Databricks using the
+    `api/2.0/jobs/run-now
+    <https://docs.databricks.com/api/latest/jobs.html#run-now>`_
+    API endpoint.
+
+    There are two ways to instantiate this operator.
+
+    In the first way, you can take the JSON payload that you typically use
+    to call the ``api/2.0/jobs/run-now`` endpoint and pass it directly
+    to our ``DatabricksRunNowOperator`` through the ``json`` parameter.
+    For example ::
+        json = {
+          "job_id": 42,
+          "notebook_params": {
+            "dry-run": "true",
+            "oldest-time-to-consider": "1457570074236"
+          }
+        }
+
+        notebook_run = DatabricksRunNowOperator(task_id='notebook_run', json=json)
+
+    Another way to accomplish the same thing is to use the named parameters
+    of the ``DatabricksRunNowOperator`` directly. Note that there is exactly
+    one named parameter for each top level parameter in the ``run-now``
+    endpoint. In this method, your code would look like this: ::
+
+        job_id=42
+
+        notebook_params = {
+            "dry-run": "true",
+            "oldest-time-to-consider": "1457570074236"
+        }
+
+        python_params = ["douglas adams", "42"]
+
+        spark_submit_params = ["--class", "org.apache.spark.examples.SparkPi"]
+
+        notebook_run = DatabricksRunNowOperator(
+            job_id=job_id,
+            notebook_params=notebook_params,
+            python_params=python_params,
+            spark_submit_params=spark_submit_params
+        )
+
+    In the case where both the json parameter **AND** the named parameters
+    are provided, they will be merged together. If there are conflicts during the merge,
+    the named parameters will take precedence and override the top level ``json`` keys.
+
+    Currently the named parameters that ``DatabricksRunNowOperator`` supports are
+        - ``job_id``
+        - ``json``
+        - ``notebook_params``
+        - ``python_params``
+        - ``spark_submit_params``
+
+
+    :param job_id: the job_id of the existing Databricks job.
+        This field will be templated.
+        .. seealso::
+            https://docs.databricks.com/api/latest/jobs.html#run-now
+    :type job_id: string
+    :param json: A JSON object containing API parameters which will be passed
+        directly to the ``api/2.0/jobs/run-now`` endpoint. The other named parameters
+        (i.e. ``notebook_params``, ``spark_submit_params``..) to this operator will
+        be merged with this json dictionary if they are provided.
+        If there are conflicts during the merge, the named parameters will
+        take precedence and override the top level json keys. (templated)
+
+        .. seealso::
+            For more information about templating see :ref:`jinja-templating`.
+            https://docs.databricks.com/api/latest/jobs.html#run-now
+    :type json: dict
+    :param notebook_params: A dict from keys to values for jobs with notebook task,
+        e.g. "notebook_params": {"name": "john doe", "age":  "35"}.
+        The map is passed to the notebook and will be accessible through the
+        dbutils.widgets.get function. See Widgets for more information.
+        If not specified upon run-now, the triggered run will use the
+        job’s base parameters. notebook_params cannot be
+        specified in conjunction with jar_params. The json representation
+        of this field (i.e. {"notebook_params":{"name":"john doe","age":"35"}})
+        cannot exceed 10,000 bytes.
+        This field will be templated.
+
+        .. seealso::
+            https://docs.databricks.com/user-guide/notebooks/widgets.html
+    :type notebook_params: dict
+    :param python_params: A list of parameters for jobs with python tasks,
+        e.g. "python_params": ["john doe", "35"].
+        The parameters will be passed to python file as command line parameters.
+        If specified upon run-now, it would overwrite the parameters specified in
+        job setting.
+        The json representation of this field (i.e. {"python_params":["john doe","35"]})
+        cannot exceed 10,000 bytes.
+        This field will be templated.
+
+        .. seealso::
+            https://docs.databricks.com/api/latest/jobs.html#run-now
+    :type python_params: array of strings
+    :param spark_submit_params: A list of parameters for jobs with spark submit task,
+        e.g. "spark_submit_params": ["--class", "org.apache.spark.examples.SparkPi"].
+        The parameters will be passed to spark-submit script as command line parameters.
+        If specified upon run-now, it would overwrite the parameters specified
+        in job setting.
+        The json representation of this field cannot exceed 10,000 bytes.
+        This field will be templated.
+        .. seealso::
+            https://docs.databricks.com/api/latest/jobs.html#run-now
+    :type spark_submit_params: array of strings
+    :param timeout_seconds: The timeout for this run. By default a value of 0 is used
+        which means to have no timeout.
+        This field will be templated.
+    :type timeout_seconds: int32
+    :param databricks_conn_id: The name of the Airflow connection to use.
+        By default and in the common case this will be ``databricks_default``. To use
+        token based authentication, provide the key ``token`` in the extra field for the
+        connection.
+    :type databricks_conn_id: string
+    :param polling_period_seconds: Controls the rate which we poll for the result of
+        this run. By default the operator will poll every 30 seconds.
+    :type polling_period_seconds: int
+    :param databricks_retry_limit: Amount of times retry if the Databricks backend is
+        unreachable. Its value must be greater than or equal to 1.
+    :type databricks_retry_limit: int
+    :param do_xcom_push: Whether we should push run_id and run_page_url to xcom.
+    :type do_xcom_push: boolean
+    """
+    # Used in airflow.models.BaseOperator
+    template_fields = ('json',)
+    # Databricks brand color (blue) under white text
+    ui_color = '#1CB1C2'
+    ui_fgcolor = '#fff'
+
+    def __init__(
+            self,
+            job_id,
+            json=None,
+            notebook_params=None,
+            python_params=None,
+            spark_submit_params=None,
+            databricks_conn_id='databricks_default',
+            polling_period_seconds=30,
+            databricks_retry_limit=3,
+            databricks_retry_delay=1,
+            do_xcom_push=False,
+            **kwargs):
 
-        The reason why we have this function is because the ``self.json`` field must be a
-         dict with only string values. This is because ``render_template`` will fail
-        for numerical values.
         """
-        c = self._deep_string_coerce
-        if isinstance(content, six.string_types):
-            return content
-        elif isinstance(content, six.integer_types + (float,)):
-            # Databricks can tolerate either numeric or string types in the API backend.
-            return str(content)
-        elif isinstance(content, (list, tuple)):
-            return [c(e, '{0}[{1}]'.format(json_path, i)) for i, e in enumerate(content)]
-        elif isinstance(content, dict):
-            return {k: c(v, '{0}[{1}]'.format(json_path, k))
-                    for k, v in list(content.items())}
-        else:
-            param_type = type(content)
-            msg = 'Type {0} used for parameter {1} is not a number or a string'\
-                .format(param_type, json_path)
-            raise AirflowException(msg)
+        Creates a new ``DatabricksRunNowOperator``.
+        """
+        super(DatabricksRunNowOperator, self).__init__(**kwargs)
+        self.json = json or {}
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
 
-    def _log_run_page_url(self, url):
-        self.log.info('View run status, Spark UI, and logs at %s', url)
+        if job_id is not None:
+            self.json['job_id'] = job_id
+        if notebook_params is not None:
+            self.json['notebook_params'] = notebook_params
+        if python_params is not None:
+            self.json['python_params'] = python_params
+        if spark_submit_params is not None:
+            self.json['spark_submit_params'] = spark_submit_params
+
+        self.json = _deep_string_coerce(self.json)
+        # This variable will be used in case our task gets killed.
+        self.run_id = None
+        self.do_xcom_push = do_xcom_push
 
     def get_hook(self):
         return DatabricksHook(
@@ -242,31 +465,8 @@ def get_hook(self):
 
     def execute(self, context):
         hook = self.get_hook()
-        self.run_id = hook.submit_run(self.json)
-        if self.do_xcom_push:
-            context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=self.run_id)
-        self.log.info('Run submitted with run_id: %s', self.run_id)
-        run_page_url = hook.get_run_page_url(self.run_id)
-        if self.do_xcom_push:
-            context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
-        self._log_run_page_url(run_page_url)
-        while True:
-            run_state = hook.get_run_state(self.run_id)
-            if run_state.is_terminal:
-                if run_state.is_successful:
-                    self.log.info('%s completed successfully.', self.task_id)
-                    self._log_run_page_url(run_page_url)
-                    return
-                else:
-                    error_message = '{t} failed with terminal state: {s}'.format(
-                        t=self.task_id,
-                        s=run_state)
-                    raise AirflowException(error_message)
-            else:
-                self.log.info('%s in run state: %s', self.task_id, run_state)
-                self._log_run_page_url(run_page_url)
-                self.log.info('Sleeping for %s seconds.', self.polling_period_seconds)
-                time.sleep(self.polling_period_seconds)
+        self.run_id = hook.run_now(self.json)
+        _handle_databricks_operator_execution(self, hook, self.log, context)
 
     def on_kill(self):
         hook = self.get_hook()
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index 04a7c8dc3c..090f46caeb 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -54,6 +54,7 @@
 }
 CLUSTER_ID = 'cluster_id'
 RUN_ID = 1
+JOB_ID = 42
 HOST = 'xx.cloud.databricks.com'
 HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
 LOGIN = 'login'
@@ -70,9 +71,21 @@
         'state_message': STATE_MESSAGE
     }
 }
+NOTEBOOK_PARAMS = {
+    "dry-run": "true",
+    "oldest-time-to-consider": "1457570074236"
+}
+JAR_PARAMS = ["param1", "param2"]
 RESULT_STATE = None
 
 
+def run_now_endpoint(host):
+    """
+    Utility function to generate the run now endpoint given the host.
+    """
+    return 'https://{}/api/2.0/jobs/run-now'.format(host)
+
+
 def submit_run_endpoint(host):
     """
     Utility function to generate the submit run endpoint given the host.
@@ -160,6 +173,7 @@ def setUp(self, session=None):
         conn.host = HOST
         conn.login = LOGIN
         conn.password = PASSWORD
+        conn.extra = None
         session.commit()
 
         self.hook = DatabricksHook(retry_delay=0)
@@ -270,6 +284,32 @@ def test_submit_run(self, mock_requests):
             headers=USER_AGENT_HEADER,
             timeout=self.hook.timeout_seconds)
 
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_run_now(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+        json = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'jar_params': JAR_PARAMS,
+            'job_id': JOB_ID
+        }
+        run_id = self.hook.run_now(json)
+
+        self.assertEquals(run_id, '1')
+
+        mock_requests.post.assert_called_once_with(
+            run_now_endpoint(HOST),
+            json={
+                'notebook_params': NOTEBOOK_PARAMS,
+                'jar_params': JAR_PARAMS,
+                'job_id': JOB_ID
+            },
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
     def test_get_run_page_url(self, mock_requests):
         mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py
index afe1a92f28..5884fc3c98 100644
--- a/tests/contrib/operators/test_databricks_operator.py
+++ b/tests/contrib/operators/test_databricks_operator.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,7 +23,9 @@
 from datetime import datetime
 
 from airflow.contrib.hooks.databricks_hook import RunState
+import airflow.contrib.operators.databricks_operator as databricks_operator
 from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
+from airflow.contrib.operators.databricks_operator import DatabricksRunNowOperator
 from airflow.exceptions import AirflowException
 from airflow.models import DAG
 
@@ -58,6 +60,40 @@
 EXISTING_CLUSTER_ID = 'existing-cluster-id'
 RUN_NAME = 'run-name'
 RUN_ID = 1
+JOB_ID = 42
+NOTEBOOK_PARAMS = {
+    "dry-run": "true",
+    "oldest-time-to-consider": "1457570074236"
+}
+JAR_PARAMS = ["param1", "param2"]
+RENDERED_TEMPLATED_JAR_PARAMS = [
+    '/test-{0}'.format(DATE)
+]
+TEMPLATED_JAR_PARAMS = [
+    '/test-{{ ds }}'
+]
+PYTHON_PARAMS = ["john doe", "35"]
+SPARK_SUBMIT_PARAMS = ["--class", "org.apache.spark.examples.SparkPi"]
+
+
+class DatabricksOperatorSharedFunctions(unittest.TestCase):
+    def test_deep_string_coerce(self):
+        test_json = {
+            'test_int': 1,
+            'test_float': 1.0,
+            'test_dict': {'key': 'value'},
+            'test_list': [1, 1.0, 'a', 'b'],
+            'test_tuple': (1, 1.0, 'a', 'b')
+        }
+
+        expected = {
+            'test_int': '1',
+            'test_float': '1.0',
+            'test_dict': {'key': 'value'},
+            'test_list': ['1', '1.0', 'a', 'b'],
+            'test_tuple': ['1', '1.0', 'a', 'b']
+        }
+        self.assertDictEqual(databricks_operator._deep_string_coerce(test_json), expected)
 
 
 class DatabricksSubmitRunOperatorTest(unittest.TestCase):
@@ -65,12 +101,15 @@ def test_init_with_named_parameters(self):
         """
         Test the initializer with the named parameters.
         """
-        op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK)
-        expected = op._deep_string_coerce({
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
-          'run_name': TASK_ID
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID,
+                                         new_cluster=NEW_CLUSTER,
+                                         notebook_task=NOTEBOOK_TASK)
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+            'run_name': TASK_ID
         })
+
         self.assertDictEqual(expected, op.json)
 
     def test_init_with_json(self):
@@ -78,14 +117,14 @@ def test_init_with_json(self):
         Test the initializer with json data.
         """
         json = {
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = op._deep_string_coerce({
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
-          'run_name': TASK_ID
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+            'run_name': TASK_ID
         })
         self.assertDictEqual(expected, op.json)
 
@@ -99,10 +138,10 @@ def test_init_with_specified_run_name(self):
           'run_name': RUN_NAME
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = op._deep_string_coerce({
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
-          'run_name': RUN_NAME
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+            'run_name': RUN_NAME
         })
         self.assertDictEqual(expected, op.json)
 
@@ -114,14 +153,16 @@ def test_init_with_merging(self):
         """
         override_new_cluster = {'workers': 999}
         json = {
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
         }
-        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster)
-        expected = op._deep_string_coerce({
-          'new_cluster': override_new_cluster,
-          'notebook_task': NOTEBOOK_TASK,
-          'run_name': TASK_ID,
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID,
+                                         json=json,
+                                         new_cluster=override_new_cluster)
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': override_new_cluster,
+            'notebook_task': NOTEBOOK_TASK,
+            'run_name': TASK_ID,
         })
         self.assertDictEqual(expected, op.json)
 
@@ -133,10 +174,10 @@ def test_init_with_templating(self):
         dag = DAG('test', start_date=datetime.now())
         op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
         op.json = op.render_template('json', op.json, {'ds': DATE})
-        expected = op._deep_string_coerce({
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK,
-          'run_name': TASK_ID,
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK,
+            'run_name': TASK_ID,
         })
         self.assertDictEqual(expected, op.json)
 
@@ -146,27 +187,9 @@ def test_init_with_bad_type(self):
         }
         # Looks a bit weird since we have to escape regex reserved symbols.
         exception_message = 'Type \<(type|class) \'datetime.datetime\'\> used ' + \
-                        'for parameter json\[test\] is not a number or a string'
+                            'for parameter json\[test\] is not a number or a string'
         with self.assertRaisesRegexp(AirflowException, exception_message):
-            op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-
-    def test_deep_string_coerce(self):
-        op = DatabricksSubmitRunOperator(task_id='test')
-        test_json = {
-            'test_int': 1,
-            'test_float': 1.0,
-            'test_dict': {'key': 'value'},
-            'test_list': [1, 1.0, 'a', 'b'],
-            'test_tuple': (1, 1.0, 'a', 'b')
-        }
-        expected = {
-            'test_int': '1',
-            'test_float': '1.0',
-            'test_dict': {'key': 'value'},
-            'test_list': ['1', '1.0', 'a', 'b'],
-            'test_tuple': ['1', '1.0', 'a', 'b']
-        }
-        self.assertDictEqual(op._deep_string_coerce(test_json), expected)
+            DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
 
     @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
     def test_exec_success(self, db_mock_class):
@@ -184,15 +207,16 @@ def test_exec_success(self, db_mock_class):
 
         op.execute(None)
 
-        expected = op._deep_string_coerce({
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
-          'run_name': TASK_ID
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+            'run_name': TASK_ID
         })
         db_mock_class.assert_called_once_with(
             DEFAULT_CONN_ID,
             retry_limit=op.databricks_retry_limit,
             retry_delay=op.databricks_retry_delay)
+
         db_mock.submit_run.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run_state.assert_called_once_with(RUN_ID)
@@ -215,10 +239,10 @@ def test_exec_failure(self, db_mock_class):
         with self.assertRaises(AirflowException):
             op.execute(None)
 
-        expected = op._deep_string_coerce({
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
-          'run_name': TASK_ID,
+        expected = databricks_operator._deep_string_coerce({
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
+            'run_name': TASK_ID,
         })
         db_mock_class.assert_called_once_with(
             DEFAULT_CONN_ID,
@@ -232,8 +256,8 @@ def test_exec_failure(self, db_mock_class):
     @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
     def test_on_kill(self, db_mock_class):
         run = {
-          'new_cluster': NEW_CLUSTER,
-          'notebook_task': NOTEBOOK_TASK,
+            'new_cluster': NEW_CLUSTER,
+            'notebook_task': NOTEBOOK_TASK,
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
         db_mock = db_mock_class.return_value
@@ -243,3 +267,173 @@ def test_on_kill(self, db_mock_class):
 
         db_mock.cancel_run.assert_called_once_with(RUN_ID)
 
+
+class DatabricksRunNowOperatorTest(unittest.TestCase):
+
+    def test_init_with_named_parameters(self):
+        """
+        Test the initializer with the named parameters.
+        """
+        op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
+        expected = databricks_operator._deep_string_coerce({
+            'job_id': 42
+        })
+
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_json(self):
+        """
+        Test the initializer with json data.
+        """
+        json = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'jar_params': JAR_PARAMS,
+            'python_params': PYTHON_PARAMS,
+            'spark_submit_params': SPARK_SUBMIT_PARAMS
+        }
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)
+
+        expected = databricks_operator._deep_string_coerce({
+            'notebook_params': NOTEBOOK_PARAMS,
+            'jar_params': JAR_PARAMS,
+            'python_params': PYTHON_PARAMS,
+            'spark_submit_params': SPARK_SUBMIT_PARAMS,
+            'job_id': JOB_ID
+        })
+
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_merging(self):
+        """
+        Test the initializer when json and other named parameters are both
+        provided. The named parameters should override top level keys in the
+        json dict.
+        """
+        override_notebook_params = {'workers': 999}
+        json = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'jar_params': JAR_PARAMS
+        }
+
+        op = DatabricksRunNowOperator(task_id=TASK_ID,
+                                      json=json,
+                                      job_id=JOB_ID,
+                                      notebook_params=override_notebook_params,
+                                      python_params=PYTHON_PARAMS,
+                                      spark_submit_params=SPARK_SUBMIT_PARAMS)
+
+        expected = databricks_operator._deep_string_coerce({
+            'notebook_params': override_notebook_params,
+            'jar_params': JAR_PARAMS,
+            'python_params': PYTHON_PARAMS,
+            'spark_submit_params': SPARK_SUBMIT_PARAMS,
+            'job_id': JOB_ID
+        })
+
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_templating(self):
+        json = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'jar_params': TEMPLATED_JAR_PARAMS
+        }
+
+        dag = DAG('test', start_date=datetime.now())
+        op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, json=json)
+        op.json = op.render_template('json', op.json, {'ds': DATE})
+        expected = databricks_operator._deep_string_coerce({
+            'notebook_params': NOTEBOOK_PARAMS,
+            'jar_params': RENDERED_TEMPLATED_JAR_PARAMS,
+            'job_id': JOB_ID
+        })
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_bad_type(self):
+        json = {
+            'test': datetime.now()
+        }
+        # Looks a bit weird since we have to escape regex reserved symbols.
+        exception_message = 'Type \<(type|class) \'datetime.datetime\'\> used ' + \
+                            'for parameter json\[test\] is not a number or a string'
+        with self.assertRaisesRegexp(AirflowException, exception_message):
+            DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)
+
+    @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+    def test_exec_success(self, db_mock_class):
+        """
+        Test the execute function in case where the run is successful.
+        """
+        run = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'notebook_task': NOTEBOOK_TASK,
+            'jar_params': JAR_PARAMS
+        }
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
+        op.execute(None)
+
+        expected = databricks_operator._deep_string_coerce({
+            'notebook_params': NOTEBOOK_PARAMS,
+            'notebook_task': NOTEBOOK_TASK,
+            'jar_params': JAR_PARAMS,
+            'job_id': JOB_ID
+        })
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay)
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run_state.assert_called_once_with(RUN_ID)
+        self.assertEquals(RUN_ID, op.run_id)
+
+    @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+    def test_exec_failure(self, db_mock_class):
+        """
+        Test the execute function in case where the run failed.
+        """
+        run = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'notebook_task': NOTEBOOK_TASK,
+            'jar_params': JAR_PARAMS
+        }
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
+
+        with self.assertRaises(AirflowException):
+            op.execute(None)
+
+        expected = databricks_operator._deep_string_coerce({
+            'notebook_params': NOTEBOOK_PARAMS,
+            'notebook_task': NOTEBOOK_TASK,
+            'jar_params': JAR_PARAMS,
+            'job_id': JOB_ID
+        })
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay)
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run_state.assert_called_once_with(RUN_ID)
+        self.assertEquals(RUN_ID, op.run_id)
+
+    @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+    def test_on_kill(self, db_mock_class):
+        run = {
+            'notebook_params': NOTEBOOK_PARAMS,
+            'notebook_task': NOTEBOOK_TASK,
+            'jar_params': JAR_PARAMS
+        }
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
+        db_mock = db_mock_class.return_value
+        op.run_id = RUN_ID
+
+        op.on_kill()
+        db_mock.cancel_run.assert_called_once_with(RUN_ID)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services