You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by cr...@apache.org on 2017/09/08 18:24:29 UTC
incubator-airflow git commit: [AIRFLOW-1577] Add token support to
DatabricksHook
Repository: incubator-airflow
Updated Branches:
refs/heads/master ea9ab96cb -> c2c51518e
[AIRFLOW-1577] Add token support to DatabricksHook
Closes #2579 from andrewmchen/token
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c2c51518
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c2c51518
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c2c51518
Branch: refs/heads/master
Commit: c2c51518e8fc7ead22dc8c007289981e805827cf
Parents: ea9ab96
Author: Andrew Chen <an...@gmail.com>
Authored: Fri Sep 8 11:24:14 2017 -0700
Committer: Chris Riccomini <cr...@apache.org>
Committed: Fri Sep 8 11:24:14 2017 -0700
----------------------------------------------------------------------
airflow/contrib/hooks/databricks_hook.py | 21 ++++++++++-
.../contrib/operators/databricks_operator.py | 4 ++-
tests/contrib/hooks/test_databricks_hook.py | 37 +++++++++++++++++++-
3 files changed, 59 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2c51518/airflow/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 0cd5d0f..18e20c4 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -20,6 +20,7 @@ from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from requests import exceptions as requests_exceptions
+from requests.auth import AuthBase
try:
@@ -99,7 +100,12 @@ class DatabricksHook(BaseHook):
url = 'https://{host}/{endpoint}'.format(
host=self._parse_host(self.databricks_conn.host),
endpoint=endpoint)
- auth = (self.databricks_conn.login, self.databricks_conn.password)
+ if 'token' in self.databricks_conn.extra_dejson:
+ logging.info('Using token auth.')
+ auth = _TokenAuth(self.databricks_conn.extra_dejson['token'])
+ else:
+ logging.info('Using basic auth.')
+ auth = (self.databricks_conn.login, self.databricks_conn.password)
if method == 'GET':
request_func = requests.get
elif method == 'POST':
@@ -200,3 +206,16 @@ class RunState:
def __repr__(self):
return str(self.__dict__)
+
+
+class _TokenAuth(AuthBase):
+ """
+ Helper class for requests Auth field. AuthBase requires you to implement the __call__
+ magic function.
+ """
+ def __init__(self, token):
+ self.token = token
+
+ def __call__(self, r):
+ r.headers['Authorization'] = 'Bearer ' + self.token
+ return r
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2c51518/airflow/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
index 9c995df..1aa1441 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -131,7 +131,9 @@ class DatabricksSubmitRunOperator(BaseOperator):
This field will be templated.
:type timeout_seconds: int32
:param databricks_conn_id: The name of the Airflow connection to use.
- By default and in the common case this will be ``databricks_default``.
+ By default and in the common case this will be ``databricks_default``. To use
+ token based authentication, provide the key ``token`` in the extra field for the
+ connection.
:type databricks_conn_id: string
:param polling_period_seconds: Controls the rate which we poll for the result of
this run. By default the operator will poll every 30 seconds.
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2c51518/tests/contrib/hooks/test_databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index 6c789f9..56288a1 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -13,10 +13,11 @@
# limitations under the License.
#
+import json
import unittest
from airflow import __version__
-from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.utils import db
@@ -45,6 +46,7 @@ HOST = 'xx.cloud.databricks.com'
HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
LOGIN = 'login'
PASSWORD = 'password'
+TOKEN = 'token'
USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
LIFE_CYCLE_STATE = 'PENDING'
@@ -203,6 +205,39 @@ class DatabricksHookTest(unittest.TestCase):
headers=USER_AGENT_HEADER,
timeout=self.hook.timeout_seconds)
+
+class DatabricksHookTokenTest(unittest.TestCase):
+ """
+ Tests for DatabricksHook when auth is done with token.
+ """
+ @db.provide_session
+ def setUp(self, session=None):
+ conn = session.query(Connection) \
+ .filter(Connection.conn_id == DEFAULT_CONN_ID) \
+ .first()
+ conn.extra = json.dumps({'token': TOKEN})
+ session.commit()
+
+ self.hook = DatabricksHook()
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_submit_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ json = {
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER
+ }
+ run_id = self.hook.submit_run(json)
+
+ self.assertEquals(run_id, '1')
+ args = mock_requests.post.call_args
+ kwargs = args[1]
+ self.assertEquals(kwargs['auth'].token, TOKEN)
+
+
class RunStateTest(unittest.TestCase):
def test_is_terminal_true(self):
terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']