You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/08/31 07:33:24 UTC

[GitHub] Fokko closed pull request #3817: [AIRFLOW-2974] Extended Databricks hook with cluster operation

Fokko closed pull request #3817: [AIRFLOW-2974] Extended Databricks hook with cluster operation
URL: https://github.com/apache/incubator-airflow/pull/3817
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 5b97a0eba0..cb2ba9bd00 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -33,6 +33,9 @@
 except ImportError:
     import urlparse
 
+RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
+START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
+TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
 
 SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
 GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
@@ -189,6 +192,15 @@ def cancel_run(self, run_id):
         json = {'run_id': run_id}
         self._do_api_call(CANCEL_RUN_ENDPOINT, json)
 
+    def restart_cluster(self, json):
+        self._do_api_call(RESTART_CLUSTER_ENDPOINT, json)
+
+    def start_cluster(self, json):
+        self._do_api_call(START_CLUSTER_ENDPOINT, json)
+
+    def terminate_cluster(self, json):
+        self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json)
+
 
 def _retryable_error(exception):
     return isinstance(exception, requests_exceptions.ConnectionError) \
diff --git a/setup.cfg b/setup.cfg
index 622cc1303a..881fe0107d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 [metadata]
 name = Airflow
 summary = Airflow is a system to programmatically author, schedule and monitor data pipelines.
@@ -34,4 +35,3 @@ all_files = 1
 upload-dir = docs/_build/html
 
 [easy_install]
-
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index a022431899..04a7c8dc3c 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -52,6 +52,7 @@
     'node_type_id': 'r3.xlarge',
     'num_workers': 1
 }
+CLUSTER_ID = 'cluster_id'
 RUN_ID = 1
 HOST = 'xx.cloud.databricks.com'
 HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
@@ -93,6 +94,26 @@ def cancel_run_endpoint(host):
     return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
 
 
+def start_cluster_endpoint(host):
+    """
+    Utility function to generate the get run endpoint given the host.
+    """
+    return 'https://{}/api/2.0/clusters/start'.format(host)
+
+
+def restart_cluster_endpoint(host):
+    """
+    Utility function to generate the get run endpoint given the host.
+    """
+    return 'https://{}/api/2.0/clusters/restart'.format(host)
+
+
+def terminate_cluster_endpoint(host):
+    """
+    Utility function to generate the get run endpoint given the host.
+    """
+    return 'https://{}/api/2.0/clusters/delete'.format(host)
+
 def create_valid_response_mock(content):
     response = mock.MagicMock()
     response.json.return_value = content
@@ -293,6 +314,54 @@ def test_cancel_run(self, mock_requests):
             headers=USER_AGENT_HEADER,
             timeout=self.hook.timeout_seconds)
 
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_start_cluster(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = {}
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+
+        self.hook.start_cluster({"cluster_id": CLUSTER_ID})
+
+        mock_requests.post.assert_called_once_with(
+            start_cluster_endpoint(HOST),
+            json={'cluster_id': CLUSTER_ID},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_restart_cluster(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = {}
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+
+        self.hook.restart_cluster({"cluster_id": CLUSTER_ID})
+
+        mock_requests.post.assert_called_once_with(
+            restart_cluster_endpoint(HOST),
+            json={'cluster_id': CLUSTER_ID},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_terminate_cluster(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = {}
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+
+        self.hook.terminate_cluster({"cluster_id": CLUSTER_ID})
+
+        mock_requests.post.assert_called_once_with(
+            terminate_cluster_endpoint(HOST),
+            json={'cluster_id': CLUSTER_ID},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
 
 class DatabricksHookTokenTest(unittest.TestCase):
     """


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services