You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ar...@apache.org on 2017/04/06 15:30:37 UTC
incubator-airflow git commit: [AIRFLOW-1028] Databricks Operator for
Airflow
Repository: incubator-airflow
Updated Branches:
refs/heads/master 5a6f18f1c -> 53ca50845
[AIRFLOW-1028] Databricks Operator for Airflow
Add DatabricksSubmitRun Operator
In this PR, we contribute a DatabricksSubmitRun operator and a
Databricks hook. This operator enables easy integration of Airflow
with Databricks. In addition to the operator, we have created a
databricks_default connection, an example_dag using this
DatabricksSubmitRunOperator, and matching documentation.
Closes #2202 from andrewmchen/databricks-operator-
squashed
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/53ca5084
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/53ca5084
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/53ca5084
Branch: refs/heads/master
Commit: 53ca5084561fd5c13996609f2eda6baf717249b5
Parents: 5a6f18f
Author: Andrew Chen <an...@databricks.com>
Authored: Thu Apr 6 08:30:01 2017 -0700
Committer: Arthur Wiedmer <ar...@gmail.com>
Committed: Thu Apr 6 08:30:33 2017 -0700
----------------------------------------------------------------------
.../example_dags/example_databricks_operator.py | 82 +++++++
airflow/contrib/hooks/databricks_hook.py | 202 +++++++++++++++++
.../contrib/operators/databricks_operator.py | 211 +++++++++++++++++
airflow/exceptions.py | 2 +-
airflow/models.py | 1 +
airflow/utils/db.py | 4 +
docs/code.rst | 1 +
docs/integration.rst | 13 ++
setup.py | 2 +
tests/contrib/hooks/databricks_hook.py | 226 +++++++++++++++++++
tests/contrib/operators/databricks_operator.py | 185 +++++++++++++++
11 files changed, 928 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/example_dags/example_databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/example_dags/example_databricks_operator.py b/airflow/contrib/example_dags/example_databricks_operator.py
new file mode 100644
index 0000000..abf6844
--- /dev/null
+++ b/airflow/contrib/example_dags/example_databricks_operator.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import airflow
+
+from airflow import DAG
+from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
+
+# This is an example DAG which uses the DatabricksSubmitRunOperator.
+# In this example, we create two tasks which execute sequentially.
+# The first task is to run a notebook at the workspace path "/test"
+# and the second task is to run a JAR uploaded to DBFS. Both,
+# tasks use new clusters.
+#
+# Because we have set a downstream dependency on the notebook task,
+# the spark jar task will NOT run until the notebook task completes
+# successfully.
+#
+# The definition of a succesful run is if the run has a result_state of "SUCCESS".
+# For more information about the state of a run refer to
+# https://docs.databricks.com/api/latest/jobs.html#runstate
+
+args = {
+ 'owner': 'airflow',
+ 'email': ['airflow@example.com'],
+ 'depends_on_past': False,
+ 'start_date': airflow.utils.dates.days_ago(2)
+}
+
+dag = DAG(
+ dag_id='example_databricks_operator', default_args=args,
+ schedule_interval='@daily')
+
+new_cluster = {
+ 'spark_version': '2.1.0-db3-scala2.11',
+ 'node_type_id': 'r3.xlarge',
+ 'aws_attributes': {
+ 'availability': 'ON_DEMAND'
+ },
+ 'num_workers': 8
+}
+
+notebook_task_params = {
+ 'new_cluster': new_cluster,
+ 'notebook_task': {
+ 'notebook_path': '/Users/airflow@example.com/PrepareData',
+ },
+}
+# Example of using the JSON parameter to initialize the operator.
+notebook_task = DatabricksSubmitRunOperator(
+ task_id='notebook_task',
+ dag=dag,
+ json=notebook_task_params)
+
+# Example of using the named parameters of DatabricksSubmitRunOperator
+# to initialize the operator.
+spark_jar_task = DatabricksSubmitRunOperator(
+ task_id='spark_jar_task',
+ dag=dag,
+ new_cluster=new_cluster,
+ spark_jar_task={
+ 'main_class_name': 'com.example.ProcessData'
+ },
+ libraries=[
+ {
+ 'jar': 'dbfs:/lib/etl-0.1.jar'
+ }
+ ]
+)
+
+notebook_task.set_downstream(spark_jar_task)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
new file mode 100644
index 0000000..0cd5d0f
--- /dev/null
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -0,0 +1,202 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import requests
+
+from airflow import __version__
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from requests import exceptions as requests_exceptions
+
+
+try:
+ from urllib import parse as urlparse
+except ImportError:
+ import urlparse
+
+
+SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
+GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
+CANCEL_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/cancel')
+USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+
+
+class DatabricksHook(BaseHook):
+ """
+ Interact with Databricks.
+ """
+ def __init__(
+ self,
+ databricks_conn_id='databricks_default',
+ timeout_seconds=180,
+ retry_limit=3):
+ """
+ :param databricks_conn_id: The name of the databricks connection to use.
+ :type databricks_conn_id: string
+ :param timeout_seconds: The amount of time in seconds the requests library
+ will wait before timing-out.
+ :type timeout_seconds: int
+ :param retry_limit: The number of times to retry the connection in case of
+ service outages.
+ :type retry_limit: int
+ """
+ self.databricks_conn_id = databricks_conn_id
+ self.databricks_conn = self.get_connection(databricks_conn_id)
+ self.timeout_seconds = timeout_seconds
+ assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
+ self.retry_limit = retry_limit
+
+ def _parse_host(self, host):
+ """
+ The purpose of this function is to be robust to improper connections
+ settings provided by users, specifically in the host field.
+
+
+ For example -- when users supply ``https://xx.cloud.databricks.com`` as the
+ host, we must strip out the protocol to get the host.
+ >>> h = DatabricksHook()
+ >>> assert h._parse_host('https://xx.cloud.databricks.com') == \
+ 'xx.cloud.databricks.com'
+
+ In the case where users supply the correct ``xx.cloud.databricks.com`` as the
+ host, this function is a no-op.
+ >>> assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
+ """
+ urlparse_host = urlparse.urlparse(host).hostname
+ if urlparse_host:
+ # In this case, host = https://xx.cloud.databricks.com
+ return urlparse_host
+ else:
+ # In this case, host = xx.cloud.databricks.com
+ return host
+
+ def _do_api_call(self, endpoint_info, json):
+ """
+ Utility function to perform an API call with retries
+ :param endpoint_info: Tuple of method and endpoint
+ :type endpoint_info: (string, string)
+ :param json: Parameters for this API call.
+ :type json: dict
+ :return: If the api call returns a OK status code,
+ this function returns the response in JSON. Otherwise,
+ we throw an AirflowException.
+ :rtype: dict
+ """
+ method, endpoint = endpoint_info
+ url = 'https://{host}/{endpoint}'.format(
+ host=self._parse_host(self.databricks_conn.host),
+ endpoint=endpoint)
+ auth = (self.databricks_conn.login, self.databricks_conn.password)
+ if method == 'GET':
+ request_func = requests.get
+ elif method == 'POST':
+ request_func = requests.post
+ else:
+ raise AirflowException('Unexpected HTTP Method: ' + method)
+
+ for attempt_num in range(1, self.retry_limit+1):
+ try:
+ response = request_func(
+ url,
+ json=json,
+ auth=auth,
+ headers=USER_AGENT_HEADER,
+ timeout=self.timeout_seconds)
+ if response.status_code == requests.codes.ok:
+ return response.json()
+ else:
+ # In this case, the user probably made a mistake.
+ # Don't retry.
+ raise AirflowException('Response: {0}, Status Code: {1}'.format(
+ response.content, response.status_code))
+ except (requests_exceptions.ConnectionError,
+ requests_exceptions.Timeout) as e:
+ logging.error(('Attempt {0} API Request to Databricks failed ' +
+ 'with reason: {1}').format(attempt_num, e))
+ raise AirflowException(('API requests to Databricks failed {} times. ' +
+ 'Giving up.').format(self.retry_limit))
+
+ def submit_run(self, json):
+ """
+ Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.
+
+ :param json: The data used in the body of the request to the ``submit`` endpoint.
+ :type json: dict
+ :return: the run_id as a string
+ :rtype: string
+ """
+ response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
+ return response['run_id']
+
+ def get_run_page_url(self, run_id):
+ json = {'run_id': run_id}
+ response = self._do_api_call(GET_RUN_ENDPOINT, json)
+ return response['run_page_url']
+
+ def get_run_state(self, run_id):
+ json = {'run_id': run_id}
+ response = self._do_api_call(GET_RUN_ENDPOINT, json)
+ state = response['state']
+ life_cycle_state = state['life_cycle_state']
+ # result_state may not be in the state if not terminal
+ result_state = state.get('result_state', None)
+ state_message = state['state_message']
+ return RunState(life_cycle_state, result_state, state_message)
+
+ def cancel_run(self, run_id):
+ json = {'run_id': run_id}
+ self._do_api_call(CANCEL_RUN_ENDPOINT, json)
+
+
+RUN_LIFE_CYCLE_STATES = [
+ 'PENDING',
+ 'RUNNING',
+ 'TERMINATING',
+ 'TERMINATED',
+ 'SKIPPED',
+ 'INTERNAL_ERROR'
+]
+
+
+class RunState:
+ """
+ Utility class for the run state concept of Databricks runs.
+ """
+ def __init__(self, life_cycle_state, result_state, state_message):
+ self.life_cycle_state = life_cycle_state
+ self.result_state = result_state
+ self.state_message = state_message
+
+ @property
+ def is_terminal(self):
+ if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
+ raise AirflowException(('Unexpected life cycle state: {}: If the state has '
+ 'been introduced recently, please check the Databricks user '
+ 'guide for troubleshooting information').format(
+ self.life_cycle_state))
+ return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR')
+
+ @property
+ def is_successful(self):
+ return self.result_state == 'SUCCESS'
+
+ def __eq__(self, other):
+ return self.life_cycle_state == other.life_cycle_state and \
+ self.result_state == other.result_state and \
+ self.state_message == other.state_message
+
+ def __repr__(self):
+ return str(self.__dict__)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
new file mode 100644
index 0000000..46b1659
--- /dev/null
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -0,0 +1,211 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import time
+
+from airflow.exceptions import AirflowException
+from airflow.contrib.hooks.databricks_hook import DatabricksHook
+from airflow.models import BaseOperator
+
+LINE_BREAK = ('-' * 80)
+
+
+class DatabricksSubmitRunOperator(BaseOperator):
+ """
+ Submits an Spark job run to Databricks using the
+ `api/2.0/jobs/runs/submit
+ <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_
+ API endpoint.
+
+ There are two ways to instantiate this operator.
+
+ In the first way, you can take the JSON payload that you typically use
+ to call the ``api/2.0/jobs/runs/submit`` endpoint and pass it directly
+ to our ``DatabricksSubmitRunOperator`` through the ``json`` parameter.
+ For example ::
+ json = {
+ 'new_cluster': {
+ 'spark_version': '2.1.0-db3-scala2.11',
+ 'num_workers': 2
+ },
+ 'notebook_task': {
+ 'notebook_path': '/Users/airflow@example.com/PrepareData',
+ },
+ }
+ notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json)
+
+ Another way to accomplish the same thing is to use the named parameters
+ of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly
+ one named parameter for each top level parameter in the ``runs/submit``
+ endpoint. In this method, your code would look like this: ::
+ new_cluster = {
+ 'spark_version': '2.1.0-db3-scala2.11',
+ 'num_workers': 2
+ }
+ notebook_task = {
+ 'notebook_path': '/Users/airflow@example.com/PrepareData',
+ }
+ notebook_run = DatabricksSubmitRunOperator(
+ task_id='notebook_run',
+ new_cluster=new_cluster,
+ notebook_task=notebook_task)
+
+ In the case where both the json parameter **AND** the named parameters
+ are provided, they will be merged together. If there are conflicts during the merge,
+ the named parameters will take precedence and override the top level ``json`` keys.
+
+ Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are
+ - ``spark_jar_task``
+ - ``notebook_task``
+ - ``new_cluster``
+ - ``existing_cluster_id``
+ - ``libraries``
+ - ``run_name``
+ - ``timeout_seconds``
+
+ :param json: A JSON object containing API parameters which will be passed
+ directly to the ``api/2.0/jobs/runs/submit`` endpoint. The other named parameters
+ (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will
+ be merged with this json dictionary if they are provided.
+ If there are conflicts during the merge, the named parameters will
+ take precedence and override the top level json keys.
+ https://docs.databricks.com/api/latest/jobs.html#runs-submit
+ :type json: dict
+ :param spark_jar_task: The main class and parameters for the JAR task. Note that
+ the actual JAR is specified in the ``libraries``.
+ *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
+ https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask
+ :type spark_jar_task: dict
+ :param notebook_task: The notebook path and parameters for the notebook task.
+ *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
+ https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask
+ :type notebook_task: dict
+ :param new_cluster: Specs for a new cluster on which this task will be run.
+ *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified.
+ https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster
+ :type new_cluster: dict
+ :param existing_cluster_id: ID for existing cluster on which to run this task.
+ *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified.
+ :type existing_cluster_id: string
+ :param libraries: Libraries which this run will use.
+ https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary
+ :type libraries: list of dicts
+ :param run_name: The run name used for this task.
+ By default this will be set to the Airflow ``task_id``. This ``task_id`` is a
+ required parameter of the superclass ``BaseOperator``.
+ :type run_name: string
+ :param timeout_seconds: The timeout for this run. By default a value of 0 is used
+ which means to have no timeout.
+ :type timeout_seconds: int32
+ :param databricks_conn_id: The name of the Airflow connection to use.
+ By default and in the common case this will be ``databricks_default``.
+ :type databricks_conn_id: string
+ :param polling_period_seconds: Controls the rate which we poll for the result of
+ this run. By default the operator will poll every 30 seconds.
+ :type polling_period_seconds: int
+ :param databricks_retry_limit: Amount of times retry if the Databricks backend is
+ unreachable. Its value must be greater than or equal to 1.
+ :type databricks_retry_limit: int
+ """
+ # Databricks brand color (blue) under white text
+ ui_color = '#1CB1C2'
+ ui_fgcolor = '#fff'
+
+ def __init__(
+ self,
+ json=None,
+ spark_jar_task=None,
+ notebook_task=None,
+ new_cluster=None,
+ existing_cluster_id=None,
+ libraries=None,
+ run_name=None,
+ timeout_seconds=None,
+ databricks_conn_id='databricks_default',
+ polling_period_seconds=30,
+ databricks_retry_limit=3,
+ **kwargs):
+ """
+ Creates a new ``DatabricksSubmitRunOperator``.
+ """
+ super(DatabricksSubmitRunOperator, self).__init__(**kwargs)
+ self.json = json or {}
+ self.databricks_conn_id = databricks_conn_id
+ self.polling_period_seconds = polling_period_seconds
+ self.databricks_retry_limit = databricks_retry_limit
+ if spark_jar_task is not None:
+ self.json['spark_jar_task'] = spark_jar_task
+ if notebook_task is not None:
+ self.json['notebook_task'] = notebook_task
+ if new_cluster is not None:
+ self.json['new_cluster'] = new_cluster
+ if existing_cluster_id is not None:
+ self.json['existing_cluster_id'] = existing_cluster_id
+ if libraries is not None:
+ self.json['libraries'] = libraries
+ if run_name is not None:
+ self.json['run_name'] = run_name
+ if timeout_seconds is not None:
+ self.json['timeout_seconds'] = timeout_seconds
+ if 'run_name' not in self.json:
+ self.json['run_name'] = run_name or kwargs['task_id']
+
+ # This variable will be used in case our task gets killed.
+ self.run_id = None
+
+ def _log_run_page_url(self, url):
+ logging.info('View run status, Spark UI, and logs at {}'.format(url))
+
+ def get_hook(self):
+ return DatabricksHook(
+ self.databricks_conn_id,
+ retry_limit=self.databricks_retry_limit)
+
+ def execute(self, context):
+ hook = self.get_hook()
+ self.run_id = hook.submit_run(self.json)
+ run_page_url = hook.get_run_page_url(self.run_id)
+ logging.info(LINE_BREAK)
+ logging.info('Run submitted with run_id: {}'.format(self.run_id))
+ self._log_run_page_url(run_page_url)
+ logging.info(LINE_BREAK)
+ while True:
+ run_state = hook.get_run_state(self.run_id)
+ if run_state.is_terminal:
+ if run_state.is_successful:
+ logging.info('{} completed successfully.'.format(
+ self.task_id))
+ self._log_run_page_url(run_page_url)
+ return
+ else:
+ error_message = '{t} failed with terminal state: {s}'.format(
+ t=self.task_id,
+ s=run_state)
+ raise AirflowException(error_message)
+ else:
+ logging.info('{t} in run state: {s}'.format(t=self.task_id,
+ s=run_state))
+ self._log_run_page_url(run_page_url)
+ logging.info('Sleeping for {} seconds.'.format(
+ self.polling_period_seconds))
+ time.sleep(self.polling_period_seconds)
+
+ def on_kill(self):
+ hook = self.get_hook()
+ hook.cancel_run(self.run_id)
+ logging.info('Task: {t} with run_id: {r} was requested to be cancelled.'.format(
+ t=self.task_id,
+ r=self.run_id))
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/exceptions.py
----------------------------------------------------------------------
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 2231208..90d3e22 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -22,7 +22,7 @@ class AirflowException(Exception):
class AirflowConfigException(AirflowException):
pass
-
+
class AirflowSensorTimeout(AirflowException):
pass
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 95e2255..42b621d 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -543,6 +543,7 @@ class Connection(Base):
('jira', 'JIRA',),
('redis', 'Redis',),
('wasb', 'Azure Blob Storage'),
+ ('databricks', 'Databricks',),
]
def __init__(
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 7da9217..54254f6 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -249,6 +249,10 @@ def initdb():
]
}
'''))
+ merge_conn(
+ models.Connection(
+ conn_id='databricks_default', conn_type='databricks',
+ host='localhost'))
# Known event types
KET = models.KnownEventType
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 683e85f..c31061c 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -97,6 +97,7 @@ Community-contributed Operators
.. autoclass:: airflow.contrib.operators.bigquery_operator.BigQueryOperator
.. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
+.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
.. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator
.. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator
.. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/docs/integration.rst
----------------------------------------------------------------------
diff --git a/docs/integration.rst b/docs/integration.rst
index 4a6b676..a6c9d7c 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -61,6 +61,19 @@ AWS: Amazon Webservices
---
+.. _Databricks:
+
+Databricks
+--------------------------
+`Databricks <https://databricks.com/>`_ has contributed an Airflow operator which enables
+submitting runs to the Databricks platform. Internally the operator talks to the
+``api/2.0/jobs/runs/submit`` `endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
+
+DatabricksSubmitRunOperator
+''''''''''''''''''''''''''''
+
+.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
+
.. _GCP:
GCP: Google Cloud Platform
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index ea60dca..6691208 100644
--- a/setup.py
+++ b/setup.py
@@ -116,6 +116,7 @@ crypto = ['cryptography>=0.9.3']
dask = [
'distributed>=1.15.2, <2'
]
+databricks = ['requests>=2.5.1, <3']
datadog = ['datadog>=0.14.0']
doc = [
'sphinx>=1.2.3',
@@ -244,6 +245,7 @@ def do_setup():
'cloudant': cloudant,
'crypto': crypto,
'dask': dask,
+ 'databricks': databricks,
'datadog': datadog,
'devel': devel_minreq,
'devel_hadoop': devel_hadoop,
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/tests/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/databricks_hook.py b/tests/contrib/hooks/databricks_hook.py
new file mode 100644
index 0000000..6c789f9
--- /dev/null
+++ b/tests/contrib/hooks/databricks_hook.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow import __version__
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.utils import db
+from requests import exceptions as requests_exceptions
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+ 'notebook_path': '/test'
+}
+NEW_CLUSTER = {
+ 'spark_version': '2.0.x-scala2.10',
+ 'node_type_id': 'r3.xlarge',
+ 'num_workers': 1
+}
+RUN_ID = 1
+HOST = 'xx.cloud.databricks.com'
+HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
+LOGIN = 'login'
+PASSWORD = 'password'
+USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
+LIFE_CYCLE_STATE = 'PENDING'
+STATE_MESSAGE = 'Waiting for cluster'
+GET_RUN_RESPONSE = {
+ 'run_page_url': RUN_PAGE_URL,
+ 'state': {
+ 'life_cycle_state': LIFE_CYCLE_STATE,
+ 'state_message': STATE_MESSAGE
+ }
+}
+RESULT_STATE = None
+
+
+def submit_run_endpoint(host):
+ """
+ Utility function to generate the submit run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/submit'.format(host)
+
+
+def get_run_endpoint(host):
+ """
+ Utility function to generate the get run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/get'.format(host)
+
+def cancel_run_endpoint(host):
+ """
+ Utility function to generate the get run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
+
+class DatabricksHookTest(unittest.TestCase):
+ """
+ Tests for DatabricksHook.
+ """
+ @db.provide_session
+ def setUp(self, session=None):
+ conn = session.query(Connection) \
+ .filter(Connection.conn_id == DEFAULT_CONN_ID) \
+ .first()
+ conn.host = HOST
+ conn.login = LOGIN
+ conn.password = PASSWORD
+ session.commit()
+
+ self.hook = DatabricksHook()
+
+ def test_parse_host_with_proper_host(self):
+ host = self.hook._parse_host(HOST)
+ self.assertEquals(host, HOST)
+
+ def test_parse_host_with_scheme(self):
+ host = self.hook._parse_host(HOST_WITH_SCHEME)
+ self.assertEquals(host, HOST)
+
+ def test_init_bad_retry_limit(self):
+ with self.assertRaises(AssertionError):
+ DatabricksHook(retry_limit = 0)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.logging')
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_do_api_call_with_error_retry(self, mock_requests, mock_logging):
+ for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
+ mock_requests.reset_mock()
+ mock_logging.reset_mock()
+ mock_requests.post.side_effect = exception()
+
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_do_api_call_with_bad_status_code(self, mock_requests):
+ mock_requests.codes.ok = 200
+ status_code_mock = mock.PropertyMock(return_value=500)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_submit_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ json = {
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER
+ }
+ run_id = self.hook.submit_run(json)
+
+ self.assertEquals(run_id, '1')
+ mock_requests.post.assert_called_once_with(
+ submit_run_endpoint(HOST),
+ json={
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER,
+ },
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_get_run_page_url(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.get.return_value).status_code = status_code_mock
+
+ run_page_url = self.hook.get_run_page_url(RUN_ID)
+
+ self.assertEquals(run_page_url, RUN_PAGE_URL)
+ mock_requests.get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_get_run_state(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.get.return_value).status_code = status_code_mock
+
+ run_state = self.hook.get_run_state(RUN_ID)
+
+ self.assertEquals(run_state, RunState(
+ LIFE_CYCLE_STATE,
+ RESULT_STATE,
+ STATE_MESSAGE))
+ mock_requests.get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_cancel_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+
+ self.hook.cancel_run(RUN_ID)
+
+ mock_requests.post.assert_called_once_with(
+ cancel_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+class RunStateTest(unittest.TestCase):
+ def test_is_terminal_true(self):
+ terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
+ for state in terminal_states:
+ run_state = RunState(state, '', '')
+ self.assertTrue(run_state.is_terminal)
+
+ def test_is_terminal_false(self):
+ non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
+ for state in non_terminal_states:
+ run_state = RunState(state, '', '')
+ self.assertFalse(run_state.is_terminal)
+
+ def test_is_terminal_with_nonexistent_life_cycle_state(self):
+ run_state = RunState('blah', '', '')
+ with self.assertRaises(AirflowException):
+ run_state.is_terminal
+
+ def test_is_successful(self):
+ run_state = RunState('TERMINATED', 'SUCCESS', '')
+ self.assertTrue(run_state.is_successful)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/tests/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/databricks_operator.py b/tests/contrib/operators/databricks_operator.py
new file mode 100644
index 0000000..aab47fa
--- /dev/null
+++ b/tests/contrib/operators/databricks_operator.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.contrib.hooks.databricks_hook import RunState
+from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
+from airflow.exceptions import AirflowException
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+ 'notebook_path': '/test'
+}
+SPARK_JAR_TASK = {
+ 'main_class_name': 'com.databricks.Test'
+}
+NEW_CLUSTER = {
+ 'spark_version': '2.0.x-scala2.10',
+ 'node_type_id': 'development-node',
+ 'num_workers': 1
+}
+EXISTING_CLUSTER_ID = 'existing-cluster-id'
+RUN_NAME = 'run-name'
+RUN_ID = 1
+
+
+class DatabricksSubmitRunOperatorTest(unittest.TestCase):
+ def test_init_with_named_parameters(self):
+ """
+ Test the initializer with the named parameters.
+ """
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK)
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID
+ }
+ self.assertDictEqual(expected, op.json)
+
+ def test_init_with_json(self):
+ """
+ Test the initializer with json data.
+ """
+ json = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID
+ }
+ self.assertDictEqual(expected, op.json)
+
+ def test_init_with_specified_run_name(self):
+ """
+ Test the initializer with a specified run_name.
+ """
+ json = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': RUN_NAME
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': RUN_NAME
+ }
+ self.assertDictEqual(expected, op.json)
+
+ def test_init_with_merging(self):
+ """
+ Test the initializer when json and other named parameters are both
+ provided. The named parameters should override top level keys in the
+ json dict.
+ """
+ override_new_cluster = {'workers': 999}
+ json = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster)
+ expected = {
+ 'new_cluster': override_new_cluster,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID,
+ }
+ self.assertDictEqual(expected, op.json)
+
+ @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+ def test_exec_success(self, db_mock_class):
+ """
+ Test the execute function in case where the run is successful.
+ """
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
+ op.execute(None)
+
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID
+ }
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit)
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run_state.assert_called_once_with(RUN_ID)
+ self.assertEquals(RUN_ID, op.run_id)
+
+ @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+ def test_exec_failure(self, db_mock_class):
+ """
+ Test the execute function in case where the run failed.
+ """
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
+
+ with self.assertRaises(AirflowException):
+ op.execute(None)
+
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID,
+ }
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit)
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run_state.assert_called_once_with(RUN_ID)
+ self.assertEquals(RUN_ID, op.run_id)
+
+ @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+ def test_on_kill(self, db_mock_class):
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ op.run_id = RUN_ID
+
+ op.on_kill()
+
+ db_mock.cancel_run.assert_called_once_with(RUN_ID)
+