You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/02/26 21:56:10 UTC
[airflow] branch main updated: Databricks: add support for triggering jobs by name (#21663)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a1845c6 Databricks: add support for triggering jobs by name (#21663)
a1845c6 is described below
commit a1845c68f9a04e61dd99ccc0a23d17a277babf57
Author: Eugene Karimov <13...@users.noreply.github.com>
AuthorDate: Sat Feb 26 22:55:30 2022 +0100
Databricks: add support for triggering jobs by name (#21663)
---
airflow/providers/databricks/hooks/databricks.py | 50 ++++++++-
.../providers/databricks/operators/databricks.py | 17 ++-
.../providers/databricks/hooks/test_databricks.py | 117 +++++++++++++++++++++
.../databricks/operators/test_databricks.py | 57 ++++++++++
4 files changed, 239 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py
index ee2523c..5f4d90c 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -28,7 +28,7 @@ or the ``api/2.1/jobs/runs/submit``
import sys
import time
from time import sleep
-from typing import Dict
+from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests
@@ -57,6 +57,8 @@ CANCEL_RUN_ENDPOINT = ('POST', 'api/2.1/jobs/runs/cancel')
INSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/install')
UNINSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/uninstall')
+LIST_JOBS_ENDPOINT = ('GET', 'api/2.1/jobs/list')
+
USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}
RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
@@ -403,6 +405,52 @@ class DatabricksHook(BaseHook):
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
return response['run_id']
+ def list_jobs(self, limit: int = 25, offset: int = 0, expand_tasks: bool = False) -> List[Dict[str, Any]]:
+ """
+ Lists the jobs in the Databricks Job Service.
+
+ :param limit: The limit/batch size used to retrieve jobs.
+ :param offset: The offset of the first job to return, relative to the most recently created job.
+ :param expand_tasks: Whether to include task and cluster details in the response.
+ :return: A list of jobs.
+ """
+ has_more = True
+ jobs = []
+
+ while has_more:
+ json = {
+ 'limit': limit,
+ 'offset': offset,
+ 'expand_tasks': expand_tasks,
+ }
+ response = self._do_api_call(LIST_JOBS_ENDPOINT, json)
+ jobs += response['jobs'] if 'jobs' in response else []
+ has_more = response.get('has_more', False)
+ if has_more:
+ offset += len(response['jobs'])
+
+ return jobs
+
+ def find_job_id_by_name(self, job_name: str) -> Optional[int]:
+ """
+ Finds job id by its name. If there are multiple jobs with the same name, raises AirflowException.
+
+ :param job_name: The name of the job to look up.
+ :return: The job_id as an int or None if no job was found.
+ """
+ all_jobs = self.list_jobs()
+ matching_jobs = [j for j in all_jobs if j['settings']['name'] == job_name]
+
+ if len(matching_jobs) > 1:
+ raise AirflowException(
+ f"There are more than one job with name {job_name}. Please delete duplicated jobs first"
+ )
+
+ if not matching_jobs:
+ return None
+ else:
+ return matching_jobs[0]['job_id']
+
def get_run_page_url(self, run_id: int) -> str:
"""
Retrieves run_page_url.
diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py
index e53ff9a..d50c522 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -28,7 +28,6 @@ from airflow.providers.databricks.hooks.databricks import DatabricksHook
if TYPE_CHECKING:
from airflow.utils.context import Context
-
XCOM_RUN_ID_KEY = 'run_id'
XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
@@ -397,6 +396,7 @@ class DatabricksRunNowOperator(BaseOperator):
Currently the named parameters that ``DatabricksRunNowOperator`` supports are
- ``job_id``
+ - ``job_name``
- ``json``
- ``notebook_params``
- ``python_params``
@@ -409,6 +409,10 @@ class DatabricksRunNowOperator(BaseOperator):
.. seealso::
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow
+ :param job_name: the name of the existing Databricks job.
+ It must exist only one job with the specified name.
+ ``job_id`` and ``job_name`` are mutually exclusive.
+ This field will be templated.
:param json: A JSON object containing API parameters which will be passed
directly to the ``api/2.1/jobs/run-now`` endpoint. The other named parameters
(i.e. ``notebook_params``, ``spark_submit_params``..) to this operator will
@@ -489,6 +493,7 @@ class DatabricksRunNowOperator(BaseOperator):
self,
*,
job_id: Optional[str] = None,
+ job_name: Optional[str] = None,
json: Optional[Any] = None,
notebook_params: Optional[Dict[str, str]] = None,
python_params: Optional[List[str]] = None,
@@ -513,6 +518,10 @@ class DatabricksRunNowOperator(BaseOperator):
if job_id is not None:
self.json['job_id'] = job_id
+ if job_name is not None:
+ self.json['job_name'] = job_name
+ if 'job_id' in self.json and 'job_name' in self.json:
+ raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
if notebook_params is not None:
self.json['notebook_params'] = notebook_params
if python_params is not None:
@@ -536,6 +545,12 @@ class DatabricksRunNowOperator(BaseOperator):
def execute(self, context: 'Context'):
hook = self._get_hook()
+ if 'job_name' in self.json:
+ job_id = hook.find_job_id_by_name(self.json['job_name'])
+ if job_id is None:
+ raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found")
+ self.json['job_id'] = job_id
+ del self.json['job_name']
self.run_id = hook.run_now(self.json)
_handle_databricks_operator_execution(self, hook, self.log, context)
diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py
index 951673a..83b4366 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -50,6 +50,7 @@ NEW_CLUSTER = {'spark_version': '2.0.x-scala2.10', 'node_type_id': 'r3.xlarge',
CLUSTER_ID = 'cluster_id'
RUN_ID = 1
JOB_ID = 42
+JOB_NAME = 'job-name'
HOST = 'xx.cloud.databricks.com'
HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
LOGIN = 'login'
@@ -71,6 +72,17 @@ LIBRARIES = [
{"jar": "dbfs:/mnt/libraries/library.jar"},
{"maven": {"coordinates": "org.jsoup:jsoup:1.7.2", "exclusions": ["slf4j:slf4j"]}},
]
+LIST_JOBS_RESPONSE = {
+ 'jobs': [
+ {
+ 'job_id': JOB_ID,
+ 'settings': {
+ 'name': JOB_NAME,
+ },
+ },
+ ],
+ 'has_more': False,
+}
def run_now_endpoint(host):
@@ -136,6 +148,13 @@ def uninstall_endpoint(host):
return f'https://{host}/api/2.0/libraries/uninstall'
+def list_jobs_endpoint(host):
+ """
+ Utility function to generate the list jobs endpoint giver the host
+ """
+ return f'https://{host}/api/2.1/jobs/list'
+
+
def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
@@ -531,6 +550,104 @@ class TestDatabricksHook(unittest.TestCase):
aad_token = {'token': 'my_token', 'expires_on': int(time.time())}
self.assertFalse(self.hook._is_aad_token_valid(aad_token))
+ @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+ def test_list_jobs_success_single_page(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE
+
+ jobs = self.hook.list_jobs()
+
+ mock_requests.get.assert_called_once_with(
+ list_jobs_endpoint(HOST),
+ json=None,
+ params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ assert jobs == LIST_JOBS_RESPONSE['jobs']
+
+ @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+ def test_list_jobs_success_multiple_pages(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.side_effect = [
+ create_successful_response_mock({**LIST_JOBS_RESPONSE, 'has_more': True}),
+ create_successful_response_mock(LIST_JOBS_RESPONSE),
+ ]
+
+ jobs = self.hook.list_jobs()
+
+ assert mock_requests.get.call_count == 2
+
+ first_call_args = mock_requests.method_calls[0]
+ assert first_call_args[1][0] == list_jobs_endpoint(HOST)
+ assert first_call_args[2]['params'] == {'limit': 25, 'offset': 0, 'expand_tasks': False}
+
+ second_call_args = mock_requests.method_calls[1]
+ assert second_call_args[1][0] == list_jobs_endpoint(HOST)
+ assert second_call_args[2]['params'] == {'limit': 25, 'offset': 1, 'expand_tasks': False}
+
+ assert len(jobs) == 2
+ assert jobs == LIST_JOBS_RESPONSE['jobs'] * 2
+
+ @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+ def test_get_job_id_by_name_success(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE
+
+ job_id = self.hook.find_job_id_by_name(JOB_NAME)
+
+ mock_requests.get.assert_called_once_with(
+ list_jobs_endpoint(HOST),
+ json=None,
+ params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ assert job_id == JOB_ID
+
+ @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+ def test_get_job_id_by_name_not_found(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = LIST_JOBS_RESPONSE
+
+ job_id = self.hook.find_job_id_by_name("Non existing job")
+
+ mock_requests.get.assert_called_once_with(
+ list_jobs_endpoint(HOST),
+ json=None,
+ params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ assert job_id is None
+
+ @mock.patch('airflow.providers.databricks.hooks.databricks.requests')
+ def test_get_job_id_by_name_raise_exception_with_duplicates(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ **LIST_JOBS_RESPONSE,
+ 'jobs': LIST_JOBS_RESPONSE['jobs'] * 2,
+ }
+
+ exception_message = f'There are more than one job with name {JOB_NAME}.'
+ with pytest.raises(AirflowException, match=exception_message):
+ self.hook.find_job_id_by_name(JOB_NAME)
+
+ mock_requests.get.assert_called_once_with(
+ list_jobs_endpoint(HOST),
+ json=None,
+ params={'limit': 25, 'offset': 0, 'expand_tasks': False},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds,
+ )
+
class TestDatabricksHookToken(unittest.TestCase):
"""
diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py
index e96c883..0e93138 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -47,6 +47,7 @@ EXISTING_CLUSTER_ID = 'existing-cluster-id'
RUN_NAME = 'run-name'
RUN_ID = 1
JOB_ID = "42"
+JOB_NAME = "job-name"
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
JAR_PARAMS = ["param1", "param2"]
RENDERED_TEMPLATED_JAR_PARAMS = [f'/test-{DATE}']
@@ -545,3 +546,59 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_not_called()
+
+ def test_init_exeption_with_job_name_and_job_id(self):
+ exception_message = "Argument 'job_name' is not allowed with argument 'job_id'"
+
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME)
+
+ with pytest.raises(AirflowException, match=exception_message):
+ run = {'job_id': JOB_ID, 'job_name': JOB_NAME}
+ DatabricksRunNowOperator(task_id=TASK_ID, json=run)
+
+ with pytest.raises(AirflowException, match=exception_message):
+ run = {'job_id': JOB_ID}
+ DatabricksRunNowOperator(task_id=TASK_ID, json=run, job_name=JOB_NAME)
+
+ @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+ def test_exec_with_job_name(self, db_mock_class):
+ run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+ op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.find_job_id_by_name.return_value = JOB_ID
+ db_mock.run_now.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
+ op.execute(None)
+
+ expected = databricks_operator._deep_string_coerce(
+ {
+ 'notebook_params': NOTEBOOK_PARAMS,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'jar_params': JAR_PARAMS,
+ 'job_id': JOB_ID,
+ }
+ )
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, retry_delay=op.databricks_retry_delay
+ )
+ db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)
+ db_mock.run_now.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run_state.assert_called_once_with(RUN_ID)
+ assert RUN_ID == op.run_id
+
+ @mock.patch('airflow.providers.databricks.operators.databricks.DatabricksHook')
+ def test_exec_failure_if_job_id_not_found(self, db_mock_class):
+ run = {'notebook_params': NOTEBOOK_PARAMS, 'notebook_task': NOTEBOOK_TASK, 'jar_params': JAR_PARAMS}
+ op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.find_job_id_by_name.return_value = None
+
+ exception_message = f"Job ID for job name {JOB_NAME} can not be found"
+ with pytest.raises(AirflowException, match=exception_message):
+ op.execute(None)
+
+ db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)