You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/02/26 21:56:10 UTC

[airflow] branch main updated: Databricks: add support for triggering jobs by name (#21663)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a1845c6  Databricks: add support for triggering jobs by name (#21663)
a1845c6 is described below

commit a1845c68f9a04e61dd99ccc0a23d17a277babf57
Author: Eugene Karimov <13...@users.noreply.github.com>
AuthorDate: Sat Feb 26 22:55:30 2022 +0100

    Databricks: add support for triggering jobs by name (#21663)
---
 airflow/providers/databricks/hooks/databricks.py   |  50 ++++++++-
 .../providers/databricks/operators/databricks.py   |  17 ++-
 .../providers/databricks/hooks/test_databricks.py  | 117 +++++++++++++++++++++
 .../databricks/operators/test_databricks.py        |  57 ++++++++++
 4 files changed, 239 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py
index ee2523c..5f4d90c 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -28,7 +28,7 @@ or the ``api/2.1/jobs/runs/submit``
 import sys
 import time
 from time import sleep
-from typing import Dict
+from typing import Any, Dict, List, Optional
 from urllib.parse import urlparse
 
 import requests
@@ -57,6 +57,8 @@ CANCEL_RUN_ENDPOINT = ('POST', 'api/2.1/jobs/runs/cancel')
 INSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/install')
 UNINSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/uninstall')
 
+LIST_JOBS_ENDPOINT = ('GET', 'api/2.1/jobs/list')
+
 USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}
 
 RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
@@ -403,6 +405,52 @@ class DatabricksHook(BaseHook):
         response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
         return response['run_id']
 
+    def list_jobs(self, limit: int = 25, offset: int = 0, expand_tasks: bool = False) -> List[Dict[str, Any]]:
+        """
+        Lists the jobs in the Databricks Job Service.
+
+        :param limit: The limit/batch size used to retrieve jobs.
+        :param offset: The offset of the first job to return, relative to the most recently created job.
+        :param expand_tasks: Whether to include task and cluster details in the response.
+        :return: A list of jobs.
+        """
+        has_more = True
+        jobs = []
+
+        while has_more:
+            json = {
+                'limit': limit,
+                'offset': offset,
+                'expand_tasks': expand_tasks,
+            }
+            response = self._do_api_call(LIST_JOBS_ENDPOINT, json)
+            jobs += response['jobs'] if 'jobs' in response else []
+            has_more = response.get('has_more', False)
+            if has_more:
+                offset += len(response['jobs'])
+
+        return jobs
+
+    def find_job_id_by_name(self, job_name: str) -> Optional[int]:
+        """
+        Finds job id by its name. If there are multiple jobs with the same name, raises AirflowException.
+
+        :param job_name: The name of the job to look up.
+        :return: The job_id as an int or None if no job was found.
+        """
+        all_jobs = self.list_jobs()
+        matching_jobs = [j for j in all_jobs if j['settings']['name'] == job_name]
+
+        if len(matching_jobs) > 1:
+            raise AirflowException(
+                f"There are more than one job with name {job_name}. Please delete duplicated jobs first"
+            )
+
+        if not matching_jobs:
+            return None
+        else:
+            return matching_jobs[0]['job_id']
+
     def get_run_page_url(self, run_id: int) -> str:
         """
         Retrieves run_page_url.
diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py
index e53ff9a..d50c522 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -28,7 +28,6 @@ from airflow.providers.databricks.hooks.databricks import DatabricksHook
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
-
 XCOM_RUN_ID_KEY = 'run_id'
 XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
 
@@ -397,6 +396,7 @@ class DatabricksRunNowOperator(BaseOperator):
 
     Currently the named parameters that ``DatabricksRunNowOperator`` supports are
         - ``job_id``
+        - ``job_name``
         - ``json``
         - ``notebook_params``
         - ``python_params``
@@ -409,6 +409,10 @@ class DatabricksRunNowOperator(BaseOperator):
 
         .. seealso::
             https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow
+    :param job_name: the name of the existing Databricks job.
+        It must exist only one job with the specified name.
+        ``job_id`` and ``job_name`` are mutually exclusive.
+        This field will be templated.
     :param json: A JSON object containing API parameters which will be passed
         directly to the ``api/2.1/jobs/run-now`` endpoint. The other named parameters
         (i.e. ``notebook_params``, ``spark_submit_params``..) to this operator will
@@ -489,6 +493,7 @@ class DatabricksRunNowOperator(BaseOperator):
         self,
         *,
         job_id: Optional[str] = None,
+        job_name: Optional[str] = None,
         json: Optional[Any] = None,
         notebook_params: Optional[Dict[str, str]] = None,
         python_params: Optional[List[str]] = None,
@@ -513,6 +518,10 @@ class DatabricksRunNowOperator(BaseOperator):
 
         if job_id is not None:
             self.json['job_id'] = job_id
+        if job_name is not None:
+            self.json['job_name'] = job_name
+        if 'job_id' in self.json and 'job_name' in self.json:
+            raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
         if notebook_params is not None:
             self.json['notebook_params'] = notebook_params
         if python_params is not None:
@@ -536,6 +545,12 @@ class DatabricksRunNowOperator(BaseOperator):
 
     def execute(self, context: 'Context'):
         hook = self._get_hook()
+        if 'job_name' in self.json:
+            job_id = hook.find_job_id_by_name(self.json['job_name'])
+            if job_id is None:
+                raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found")
+            self.json['job_id'] = job_id
+            del self.json['job_name']
         self.run_id = hook.run_now(self.json)
         _handle_databricks_operator_execution(self, hook, self.log, context)
 
diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py
index 951673a..83b4366 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -50,6 +50,7 @@ NEW_CLUSTER = {'spark_version': '2.0.x-scala2.10', 'node_type_id': 'r3.xlarge',
 CLUSTER_ID = 'cluster_id'
 RUN_ID = 1
 JOB_ID = 42
+JOB_NAME = 'job-name'
 HOST = 'xx.cloud.databricks.com'
 HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
 LOGIN = 'login'
@@ -71,6 +72,17 @@ LIBRARIES = [
     {"jar": "dbfs:/mnt/libraries/library.jar"},
     {"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}},
 ]
+LIST_JOBS_RESPONSE = {
+    'jobs': [
+        {
+            'job_id': JOB_ID,
+            'settings': {
+                'name': JOB_NAME,
+            },
+        },
+    ],
+    'has_more': False,
+}
 
 
 def run_now_endpoint(host):
@@ -136,6 +148,13 @@ def uninstall_endpoint(host):
     return f'https://{host}/api/2.0/libraries/uninstall'
 
 
+def list_jobs_endpoint(host):
+    """
+    Utility function to generate the list jobs endpoint giver the host
+    """
+    return f'https://{host}/api/2.1/jobs/list'
+
+
 def create_valid_response_mock(content):
     response = mock.MagicMock()
     response.json.return_value = content
@@ -531,6 +550,104 @@ class TestDatabricksHook(unittest.TestCase):
         aad_token = {'token': 'my_token', 'expires_on': int(time.time())}
         self.assertFalse(self.hook._is_aad_token_valid(aad_token))
 
+    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+    def test_list_jobs_success_single_page(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE
+
+        jobs = self.hook.list_jobs()
+
+        mock_requests.get.assert_called_once_with(
+            list_jobs_endpoint(HOST),
+            json=None,
+            params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+        assert jobs == LIST_JOBS_RESPONSE['jobs']
+
+    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+    def test_list_jobs_success_multiple_pages(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.side_effect = [
+            create_successful_response_mock({**LIST_JOBS_RESPONSE, 'has_more': True}),
+            create_successful_response_mock(LIST_JOBS_RESPONSE),
+        ]
+
+        jobs = self.hook.list_jobs()
+
+        assert mock_requests.get.call_count == 2
+
+        first_call_args = mock_requests.method_calls[0]
+        assert first_call_args[1][0] == list_jobs_endpoint(HOST)
+        assert first_call_args[2]['params'] == {'limit': 25, 'offset': 0, 'expand_tasks': False}
+
+        second_call_args = mock_requests.method_calls[1]
+        assert second_call_args[1][0] == list_jobs_endpoint(HOST)
+        assert second_call_args[2]['params'] == {'limit': 25, 'offset': 1, 'expand_tasks': False}
+
+        assert len(jobs) == 2
+        assert jobs == LIST_JOBS_RESPONSE['jobs'] * 2
+
+    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+    def test_get_job_id_by_name_success(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE
+
+        job_id = self.hook.find_job_id_by_name(JOB_NAME)
+
+        mock_requests.get.assert_called_once_with(
+            list_jobs_endpoint(HOST),
+            json=None,
+            params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+        assert job_id == JOB_ID
+
+    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+    def test_get_job_id_by_name_not_found(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE
+
+        job_id = self.hook.find_job_id_by_name("Non existing job")
+
+        mock_requests.get.assert_called_once_with(
+            list_jobs_endpoint(HOST),
+            json=None,
+            params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
+        assert job_id is None
+
+    @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+    def test_get_job_id_by_name_raise_exception_with_duplicates(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            **LIST_JOBS_RESPONSE,
+            'jobs': LIST_JOBS_RESPONSE['jobs'] * 2,
+        }
+
+        exception_message = f'There are more than one job with name {JOB_NAME}.'
+        with pytest.raises(AirflowException, match=exception_message):
+            self.hook.find_job_id_by_name(JOB_NAME)
+
+        mock_requests.get.assert_called_once_with(
+            list_jobs_endpoint(HOST),
+            json=None,
+            params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds,
+        )
+
 
 class TestDatabricksHookToken(unittest.TestCase):
     """
diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py
index e96c883..0e93138 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -47,6 +47,7 @@ EXISTING_CLUSTER_ID = 'existing-cluster-id'
 RUN_NAME = 'run-name'
 RUN_ID = 1
 JOB_ID = "42"
+JOB_NAME = "job-name"
 NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
 JAR_PARAMS = ["param1", "param2"]
 RENDERED_TEMPLATED_JAR_PARAMS = [f'/test-{DATE}']
@@ -545,3 +546,59 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
         db_mock.run_now.assert_called_once_with(expected)
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run_state.assert_not_called()
+
+    def test_init_exeption_with_job_name_and_job_id(self):
+        exception_message = "Argument 'job_name' is not allowed with argument 'job_id'"
+
+        with pytest.raises(AirflowException, match=exception_message):
+            DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME)
+
+        with pytest.raises(AirflowException, match=exception_message):
+            run = {'job_id': JOB_ID, 'job_name': JOB_NAME}
+            DatabricksRunNowOperator(task_id=TASK_ID, json=run)
+
+        with pytest.raises(AirflowException, match=exception_message):
+            run = {'job_id': JOB_ID}
+            DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME)
+
+    @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+    def test_exec_with_job_name(self, db_mock_class):
+        run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.find_job_id_by_name.return_value = JOB_ID
+        db_mock.run_now.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
+        op.execute(None)
+
+        expected = databricks_operator._deep_string_coerce(
+            {
+                'notebook_params': NOTEBOOK_PARAMS,
+                'notebook_task': NOTEBOOK_TASK,
+                'jar_params': JAR_PARAMS,
+                'job_id': JOB_ID,
+            }
+        )
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay
+        )
+        db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run_state.assert_called_once_with(RUN_ID)
+        assert RUN_ID == op.run_id
+
+    @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+    def test_exec_failure_if_job_id_not_found(self, db_mock_class):
+        run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+        op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.find_job_id_by_name.return_value = None
+
+        exception_message = f"Job ID for job name {JOB_NAME} can not be found"
+        with pytest.raises(AirflowException, match=exception_message):
+            op.execute(None)
+
+        db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)