You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/04/17 08:05:10 UTC
[3/3] incubator-airflow git commit: [AIRFLOW-1094] Run unit tests
under contrib in Travis
[AIRFLOW-1094] Run unit tests under contrib in Travis
Rename all unit tests under tests/contrib to start
with test_* and fix
broken unit tests so that they run for the Python
2 and 3 builds.
Closes #2234 from hgrif/AIRFLOW-1094
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/219c5064
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/219c5064
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/219c5064
Branch: refs/heads/master
Commit: 219c5064142c66cf8f051455199f2dda9b164584
Parents: 74c1ce2
Author: Henk Griffioen <hg...@users.noreply.github.com>
Authored: Mon Apr 17 10:04:29 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Mon Apr 17 10:04:36 2017 +0200
----------------------------------------------------------------------
airflow/contrib/operators/ecs_operator.py | 2 +-
airflow/hooks/__init__.py | 1 +
airflow/hooks/zendesk_hook.py | 2 +-
scripts/ci/requirements.txt | 5 +
tests/contrib/hooks/aws_hook.py | 47 ----
tests/contrib/hooks/bigquery_hook.py | 139 ----------
tests/contrib/hooks/databricks_hook.py | 226 -----------------
tests/contrib/hooks/emr_hook.py | 53 ----
tests/contrib/hooks/gcp_dataflow_hook.py | 56 -----
tests/contrib/hooks/spark_submit_hook.py | 185 --------------
tests/contrib/hooks/sqoop_hook.py | 219 ----------------
tests/contrib/hooks/test_aws_hook.py | 47 ++++
tests/contrib/hooks/test_bigquery_hook.py | 139 ++++++++++
tests/contrib/hooks/test_databricks_hook.py | 226 +++++++++++++++++
tests/contrib/hooks/test_emr_hook.py | 53 ++++
tests/contrib/hooks/test_gcp_dataflow_hook.py | 56 +++++
tests/contrib/hooks/test_spark_submit_hook.py | 197 +++++++++++++++
tests/contrib/hooks/test_sqoop_hook.py | 218 ++++++++++++++++
tests/contrib/hooks/test_zendesk_hook.py | 89 +++++++
tests/contrib/hooks/zendesk_hook.py | 90 -------
tests/contrib/operators/__init__.py | 3 -
tests/contrib/operators/databricks_operator.py | 185 --------------
tests/contrib/operators/dataflow_operator.py | 82 ------
tests/contrib/operators/ecs_operator.py | 207 ---------------
.../contrib/operators/emr_add_steps_operator.py | 53 ----
.../operators/emr_create_job_flow_operator.py | 53 ----
.../emr_terminate_job_flow_operator.py | 52 ----
tests/contrib/operators/fs_operator.py | 64 -----
tests/contrib/operators/hipchat_operator.py | 74 ------
tests/contrib/operators/jira_operator_test.py | 101 --------
.../contrib/operators/spark_submit_operator.py | 81 ------
tests/contrib/operators/sqoop_operator.py | 93 -------
tests/contrib/operators/ssh_execute_operator.py | 79 ------
.../operators/test_databricks_operator.py | 185 ++++++++++++++
.../contrib/operators/test_dataflow_operator.py | 81 ++++++
tests/contrib/operators/test_ecs_operator.py | 214 ++++++++++++++++
.../operators/test_emr_add_steps_operator.py | 53 ++++
.../test_emr_create_job_flow_operator.py | 53 ++++
.../test_emr_terminate_job_flow_operator.py | 52 ++++
tests/contrib/operators/test_fs_operator.py | 64 +++++
.../contrib/operators/test_hipchat_operator.py | 74 ++++++
.../operators/test_jira_operator_test.py | 101 ++++++++
.../operators/test_spark_submit_operator.py | 88 +++++++
tests/contrib/operators/test_sqoop_operator.py | 93 +++++++
.../operators/test_ssh_execute_operator.py | 95 +++++++
tests/contrib/sensors/datadog_sensor.py | 91 -------
tests/contrib/sensors/emr_base_sensor.py | 126 ----------
tests/contrib/sensors/emr_job_flow_sensor.py | 123 ---------
tests/contrib/sensors/emr_step_sensor.py | 119 ---------
tests/contrib/sensors/ftp_sensor.py | 66 -----
tests/contrib/sensors/hdfs_sensors.py | 251 -------------------
tests/contrib/sensors/jira_sensor_test.py | 85 -------
tests/contrib/sensors/redis_sensor.py | 64 -----
tests/contrib/sensors/test_datadog_sensor.py | 106 ++++++++
tests/contrib/sensors/test_emr_base_sensor.py | 126 ++++++++++
.../contrib/sensors/test_emr_job_flow_sensor.py | 123 +++++++++
tests/contrib/sensors/test_emr_step_sensor.py | 119 +++++++++
tests/contrib/sensors/test_ftp_sensor.py | 66 +++++
tests/contrib/sensors/test_hdfs_sensors.py | 251 +++++++++++++++++++
tests/contrib/sensors/test_jira_sensor_test.py | 85 +++++++
tests/contrib/sensors/test_redis_sensor.py | 64 +++++
61 files changed, 3126 insertions(+), 3069 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py
index df02c4e..11f8c94 100644
--- a/airflow/contrib/operators/ecs_operator.py
+++ b/airflow/contrib/operators/ecs_operator.py
@@ -89,7 +89,7 @@ class ECSOperator(BaseOperator):
def _wait_for_task_ended(self):
waiter = self.client.get_waiter('tasks_stopped')
- waiter.config.max_attempts = sys.maxint # timeout is managed by airflow
+ waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(
cluster=self.cluster,
tasks=[self.arn]
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index cc09f5a..bb02967 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -48,6 +48,7 @@ _hooks = {
'samba_hook': ['SambaHook'],
'sqlite_hook': ['SqliteHook'],
'S3_hook': ['S3Hook'],
+ 'zendesk_hook': ['ZendeskHook'],
'http_hook': ['HttpHook'],
'druid_hook': ['DruidHook'],
'jdbc_hook': ['JdbcHook'],
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/hooks/zendesk_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py
index 438597f..907d1e8 100644
--- a/airflow/hooks/zendesk_hook.py
+++ b/airflow/hooks/zendesk_hook.py
@@ -21,7 +21,7 @@ A hook to talk to Zendesk
import logging
import time
from zdesk import Zendesk, RateLimitError, ZendeskError
-from airflow.hooks import BaseHook
+from airflow.hooks.base_hook import BaseHook
class ZendeskHook(BaseHook):
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index 1905398..751c13f 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -3,6 +3,7 @@ azure-storage>=0.34.0
bcrypt
bleach
boto
+boto3
celery
cgroupspy
chartkick
@@ -11,6 +12,7 @@ coverage
coveralls
croniter
cryptography
+datadog
dill
distributed
docker-py
@@ -25,6 +27,7 @@ Flask-WTF
flower
freezegun
future
+google-api-python-client>=1.5.0,<1.6.0
gunicorn
hdfs
hive-thrift-py
@@ -37,6 +40,7 @@ ldap3
lxml
markdown
mock
+moto
mysqlclient
nose
nose-exclude
@@ -69,3 +73,4 @@ statsd
thrift
thrift_sasl
unicodecsv
+zdesk
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/aws_hook.py b/tests/contrib/hooks/aws_hook.py
deleted file mode 100644
index 6f13e58..0000000
--- a/tests/contrib/hooks/aws_hook.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import boto3
-
-from airflow import configuration
-from airflow.contrib.hooks.aws_hook import AwsHook
-
-
-try:
- from moto import mock_emr
-except ImportError:
- mock_emr = None
-
-
-class TestAwsHook(unittest.TestCase):
- @mock_emr
- def setUp(self):
- configuration.load_test_config()
-
- @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
- @mock_emr
- def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
- client = boto3.client('emr', region_name='us-east-1')
- if len(client.list_clusters()['Clusters']):
- raise ValueError('AWS not properly mocked')
-
- hook = AwsHook(aws_conn_id='aws_default')
- client_from_hook = hook.get_client_type('emr')
-
- self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/bigquery_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/bigquery_hook.py b/tests/contrib/hooks/bigquery_hook.py
deleted file mode 100644
index 68856f8..0000000
--- a/tests/contrib/hooks/bigquery_hook.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow.contrib.hooks import bigquery_hook as hook
-
-
-class TestBigQueryTableSplitter(unittest.TestCase):
- def test_internal_need_default_project(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('dataset.table', None)
-
- self.assertIn('INTERNAL: No default project is specified',
- str(context.exception), "")
-
- def test_split_dataset_table(self):
- project, dataset, table = hook._split_tablename('dataset.table',
- 'project')
- self.assertEqual("project", project)
- self.assertEqual("dataset", dataset)
- self.assertEqual("table", table)
-
- def test_split_project_dataset_table(self):
- project, dataset, table = hook._split_tablename('alternative:dataset.table',
- 'project')
- self.assertEqual("alternative", project)
- self.assertEqual("dataset", dataset)
- self.assertEqual("table", table)
-
- def test_sql_split_project_dataset_table(self):
- project, dataset, table = hook._split_tablename('alternative.dataset.table',
- 'project')
- self.assertEqual("alternative", project)
- self.assertEqual("dataset", dataset)
- self.assertEqual("table", table)
-
- def test_invalid_syntax_column_double_project(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt.dataset.table',
- 'project')
-
- self.assertIn('Use either : or . to specify project',
- str(context.exception), "")
- self.assertFalse('Format exception for' in str(context.exception))
-
- def test_invalid_syntax_double_column(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt:dataset.table',
- 'project')
-
- self.assertIn('Expect format of (<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertFalse('Format exception for' in str(context.exception))
-
- def test_invalid_syntax_tiple_dot(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1.alt.dataset.table',
- 'project')
-
- self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertFalse('Format exception for' in str(context.exception))
-
- def test_invalid_syntax_column_double_project_var(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt.dataset.table',
- 'project', 'var_x')
-
- self.assertIn('Use either : or . to specify project',
- str(context.exception), "")
- self.assertIn('Format exception for var_x:',
- str(context.exception), "")
-
- def test_invalid_syntax_double_column_var(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt:dataset.table',
- 'project', 'var_x')
-
- self.assertIn('Expect format of (<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertIn('Format exception for var_x:',
- str(context.exception), "")
-
- def test_invalid_syntax_tiple_dot_var(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1.alt.dataset.table',
- 'project', 'var_x')
-
- self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertIn('Format exception for var_x:',
- str(context.exception), "")
-
-class TestBigQueryHookSourceFormat(unittest.TestCase):
- def test_invalid_source_format(self):
- with self.assertRaises(Exception) as context:
- hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json")
-
- # since we passed 'json' in, and it's not valid, make sure it's present in the error string.
- self.assertIn("json", str(context.exception))
-
-
-class TestBigQueryBaseCursor(unittest.TestCase):
- def test_invalid_schema_update_options(self):
- with self.assertRaises(Exception) as context:
- hook.BigQueryBaseCursor("test", "test").run_load(
- "test.test",
- "test_schema.json",
- ["test_data.json"],
- schema_update_options=["THIS IS NOT VALID"]
- )
- self.assertIn("THIS IS NOT VALID", str(context.exception))
-
- def test_invalid_schema_update_and_write_disposition(self):
- with self.assertRaises(Exception) as context:
- hook.BigQueryBaseCursor("test", "test").run_load(
- "test.test",
- "test_schema.json",
- ["test_data.json"],
- schema_update_options=['ALLOW_FIELD_ADDITION'],
- write_disposition='WRITE_EMPTY'
- )
- self.assertIn("schema_update_options is only", str(context.exception))
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/databricks_hook.py b/tests/contrib/hooks/databricks_hook.py
deleted file mode 100644
index 6c789f9..0000000
--- a/tests/contrib/hooks/databricks_hook.py
+++ /dev/null
@@ -1,226 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow import __version__
-from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
-from airflow.exceptions import AirflowException
-from airflow.models import Connection
-from airflow.utils import db
-from requests import exceptions as requests_exceptions
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-TASK_ID = 'databricks-operator'
-DEFAULT_CONN_ID = 'databricks_default'
-NOTEBOOK_TASK = {
- 'notebook_path': '/test'
-}
-NEW_CLUSTER = {
- 'spark_version': '2.0.x-scala2.10',
- 'node_type_id': 'r3.xlarge',
- 'num_workers': 1
-}
-RUN_ID = 1
-HOST = 'xx.cloud.databricks.com'
-HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
-LOGIN = 'login'
-PASSWORD = 'password'
-USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
-RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
-LIFE_CYCLE_STATE = 'PENDING'
-STATE_MESSAGE = 'Waiting for cluster'
-GET_RUN_RESPONSE = {
- 'run_page_url': RUN_PAGE_URL,
- 'state': {
- 'life_cycle_state': LIFE_CYCLE_STATE,
- 'state_message': STATE_MESSAGE
- }
-}
-RESULT_STATE = None
-
-
-def submit_run_endpoint(host):
- """
- Utility function to generate the submit run endpoint given the host.
- """
- return 'https://{}/api/2.0/jobs/runs/submit'.format(host)
-
-
-def get_run_endpoint(host):
- """
- Utility function to generate the get run endpoint given the host.
- """
- return 'https://{}/api/2.0/jobs/runs/get'.format(host)
-
-def cancel_run_endpoint(host):
- """
- Utility function to generate the get run endpoint given the host.
- """
- return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
-
-class DatabricksHookTest(unittest.TestCase):
- """
- Tests for DatabricksHook.
- """
- @db.provide_session
- def setUp(self, session=None):
- conn = session.query(Connection) \
- .filter(Connection.conn_id == DEFAULT_CONN_ID) \
- .first()
- conn.host = HOST
- conn.login = LOGIN
- conn.password = PASSWORD
- session.commit()
-
- self.hook = DatabricksHook()
-
- def test_parse_host_with_proper_host(self):
- host = self.hook._parse_host(HOST)
- self.assertEquals(host, HOST)
-
- def test_parse_host_with_scheme(self):
- host = self.hook._parse_host(HOST_WITH_SCHEME)
- self.assertEquals(host, HOST)
-
- def test_init_bad_retry_limit(self):
- with self.assertRaises(AssertionError):
- DatabricksHook(retry_limit = 0)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.logging')
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_do_api_call_with_error_retry(self, mock_requests, mock_logging):
- for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
- mock_requests.reset_mock()
- mock_logging.reset_mock()
- mock_requests.post.side_effect = exception()
-
- with self.assertRaises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
-
- self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_do_api_call_with_bad_status_code(self, mock_requests):
- mock_requests.codes.ok = 200
- status_code_mock = mock.PropertyMock(return_value=500)
- type(mock_requests.post.return_value).status_code = status_code_mock
- with self.assertRaises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_submit_run(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.post.return_value.json.return_value = {'run_id': '1'}
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.post.return_value).status_code = status_code_mock
- json = {
- 'notebook_task': NOTEBOOK_TASK,
- 'new_cluster': NEW_CLUSTER
- }
- run_id = self.hook.submit_run(json)
-
- self.assertEquals(run_id, '1')
- mock_requests.post.assert_called_once_with(
- submit_run_endpoint(HOST),
- json={
- 'notebook_task': NOTEBOOK_TASK,
- 'new_cluster': NEW_CLUSTER,
- },
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_get_run_page_url(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.get.return_value).status_code = status_code_mock
-
- run_page_url = self.hook.get_run_page_url(RUN_ID)
-
- self.assertEquals(run_page_url, RUN_PAGE_URL)
- mock_requests.get.assert_called_once_with(
- get_run_endpoint(HOST),
- json={'run_id': RUN_ID},
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_get_run_state(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.get.return_value).status_code = status_code_mock
-
- run_state = self.hook.get_run_state(RUN_ID)
-
- self.assertEquals(run_state, RunState(
- LIFE_CYCLE_STATE,
- RESULT_STATE,
- STATE_MESSAGE))
- mock_requests.get.assert_called_once_with(
- get_run_endpoint(HOST),
- json={'run_id': RUN_ID},
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_cancel_run(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.post.return_value).status_code = status_code_mock
-
- self.hook.cancel_run(RUN_ID)
-
- mock_requests.post.assert_called_once_with(
- cancel_run_endpoint(HOST),
- json={'run_id': RUN_ID},
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
-class RunStateTest(unittest.TestCase):
- def test_is_terminal_true(self):
- terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
- for state in terminal_states:
- run_state = RunState(state, '', '')
- self.assertTrue(run_state.is_terminal)
-
- def test_is_terminal_false(self):
- non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
- for state in non_terminal_states:
- run_state = RunState(state, '', '')
- self.assertFalse(run_state.is_terminal)
-
- def test_is_terminal_with_nonexistent_life_cycle_state(self):
- run_state = RunState('blah', '', '')
- with self.assertRaises(AirflowException):
- run_state.is_terminal
-
- def test_is_successful(self):
- run_state = RunState('TERMINATED', 'SUCCESS', '')
- self.assertTrue(run_state.is_successful)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/emr_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/emr_hook.py b/tests/contrib/hooks/emr_hook.py
deleted file mode 100644
index 119df99..0000000
--- a/tests/contrib/hooks/emr_hook.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import boto3
-
-from airflow import configuration
-from airflow.contrib.hooks.emr_hook import EmrHook
-
-
-try:
- from moto import mock_emr
-except ImportError:
- mock_emr = None
-
-
-class TestEmrHook(unittest.TestCase):
- @mock_emr
- def setUp(self):
- configuration.load_test_config()
-
- @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
- @mock_emr
- def test_get_conn_returns_a_boto3_connection(self):
- hook = EmrHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn().list_clusters())
-
- @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
- @mock_emr
- def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
- client = boto3.client('emr', region_name='us-east-1')
- if len(client.list_clusters()['Clusters']):
- raise ValueError('AWS not properly mocked')
-
- hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
- cluster = hook.create_job_flow({'Name': 'test_cluster'})
-
- self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/gcp_dataflow_hook.py b/tests/contrib/hooks/gcp_dataflow_hook.py
deleted file mode 100644
index 797d40c..0000000
--- a/tests/contrib/hooks/gcp_dataflow_hook.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-
-TASK_ID = 'test-python-dataflow'
-PY_FILE = 'apache_beam.examples.wordcount'
-PY_OPTIONS = ['-m']
-OPTIONS = {
- 'project': 'test',
- 'staging_location': 'gs://test/staging'
-}
-BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
-DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}'
-
-
-def mock_init(self, gcp_conn_id, delegate_to=None):
- pass
-
-
-class DataFlowHookTest(unittest.TestCase):
-
- def setUp(self):
- with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
- new=mock_init):
- self.dataflow_hook = DataFlowHook(gcp_conn_id='test')
-
- @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow'))
- def test_start_python_dataflow(self, internal_dataflow_mock):
- self.dataflow_hook.start_python_dataflow(
- task_id=TASK_ID, variables=OPTIONS,
- dataflow=PY_FILE, py_options=PY_OPTIONS)
- internal_dataflow_mock.assert_called_once_with(
- TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/spark_submit_hook.py b/tests/contrib/hooks/spark_submit_hook.py
deleted file mode 100644
index 8f514c2..0000000
--- a/tests/contrib/hooks/spark_submit_hook.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import os
-import unittest
-
-from airflow import configuration, models
-from airflow.utils import db
-from airflow.exceptions import AirflowException
-from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
-
-
-class TestSparkSubmitHook(unittest.TestCase):
- _spark_job_file = 'test_application.py'
- _config = {
- 'conf': {
- 'parquet.compression': 'SNAPPY'
- },
- 'conn_id': 'default_spark',
- 'files': 'hive-site.xml',
- 'py_files': 'sample_library.py',
- 'jars': 'parquet.jar',
- 'executor_cores': 4,
- 'executor_memory': '22g',
- 'keytab': 'privileged_user.keytab',
- 'principal': 'user/spark@airflow.org',
- 'name': 'spark-job',
- 'num_executors': 10,
- 'verbose': True,
- 'driver_memory': '3g',
- 'java_class': 'com.foo.bar.AppMain'
- }
-
- def setUp(self):
- configuration.load_test_config()
- db.merge_conn(
- models.Connection(
- conn_id='spark_yarn_cluster', conn_type='spark',
- host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
- )
- db.merge_conn(
- models.Connection(
- conn_id='spark_default_mesos', conn_type='spark',
- host='mesos://host', port=5050)
- )
-
- db.merge_conn(
- models.Connection(
- conn_id='spark_home_set', conn_type='spark',
- host='yarn://yarn-master',
- extra='{"spark-home": "/opt/myspark"}')
- )
-
- db.merge_conn(
- models.Connection(
- conn_id='spark_home_not_set', conn_type='spark',
- host='yarn://yarn-master')
- )
-
- def test_build_command(self):
- hook = SparkSubmitHook(**self._config)
-
- # The subprocess requires an array but we build the cmd by joining on a space
- cmd = ' '.join(hook._build_command(self._spark_job_file))
-
- # Check if the URL gets build properly and everything exists.
- assert self._spark_job_file in cmd
-
- # Check all the parameters
- assert "--files {}".format(self._config['files']) in cmd
- assert "--py-files {}".format(self._config['py_files']) in cmd
- assert "--jars {}".format(self._config['jars']) in cmd
- assert "--executor-cores {}".format(self._config['executor_cores']) in cmd
- assert "--executor-memory {}".format(self._config['executor_memory']) in cmd
- assert "--keytab {}".format(self._config['keytab']) in cmd
- assert "--principal {}".format(self._config['principal']) in cmd
- assert "--name {}".format(self._config['name']) in cmd
- assert "--num-executors {}".format(self._config['num_executors']) in cmd
- assert "--class {}".format(self._config['java_class']) in cmd
- assert "--driver-memory {}".format(self._config['driver_memory']) in cmd
-
- # Check if all config settings are there
- for k in self._config['conf']:
- assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd
-
- if self._config['verbose']:
- assert "--verbose" in cmd
-
- def test_submit(self):
- hook = SparkSubmitHook()
-
- # We don't have spark-submit available, and this is hard to mock, so just accept
- # an exception for now.
- with self.assertRaises(AirflowException):
- hook.submit(self._spark_job_file)
-
- def test_resolve_connection(self):
-
- # Default to the standard yarn connection because conn_id does not exists
- hook = SparkSubmitHook(conn_id='')
- self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None))
- assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
-
- # Default to the standard yarn connection
- hook = SparkSubmitHook(conn_id='spark_default')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn', 'root.default', None, None)
- )
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert "--master yarn" in cmd
- assert "--queue root.default" in cmd
-
- # Connect to a mesos master
- hook = SparkSubmitHook(conn_id='spark_default_mesos')
- self.assertEqual(
- hook._resolve_connection(),
- ('mesos://host:5050', None, None, None)
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert "--master mesos://host:5050" in cmd
-
- # Set specific queue and deploy mode
- hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn://yarn-master', 'root.etl', 'cluster', None)
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert "--master yarn://yarn-master" in cmd
- assert "--queue root.etl" in cmd
- assert "--deploy-mode cluster" in cmd
-
- # Set the spark home
- hook = SparkSubmitHook(conn_id='spark_home_set')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn://yarn-master', None, None, '/opt/myspark')
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert cmd.startswith('/opt/myspark/bin/spark-submit')
-
- # Spark home not set
- hook = SparkSubmitHook(conn_id='spark_home_not_set')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn://yarn-master', None, None, None)
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert cmd.startswith('spark-submit')
-
- def test_process_log(self):
- # Must select yarn connection
- hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
-
- log_lines = [
- 'SPARK_MAJOR_VERSION is set to 2, using Spark2',
- 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable',
- 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.',
- 'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
- 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
- ]
-
- hook._process_log(log_lines)
-
- assert hook._yarn_application_id == 'application_1486558679801_1820'
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/sqoop_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/sqoop_hook.py b/tests/contrib/hooks/sqoop_hook.py
deleted file mode 100644
index 1d85e43..0000000
--- a/tests/contrib/hooks/sqoop_hook.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import json
-import unittest
-from exceptions import OSError
-
-from airflow import configuration, models
-from airflow.contrib.hooks.sqoop_hook import SqoopHook
-from airflow.utils import db
-
-
-class TestSqoopHook(unittest.TestCase):
- _config = {
- 'conn_id': 'sqoop_test',
- 'num_mappers': 22,
- 'verbose': True,
- 'properties': {
- 'mapred.map.max.attempts': '1'
- }
- }
- _config_export = {
- 'table': 'domino.export_data_to',
- 'export_dir': '/hdfs/data/to/be/exported',
- 'input_null_string': '\n',
- 'input_null_non_string': '\t',
- 'staging_table': 'database.staging',
- 'clear_staging_table': True,
- 'enclosed_by': '"',
- 'escaped_by': '\\',
- 'input_fields_terminated_by': '|',
- 'input_lines_terminated_by': '\n',
- 'input_optionally_enclosed_by': '"',
- 'batch': True,
- 'relaxed_isolation': True
- }
- _config_import = {
- 'target_dir': '/hdfs/data/target/location',
- 'append': True,
- 'file_type': 'parquet',
- 'split_by': '\n',
- 'direct': True,
- 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver'
- }
-
- _config_json = {
- 'namenode': 'http://0.0.0.0:50070/',
- 'job_tracker': 'http://0.0.0.0:50030/',
- 'libjars': '/path/to/jars',
- 'files': '/path/to/files',
- 'archives': '/path/to/archives'
- }
-
- def setUp(self):
- configuration.load_test_config()
- db.merge_conn(
- models.Connection(
- conn_id='sqoop_test', conn_type='sqoop',
- host='rmdbs', port=5050, extra=json.dumps(self._config_json)
- )
- )
-
- def test_popen(self):
- hook = SqoopHook(**self._config)
-
- # Should go well
- hook.Popen(['ls'])
-
- # Should give an exception
- with self.assertRaises(OSError):
- hook.Popen('exit 1')
-
- def test_submit(self):
- hook = SqoopHook(**self._config)
-
- cmd = ' '.join(hook._prepare_command())
-
- # Check if the config has been extracted from the json
- if self._config_json['namenode']:
- assert "-fs {}".format(self._config_json['namenode']) in cmd
-
- if self._config_json['job_tracker']:
- assert "-jt {}".format(self._config_json['job_tracker']) in cmd
-
- if self._config_json['libjars']:
- assert "-libjars {}".format(self._config_json['libjars']) in cmd
-
- if self._config_json['files']:
- assert "-files {}".format(self._config_json['files']) in cmd
-
- if self._config_json['archives']:
- assert "-archives {}".format(self._config_json['archives']) in cmd
-
- # Check the regulator stuff passed by the default constructor
- if self._config['verbose']:
- assert "--verbose" in cmd
-
- if self._config['num_mappers']:
- assert "--num-mappers {}".format(
- self._config['num_mappers']) in cmd
-
- print(self._config['properties'])
- for key, value in self._config['properties'].items():
- assert "-D {}={}".format(key, value) in cmd
-
- # We don't have the sqoop binary available, and this is hard to mock,
- # so just accept an exception for now.
- with self.assertRaises(OSError):
- hook.export_table(**self._config_export)
-
- with self.assertRaises(OSError):
- hook.import_table(table='schema.table',
- target_dir='/sqoop/example/path')
-
- with self.assertRaises(OSError):
- hook.import_query(query='SELECT * FROM sometable',
- target_dir='/sqoop/example/path')
-
- def test_export_cmd(self):
- hook = SqoopHook()
-
- # The subprocess requires an array but we build the cmd by joining on a space
- cmd = ' '.join(
- hook._export_cmd(
- self._config_export['table'],
- self._config_export['export_dir'],
- input_null_string=self._config_export['input_null_string'],
- input_null_non_string=self._config_export[
- 'input_null_non_string'],
- staging_table=self._config_export['staging_table'],
- clear_staging_table=self._config_export['clear_staging_table'],
- enclosed_by=self._config_export['enclosed_by'],
- escaped_by=self._config_export['escaped_by'],
- input_fields_terminated_by=self._config_export[
- 'input_fields_terminated_by'],
- input_lines_terminated_by=self._config_export[
- 'input_lines_terminated_by'],
- input_optionally_enclosed_by=self._config_export[
- 'input_optionally_enclosed_by'],
- batch=self._config_export['batch'],
- relaxed_isolation=self._config_export['relaxed_isolation'])
- )
-
- assert "--input-null-string {}".format(
- self._config_export['input_null_string']) in cmd
- assert "--input-null-non-string {}".format(
- self._config_export['input_null_non_string']) in cmd
- assert "--staging-table {}".format(
- self._config_export['staging_table']) in cmd
- assert "--enclosed-by {}".format(
- self._config_export['enclosed_by']) in cmd
- assert "--escaped-by {}".format(
- self._config_export['escaped_by']) in cmd
- assert "--input-fields-terminated-by {}".format(
- self._config_export['input_fields_terminated_by']) in cmd
- assert "--input-lines-terminated-by {}".format(
- self._config_export['input_lines_terminated_by']) in cmd
- assert "--input-optionally-enclosed-by {}".format(
- self._config_export['input_optionally_enclosed_by']) in cmd
-
- if self._config_export['clear_staging_table']:
- assert "--clear-staging-table" in cmd
-
- if self._config_export['batch']:
- assert "--batch" in cmd
-
- if self._config_export['relaxed_isolation']:
- assert "--relaxed-isolation" in cmd
-
- def test_import_cmd(self):
- hook = SqoopHook()
-
- # The subprocess requires an array but we build the cmd by joining on a space
- cmd = ' '.join(
- hook._import_cmd(self._config_import['target_dir'],
- append=self._config_import['append'],
- file_type=self._config_import['file_type'],
- split_by=self._config_import['split_by'],
- direct=self._config_import['direct'],
- driver=self._config_import['driver'])
- )
-
- if self._config_import['append']:
- assert '--append' in cmd
-
- if self._config_import['direct']:
- assert '--direct' in cmd
-
- assert '--target-dir {}'.format(
- self._config_import['target_dir']) in cmd
-
- assert '--driver {}'.format(self._config_import['driver']) in cmd
- assert '--split-by {}'.format(self._config_import['split_by']) in cmd
-
- def test_get_export_format_argument(self):
- hook = SqoopHook()
- assert "--as-avrodatafile" in hook._get_export_format_argument('avro')
- assert "--as-parquetfile" in hook._get_export_format_argument(
- 'parquet')
- assert "--as-sequencefile" in hook._get_export_format_argument(
- 'sequence')
- assert "--as-textfile" in hook._get_export_format_argument('text')
- assert "--as-textfile" in hook._get_export_format_argument('unknown')
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_aws_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py
new file mode 100644
index 0000000..6f13e58
--- /dev/null
+++ b/tests/contrib/hooks/test_aws_hook.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+try:
+ from moto import mock_emr
+except ImportError:
+ mock_emr = None
+
+
+class TestAwsHook(unittest.TestCase):
+ @mock_emr
+ def setUp(self):
+ configuration.load_test_config()
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
+ client = boto3.client('emr', region_name='us-east-1')
+ if len(client.list_clusters()['Clusters']):
+ raise ValueError('AWS not properly mocked')
+
+ hook = AwsHook(aws_conn_id='aws_default')
+ client_from_hook = hook.get_client_type('emr')
+
+ self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_bigquery_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
new file mode 100644
index 0000000..0adffc5
--- /dev/null
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.contrib.hooks import bigquery_hook as hook
+
+
+class TestBigQueryTableSplitter(unittest.TestCase):
+ def test_internal_need_default_project(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('dataset.table', None)
+
+ self.assertIn('INTERNAL: No default project is specified',
+ str(context.exception), "")
+
+ def test_split_dataset_table(self):
+ project, dataset, table = hook._split_tablename('dataset.table',
+ 'project')
+ self.assertEqual("project", project)
+ self.assertEqual("dataset", dataset)
+ self.assertEqual("table", table)
+
+ def test_split_project_dataset_table(self):
+ project, dataset, table = hook._split_tablename('alternative:dataset.table',
+ 'project')
+ self.assertEqual("alternative", project)
+ self.assertEqual("dataset", dataset)
+ self.assertEqual("table", table)
+
+ def test_sql_split_project_dataset_table(self):
+ project, dataset, table = hook._split_tablename('alternative.dataset.table',
+ 'project')
+ self.assertEqual("alternative", project)
+ self.assertEqual("dataset", dataset)
+ self.assertEqual("table", table)
+
+ def test_invalid_syntax_column_double_project(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt.dataset.table',
+ 'project')
+
+ self.assertIn('Use either : or . to specify project',
+ str(context.exception), "")
+ self.assertFalse('Format exception for' in str(context.exception))
+
+ def test_invalid_syntax_double_column(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt:dataset.table',
+ 'project')
+
+ self.assertIn('Expect format of (<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertFalse('Format exception for' in str(context.exception))
+
+ def test_invalid_syntax_tiple_dot(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1.alt.dataset.table',
+ 'project')
+
+ self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertFalse('Format exception for' in str(context.exception))
+
+ def test_invalid_syntax_column_double_project_var(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt.dataset.table',
+ 'project', 'var_x')
+
+ self.assertIn('Use either : or . to specify project',
+ str(context.exception), "")
+ self.assertIn('Format exception for var_x:',
+ str(context.exception), "")
+
+ def test_invalid_syntax_double_column_var(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt:dataset.table',
+ 'project', 'var_x')
+
+ self.assertIn('Expect format of (<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertIn('Format exception for var_x:',
+ str(context.exception), "")
+
+ def test_invalid_syntax_tiple_dot_var(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1.alt.dataset.table',
+ 'project', 'var_x')
+
+ self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertIn('Format exception for var_x:',
+ str(context.exception), "")
+
+class TestBigQueryHookSourceFormat(unittest.TestCase):
+ def test_invalid_source_format(self):
+ with self.assertRaises(Exception) as context:
+ hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json")
+
+ # since we passed 'json' in, and it's not valid, make sure it's present in the error string.
+ self.assertIn("JSON", str(context.exception))
+
+
+class TestBigQueryBaseCursor(unittest.TestCase):
+ def test_invalid_schema_update_options(self):
+ with self.assertRaises(Exception) as context:
+ hook.BigQueryBaseCursor("test", "test").run_load(
+ "test.test",
+ "test_schema.json",
+ ["test_data.json"],
+ schema_update_options=["THIS IS NOT VALID"]
+ )
+ self.assertIn("THIS IS NOT VALID", str(context.exception))
+
+ def test_invalid_schema_update_and_write_disposition(self):
+ with self.assertRaises(Exception) as context:
+ hook.BigQueryBaseCursor("test", "test").run_load(
+ "test.test",
+ "test_schema.json",
+ ["test_data.json"],
+ schema_update_options=['ALLOW_FIELD_ADDITION'],
+ write_disposition='WRITE_EMPTY'
+ )
+ self.assertIn("schema_update_options is only", str(context.exception))
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
new file mode 100644
index 0000000..6c789f9
--- /dev/null
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow import __version__
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.utils import db
+from requests import exceptions as requests_exceptions
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+ 'notebook_path': '/test'
+}
+NEW_CLUSTER = {
+ 'spark_version': '2.0.x-scala2.10',
+ 'node_type_id': 'r3.xlarge',
+ 'num_workers': 1
+}
+RUN_ID = 1
+HOST = 'xx.cloud.databricks.com'
+HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
+LOGIN = 'login'
+PASSWORD = 'password'
+USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
+LIFE_CYCLE_STATE = 'PENDING'
+STATE_MESSAGE = 'Waiting for cluster'
+GET_RUN_RESPONSE = {
+ 'run_page_url': RUN_PAGE_URL,
+ 'state': {
+ 'life_cycle_state': LIFE_CYCLE_STATE,
+ 'state_message': STATE_MESSAGE
+ }
+}
+RESULT_STATE = None
+
+
+def submit_run_endpoint(host):
+ """
+ Utility function to generate the submit run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/submit'.format(host)
+
+
+def get_run_endpoint(host):
+ """
+ Utility function to generate the get run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/get'.format(host)
+
+def cancel_run_endpoint(host):
+ """
+ Utility function to generate the get run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
+
+class DatabricksHookTest(unittest.TestCase):
+ """
+ Tests for DatabricksHook.
+ """
+ @db.provide_session
+ def setUp(self, session=None):
+ conn = session.query(Connection) \
+ .filter(Connection.conn_id == DEFAULT_CONN_ID) \
+ .first()
+ conn.host = HOST
+ conn.login = LOGIN
+ conn.password = PASSWORD
+ session.commit()
+
+ self.hook = DatabricksHook()
+
+ def test_parse_host_with_proper_host(self):
+ host = self.hook._parse_host(HOST)
+ self.assertEquals(host, HOST)
+
+ def test_parse_host_with_scheme(self):
+ host = self.hook._parse_host(HOST_WITH_SCHEME)
+ self.assertEquals(host, HOST)
+
+ def test_init_bad_retry_limit(self):
+ with self.assertRaises(AssertionError):
+ DatabricksHook(retry_limit = 0)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.logging')
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_do_api_call_with_error_retry(self, mock_requests, mock_logging):
+ for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
+ mock_requests.reset_mock()
+ mock_logging.reset_mock()
+ mock_requests.post.side_effect = exception()
+
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_do_api_call_with_bad_status_code(self, mock_requests):
+ mock_requests.codes.ok = 200
+ status_code_mock = mock.PropertyMock(return_value=500)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_submit_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ json = {
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER
+ }
+ run_id = self.hook.submit_run(json)
+
+ self.assertEquals(run_id, '1')
+ mock_requests.post.assert_called_once_with(
+ submit_run_endpoint(HOST),
+ json={
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER,
+ },
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_get_run_page_url(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.get.return_value).status_code = status_code_mock
+
+ run_page_url = self.hook.get_run_page_url(RUN_ID)
+
+ self.assertEquals(run_page_url, RUN_PAGE_URL)
+ mock_requests.get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_get_run_state(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.get.return_value).status_code = status_code_mock
+
+ run_state = self.hook.get_run_state(RUN_ID)
+
+ self.assertEquals(run_state, RunState(
+ LIFE_CYCLE_STATE,
+ RESULT_STATE,
+ STATE_MESSAGE))
+ mock_requests.get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_cancel_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+
+ self.hook.cancel_run(RUN_ID)
+
+ mock_requests.post.assert_called_once_with(
+ cancel_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+class RunStateTest(unittest.TestCase):
+ def test_is_terminal_true(self):
+ terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
+ for state in terminal_states:
+ run_state = RunState(state, '', '')
+ self.assertTrue(run_state.is_terminal)
+
+ def test_is_terminal_false(self):
+ non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
+ for state in non_terminal_states:
+ run_state = RunState(state, '', '')
+ self.assertFalse(run_state.is_terminal)
+
+ def test_is_terminal_with_nonexistent_life_cycle_state(self):
+ run_state = RunState('blah', '', '')
+ with self.assertRaises(AirflowException):
+ run_state.is_terminal
+
+ def test_is_successful(self):
+ run_state = RunState('TERMINATED', 'SUCCESS', '')
+ self.assertTrue(run_state.is_successful)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_emr_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_emr_hook.py b/tests/contrib/hooks/test_emr_hook.py
new file mode 100644
index 0000000..119df99
--- /dev/null
+++ b/tests/contrib/hooks/test_emr_hook.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.hooks.emr_hook import EmrHook
+
+
+try:
+ from moto import mock_emr
+except ImportError:
+ mock_emr = None
+
+
+class TestEmrHook(unittest.TestCase):
+ @mock_emr
+ def setUp(self):
+ configuration.load_test_config()
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_conn_returns_a_boto3_connection(self):
+ hook = EmrHook(aws_conn_id='aws_default')
+ self.assertIsNotNone(hook.get_conn().list_clusters())
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
+ client = boto3.client('emr', region_name='us-east-1')
+ if len(client.list_clusters()['Clusters']):
+ raise ValueError('AWS not properly mocked')
+
+ hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
+ cluster = hook.create_job_flow({'Name': 'test_cluster'})
+
+ self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py
new file mode 100644
index 0000000..797d40c
--- /dev/null
+++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+TASK_ID = 'test-python-dataflow'
+PY_FILE = 'apache_beam.examples.wordcount'
+PY_OPTIONS = ['-m']
+OPTIONS = {
+ 'project': 'test',
+ 'staging_location': 'gs://test/staging'
+}
+BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
+DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}'
+
+
+def mock_init(self, gcp_conn_id, delegate_to=None):
+ pass
+
+
+class DataFlowHookTest(unittest.TestCase):
+
+ def setUp(self):
+ with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
+ new=mock_init):
+ self.dataflow_hook = DataFlowHook(gcp_conn_id='test')
+
+ @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow'))
+ def test_start_python_dataflow(self, internal_dataflow_mock):
+ self.dataflow_hook.start_python_dataflow(
+ task_id=TASK_ID, variables=OPTIONS,
+ dataflow=PY_FILE, py_options=PY_OPTIONS)
+ internal_dataflow_mock.assert_called_once_with(
+ TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
new file mode 100644
index 0000000..24315fa
--- /dev/null
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -0,0 +1,197 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import sys
+import unittest
+from io import StringIO
+
+import mock
+
+from airflow import configuration, models
+from airflow.utils import db
+from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
+
+
+class TestSparkSubmitHook(unittest.TestCase):
+
+ _spark_job_file = 'test_application.py'
+ _config = {
+ 'conf': {
+ 'parquet.compression': 'SNAPPY'
+ },
+ 'conn_id': 'default_spark',
+ 'files': 'hive-site.xml',
+ 'py_files': 'sample_library.py',
+ 'jars': 'parquet.jar',
+ 'executor_cores': 4,
+ 'executor_memory': '22g',
+ 'keytab': 'privileged_user.keytab',
+ 'principal': 'user/spark@airflow.org',
+ 'name': 'spark-job',
+ 'num_executors': 10,
+ 'verbose': True,
+ 'driver_memory': '3g',
+ 'java_class': 'com.foo.bar.AppMain'
+ }
+
+ def setUp(self):
+
+ if sys.version_info[0] == 3:
+ raise unittest.SkipTest('TestSparkSubmitHook won\'t work with '
+ 'python3. No need to test anything here')
+
+ configuration.load_test_config()
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_yarn_cluster', conn_type='spark',
+ host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
+ )
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_default_mesos', conn_type='spark',
+ host='mesos://host', port=5050)
+ )
+
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_home_set', conn_type='spark',
+ host='yarn://yarn-master',
+ extra='{"spark-home": "/opt/myspark"}')
+ )
+
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_home_not_set', conn_type='spark',
+ host='yarn://yarn-master')
+ )
+
+ def test_build_command(self):
+ hook = SparkSubmitHook(**self._config)
+
+ # The subprocess requires an array but we build the cmd by joining on a space
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+
+ # Check if the URL gets build properly and everything exists.
+ assert self._spark_job_file in cmd
+
+ # Check all the parameters
+ assert "--files {}".format(self._config['files']) in cmd
+ assert "--py-files {}".format(self._config['py_files']) in cmd
+ assert "--jars {}".format(self._config['jars']) in cmd
+ assert "--executor-cores {}".format(self._config['executor_cores']) in cmd
+ assert "--executor-memory {}".format(self._config['executor_memory']) in cmd
+ assert "--keytab {}".format(self._config['keytab']) in cmd
+ assert "--principal {}".format(self._config['principal']) in cmd
+ assert "--name {}".format(self._config['name']) in cmd
+ assert "--num-executors {}".format(self._config['num_executors']) in cmd
+ assert "--class {}".format(self._config['java_class']) in cmd
+ assert "--driver-memory {}".format(self._config['driver_memory']) in cmd
+
+ # Check if all config settings are there
+ for k in self._config['conf']:
+ assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd
+
+ if self._config['verbose']:
+ assert "--verbose" in cmd
+
+ @mock.patch('airflow.contrib.hooks.spark_submit_hook.subprocess')
+ def test_submit(self, mock_process):
+ # We don't have spark-submit available, and this is hard to mock, so let's
+ # just use this simple mock.
+ mock_Popen = mock_process.Popen.return_value
+ mock_Popen.stdout = StringIO(u'stdout')
+ mock_Popen.stderr = StringIO(u'stderr')
+ mock_Popen.returncode = None
+ mock_Popen.communicate.return_value = ['extra stdout', 'extra stderr']
+ hook = SparkSubmitHook()
+ hook.submit(self._spark_job_file)
+
+ def test_resolve_connection(self):
+
+ # Default to the standard yarn connection because conn_id does not exists
+ hook = SparkSubmitHook(conn_id='')
+ self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None))
+ assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
+
+ # Default to the standard yarn connection
+ hook = SparkSubmitHook(conn_id='spark_default')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn', 'root.default', None, None)
+ )
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert "--master yarn" in cmd
+ assert "--queue root.default" in cmd
+
+ # Connect to a mesos master
+ hook = SparkSubmitHook(conn_id='spark_default_mesos')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('mesos://host:5050', None, None, None)
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert "--master mesos://host:5050" in cmd
+
+ # Set specific queue and deploy mode
+ hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn://yarn-master', 'root.etl', 'cluster', None)
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert "--master yarn://yarn-master" in cmd
+ assert "--queue root.etl" in cmd
+ assert "--deploy-mode cluster" in cmd
+
+ # Set the spark home
+ hook = SparkSubmitHook(conn_id='spark_home_set')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn://yarn-master', None, None, '/opt/myspark')
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert cmd.startswith('/opt/myspark/bin/spark-submit')
+
+ # Spark home not set
+ hook = SparkSubmitHook(conn_id='spark_home_not_set')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn://yarn-master', None, None, None)
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert cmd.startswith('spark-submit')
+
+ def test_process_log(self):
+ # Must select yarn connection
+ hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
+
+ log_lines = [
+ 'SPARK_MAJOR_VERSION is set to 2, using Spark2',
+ 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable',
+ 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.',
+ 'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
+ 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
+ ]
+
+ hook._process_log(log_lines)
+
+ assert hook._yarn_application_id == 'application_1486558679801_1820'
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_sqoop_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_sqoop_hook.py b/tests/contrib/hooks/test_sqoop_hook.py
new file mode 100644
index 0000000..ca8033b
--- /dev/null
+++ b/tests/contrib/hooks/test_sqoop_hook.py
@@ -0,0 +1,218 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+
+from airflow import configuration, models
+from airflow.contrib.hooks.sqoop_hook import SqoopHook
+from airflow.utils import db
+
+
+class TestSqoopHook(unittest.TestCase):
+ _config = {
+ 'conn_id': 'sqoop_test',
+ 'num_mappers': 22,
+ 'verbose': True,
+ 'properties': {
+ 'mapred.map.max.attempts': '1'
+ }
+ }
+ _config_export = {
+ 'table': 'domino.export_data_to',
+ 'export_dir': '/hdfs/data/to/be/exported',
+ 'input_null_string': '\n',
+ 'input_null_non_string': '\t',
+ 'staging_table': 'database.staging',
+ 'clear_staging_table': True,
+ 'enclosed_by': '"',
+ 'escaped_by': '\\',
+ 'input_fields_terminated_by': '|',
+ 'input_lines_terminated_by': '\n',
+ 'input_optionally_enclosed_by': '"',
+ 'batch': True,
+ 'relaxed_isolation': True
+ }
+ _config_import = {
+ 'target_dir': '/hdfs/data/target/location',
+ 'append': True,
+ 'file_type': 'parquet',
+ 'split_by': '\n',
+ 'direct': True,
+ 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver'
+ }
+
+ _config_json = {
+ 'namenode': 'http://0.0.0.0:50070/',
+ 'job_tracker': 'http://0.0.0.0:50030/',
+ 'libjars': '/path/to/jars',
+ 'files': '/path/to/files',
+ 'archives': '/path/to/archives'
+ }
+
+ def setUp(self):
+ configuration.load_test_config()
+ db.merge_conn(
+ models.Connection(
+ conn_id='sqoop_test', conn_type='sqoop',
+ host='rmdbs', port=5050, extra=json.dumps(self._config_json)
+ )
+ )
+
+ def test_popen(self):
+ hook = SqoopHook(**self._config)
+
+ # Should go well
+ hook.Popen(['ls'])
+
+ # Should give an exception
+ with self.assertRaises(OSError):
+ hook.Popen('exit 1')
+
+ def test_submit(self):
+ hook = SqoopHook(**self._config)
+
+ cmd = ' '.join(hook._prepare_command())
+
+ # Check if the config has been extracted from the json
+ if self._config_json['namenode']:
+ assert "-fs {}".format(self._config_json['namenode']) in cmd
+
+ if self._config_json['job_tracker']:
+ assert "-jt {}".format(self._config_json['job_tracker']) in cmd
+
+ if self._config_json['libjars']:
+ assert "-libjars {}".format(self._config_json['libjars']) in cmd
+
+ if self._config_json['files']:
+ assert "-files {}".format(self._config_json['files']) in cmd
+
+ if self._config_json['archives']:
+ assert "-archives {}".format(self._config_json['archives']) in cmd
+
+ # Check the regulator stuff passed by the default constructor
+ if self._config['verbose']:
+ assert "--verbose" in cmd
+
+ if self._config['num_mappers']:
+ assert "--num-mappers {}".format(
+ self._config['num_mappers']) in cmd
+
+ print(self._config['properties'])
+ for key, value in self._config['properties'].items():
+ assert "-D {}={}".format(key, value) in cmd
+
+ # We don't have the sqoop binary available, and this is hard to mock,
+ # so just accept an exception for now.
+ with self.assertRaises(OSError):
+ hook.export_table(**self._config_export)
+
+ with self.assertRaises(OSError):
+ hook.import_table(table='schema.table',
+ target_dir='/sqoop/example/path')
+
+ with self.assertRaises(OSError):
+ hook.import_query(query='SELECT * FROM sometable',
+ target_dir='/sqoop/example/path')
+
+ def test_export_cmd(self):
+ hook = SqoopHook()
+
+ # The subprocess requires an array but we build the cmd by joining on a space
+ cmd = ' '.join(
+ hook._export_cmd(
+ self._config_export['table'],
+ self._config_export['export_dir'],
+ input_null_string=self._config_export['input_null_string'],
+ input_null_non_string=self._config_export[
+ 'input_null_non_string'],
+ staging_table=self._config_export['staging_table'],
+ clear_staging_table=self._config_export['clear_staging_table'],
+ enclosed_by=self._config_export['enclosed_by'],
+ escaped_by=self._config_export['escaped_by'],
+ input_fields_terminated_by=self._config_export[
+ 'input_fields_terminated_by'],
+ input_lines_terminated_by=self._config_export[
+ 'input_lines_terminated_by'],
+ input_optionally_enclosed_by=self._config_export[
+ 'input_optionally_enclosed_by'],
+ batch=self._config_export['batch'],
+ relaxed_isolation=self._config_export['relaxed_isolation'])
+ )
+
+ assert "--input-null-string {}".format(
+ self._config_export['input_null_string']) in cmd
+ assert "--input-null-non-string {}".format(
+ self._config_export['input_null_non_string']) in cmd
+ assert "--staging-table {}".format(
+ self._config_export['staging_table']) in cmd
+ assert "--enclosed-by {}".format(
+ self._config_export['enclosed_by']) in cmd
+ assert "--escaped-by {}".format(
+ self._config_export['escaped_by']) in cmd
+ assert "--input-fields-terminated-by {}".format(
+ self._config_export['input_fields_terminated_by']) in cmd
+ assert "--input-lines-terminated-by {}".format(
+ self._config_export['input_lines_terminated_by']) in cmd
+ assert "--input-optionally-enclosed-by {}".format(
+ self._config_export['input_optionally_enclosed_by']) in cmd
+
+ if self._config_export['clear_staging_table']:
+ assert "--clear-staging-table" in cmd
+
+ if self._config_export['batch']:
+ assert "--batch" in cmd
+
+ if self._config_export['relaxed_isolation']:
+ assert "--relaxed-isolation" in cmd
+
+ def test_import_cmd(self):
+ hook = SqoopHook()
+
+ # The subprocess requires an array but we build the cmd by joining on a space
+ cmd = ' '.join(
+ hook._import_cmd(self._config_import['target_dir'],
+ append=self._config_import['append'],
+ file_type=self._config_import['file_type'],
+ split_by=self._config_import['split_by'],
+ direct=self._config_import['direct'],
+ driver=self._config_import['driver'])
+ )
+
+ if self._config_import['append']:
+ assert '--append' in cmd
+
+ if self._config_import['direct']:
+ assert '--direct' in cmd
+
+ assert '--target-dir {}'.format(
+ self._config_import['target_dir']) in cmd
+
+ assert '--driver {}'.format(self._config_import['driver']) in cmd
+ assert '--split-by {}'.format(self._config_import['split_by']) in cmd
+
+ def test_get_export_format_argument(self):
+ hook = SqoopHook()
+ assert "--as-avrodatafile" in hook._get_export_format_argument('avro')
+ assert "--as-parquetfile" in hook._get_export_format_argument(
+ 'parquet')
+ assert "--as-sequencefile" in hook._get_export_format_argument(
+ 'sequence')
+ assert "--as-textfile" in hook._get_export_format_argument('text')
+ assert "--as-textfile" in hook._get_export_format_argument('unknown')
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_zendesk_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_zendesk_hook.py b/tests/contrib/hooks/test_zendesk_hook.py
new file mode 100644
index 0000000..7751a2b
--- /dev/null
+++ b/tests/contrib/hooks/test_zendesk_hook.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import mock
+
+from airflow.hooks.zendesk_hook import ZendeskHook
+from zdesk import RateLimitError
+
+
+class TestZendeskHook(unittest.TestCase):
+
+ @mock.patch("airflow.hooks.zendesk_hook.time")
+ def test_sleeps_for_correct_interval(self, mocked_time):
+ sleep_time = 10
+ # To break out of the otherwise infinite tries
+ mocked_time.sleep = mock.Mock(side_effect=ValueError, return_value=3)
+ conn_mock = mock.Mock()
+ mock_response = mock.Mock()
+ mock_response.headers.get.return_value = sleep_time
+ conn_mock.call = mock.Mock(
+ side_effect=RateLimitError(msg="some message", code="some code",
+ response=mock_response))
+
+ zendesk_hook = ZendeskHook("conn_id")
+ zendesk_hook.get_conn = mock.Mock(return_value=conn_mock)
+
+ with self.assertRaises(ValueError):
+ zendesk_hook.call("some_path", get_all_pages=False)
+ mocked_time.sleep.assert_called_with(sleep_time)
+
+ @mock.patch("airflow.hooks.zendesk_hook.Zendesk")
+ def test_returns_single_page_if_get_all_pages_false(self, _):
+ zendesk_hook = ZendeskHook("conn_id")
+ mock_connection = mock.Mock()
+ mock_connection.host = "some_host"
+ zendesk_hook.get_connection = mock.Mock(return_value=mock_connection)
+ zendesk_hook.get_conn()
+
+ mock_conn = mock.Mock()
+ mock_call = mock.Mock(
+ return_value={'next_page': 'https://some_host/something', 'path':
+ []})
+ mock_conn.call = mock_call
+ zendesk_hook.get_conn = mock.Mock(return_value=mock_conn)
+ zendesk_hook.call("path", get_all_pages=False)
+ mock_call.assert_called_once_with("path", None)
+
+ @mock.patch("airflow.hooks.zendesk_hook.Zendesk")
+ def test_returns_multiple_pages_if_get_all_pages_true(self, _):
+ zendesk_hook = ZendeskHook("conn_id")
+ mock_connection = mock.Mock()
+ mock_connection.host = "some_host"
+ zendesk_hook.get_connection = mock.Mock(return_value=mock_connection)
+ zendesk_hook.get_conn()
+
+ mock_conn = mock.Mock()
+ mock_call = mock.Mock(
+ return_value={'next_page': 'https://some_host/something', 'path': []})
+ mock_conn.call = mock_call
+ zendesk_hook.get_conn = mock.Mock(return_value=mock_conn)
+ zendesk_hook.call("path", get_all_pages=True)
+ assert mock_call.call_count == 2
+
+ @mock.patch("airflow.hooks.zendesk_hook.Zendesk")
+ def test_zdesk_is_inited_correctly(self, mock_zendesk):
+ conn_mock = mock.Mock()
+ conn_mock.host = "conn_host"
+ conn_mock.login = "conn_login"
+ conn_mock.password = "conn_pass"
+
+ zendesk_hook = ZendeskHook("conn_id")
+ zendesk_hook.get_connection = mock.Mock(return_value=conn_mock)
+ zendesk_hook.get_conn()
+ mock_zendesk.assert_called_with('https://conn_host', 'conn_login',
+ 'conn_pass', True)