You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/04/17 08:05:08 UTC
[1/3] incubator-airflow git commit: [AIRFLOW-1094] Run unit tests
under contrib in Travis
Repository: incubator-airflow
Updated Branches:
refs/heads/master 74c1ce254 -> 219c50641
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_sqoop_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_sqoop_operator.py b/tests/contrib/operators/test_sqoop_operator.py
new file mode 100644
index 0000000..a46dc93
--- /dev/null
+++ b/tests/contrib/operators/test_sqoop_operator.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import unittest
+
+from airflow import DAG, configuration
+from airflow.contrib.operators.sqoop_operator import SqoopOperator
+
+
+class TestSqoopOperator(unittest.TestCase):
+ _config = {
+ 'cmd_type': 'export',
+ 'table': 'target_table',
+ 'query': 'SELECT * FROM schema.table',
+ 'target_dir': '/path/on/hdfs/to/import',
+ 'append': True,
+ 'file_type': 'avro',
+ 'columns': 'a,b,c',
+ 'num_mappers': 22,
+ 'split_by': 'id',
+ 'export_dir': '/path/on/hdfs/to/export',
+ 'input_null_string': '\n',
+ 'input_null_non_string': '\t',
+ 'staging_table': 'target_table_staging',
+ 'clear_staging_table': True,
+ 'enclosed_by': '"',
+ 'escaped_by': '\\',
+ 'input_fields_terminated_by': '|',
+ 'input_lines_terminated_by': '\n',
+ 'input_optionally_enclosed_by': '"',
+ 'batch': True,
+ 'relaxed_isolation': True,
+ 'direct': True,
+ 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver',
+ 'properties': {
+ 'mapred.map.max.attempts': '1'
+ }
+ }
+
+ def setUp(self):
+ configuration.load_test_config()
+ args = {
+ 'owner': 'airflow',
+ 'start_date': datetime.datetime(2017, 1, 1)
+ }
+ self.dag = DAG('test_dag_id', default_args=args)
+
+ def test_execute(self, conn_id='sqoop_default'):
+ operator = SqoopOperator(
+ task_id='sqoop_job',
+ dag=self.dag,
+ **self._config
+ )
+
+ self.assertEqual(conn_id, operator.conn_id)
+
+ self.assertEqual(self._config['cmd_type'], operator.cmd_type)
+ self.assertEqual(self._config['table'], operator.table)
+ self.assertEqual(self._config['target_dir'], operator.target_dir)
+ self.assertEqual(self._config['append'], operator.append)
+ self.assertEqual(self._config['file_type'], operator.file_type)
+ self.assertEqual(self._config['num_mappers'], operator.num_mappers)
+ self.assertEqual(self._config['split_by'], operator.split_by)
+ self.assertEqual(self._config['input_null_string'],
+ operator.input_null_string)
+ self.assertEqual(self._config['input_null_non_string'],
+ operator.input_null_non_string)
+ self.assertEqual(self._config['staging_table'], operator.staging_table)
+ self.assertEqual(self._config['clear_staging_table'],
+ operator.clear_staging_table)
+ self.assertEqual(self._config['batch'], operator.batch)
+ self.assertEqual(self._config['relaxed_isolation'],
+ operator.relaxed_isolation)
+ self.assertEqual(self._config['direct'], operator.direct)
+ self.assertEqual(self._config['driver'], operator.driver)
+ self.assertEqual(self._config['properties'], operator.properties)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_ssh_execute_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_ssh_execute_operator.py b/tests/contrib/operators/test_ssh_execute_operator.py
new file mode 100644
index 0000000..0c2b9f2
--- /dev/null
+++ b/tests/contrib/operators/test_ssh_execute_operator.py
@@ -0,0 +1,95 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import os
+import sys
+from datetime import datetime
+from io import StringIO
+
+import mock
+
+from airflow import configuration
+from airflow.settings import Session
+from airflow import models, DAG
+from airflow.contrib.operators.ssh_execute_operator import SSHExecuteOperator
+
+
+TEST_DAG_ID = 'unit_tests'
+DEFAULT_DATE = datetime(2015, 1, 1)
+configuration.load_test_config()
+
+
+def reset(dag_id=TEST_DAG_ID):
+ session = Session()
+ tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
+ tis.delete()
+ session.commit()
+ session.close()
+
+reset()
+
+
+class SSHExecuteOperatorTest(unittest.TestCase):
+
+ def setUp(self):
+
+ if sys.version_info[0] == 3:
+ raise unittest.SkipTest('SSHExecuteOperatorTest won\'t work with '
+ 'python3. No need to test anything here')
+
+ configuration.load_test_config()
+ from airflow.contrib.hooks.ssh_hook import SSHHook
+ hook = mock.MagicMock(spec=SSHHook)
+ hook.no_host_key_check = True
+ hook.Popen.return_value.stdout = StringIO(u'stdout')
+ hook.Popen.return_value.returncode = False
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'provide_context': True
+ }
+ dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args)
+ dag.schedule_interval = '@once'
+ self.hook = hook
+ self.dag = dag
+
+ @mock.patch('airflow.contrib.operators.ssh_execute_operator.SSHTempFileContent')
+ def test_simple(self, temp_file):
+ temp_file.return_value.__enter__ = lambda x: 'filepath'
+ task = SSHExecuteOperator(
+ task_id="test",
+ bash_command="echo airflow",
+ ssh_hook=self.hook,
+ dag=self.dag,
+ )
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ @mock.patch('airflow.contrib.operators.ssh_execute_operator.SSHTempFileContent')
+ def test_with_env(self, temp_file):
+ temp_file.return_value.__enter__ = lambda x: 'filepath'
+ test_env = os.environ.copy()
+ test_env['AIRFLOW_test'] = "test"
+ task = SSHExecuteOperator(
+ task_id="test",
+ bash_command="echo $AIRFLOW_HOME",
+ ssh_hook=self.hook,
+ env=test_env['AIRFLOW_test'],
+ dag=self.dag,
+ )
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/datadog_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/datadog_sensor.py b/tests/contrib/sensors/datadog_sensor.py
deleted file mode 100644
index 4d601e1..0000000
--- a/tests/contrib/sensors/datadog_sensor.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-from mock import patch
-
-from airflow.contrib.sensors.datadog_sensor import DatadogSensor
-
-
-at_least_one_event = [{'alert_type': 'info',
- 'comments': [],
- 'date_happened': 1419436860,
- 'device_name': None,
- 'host': None,
- 'id': 2603387619536318140,
- 'is_aggregate': False,
- 'priority': 'normal',
- 'resource': '/api/v1/events/2603387619536318140',
- 'source': 'My Apps',
- 'tags': ['application:web', 'version:1'],
- 'text': 'And let me tell you all about it here!',
- 'title': 'Something big happened!',
- 'url': '/event/jump_to?event_id=2603387619536318140'},
- {'alert_type': 'info',
- 'comments': [],
- 'date_happened': 1419436865,
- 'device_name': None,
- 'host': None,
- 'id': 2603387619536318141,
- 'is_aggregate': False,
- 'priority': 'normal',
- 'resource': '/api/v1/events/2603387619536318141',
- 'source': 'My Apps',
- 'tags': ['application:web', 'version:1'],
- 'text': 'And let me tell you all about it here!',
- 'title': 'Something big happened!',
- 'url': '/event/jump_to?event_id=2603387619536318141'}]
-
-zero_events = []
-
-
-class TestDatadogSensor(unittest.TestCase):
- @patch('airflow.contrib.hooks.datadog_hook.api.Event.query')
- @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query')
- def test_sensor_ok(self, api1, api2):
- api1.return_value = at_least_one_event
- api2.return_value = at_least_one_event
-
- sensor = DatadogSensor(
- task_id='test_datadog',
- datadog_conn_id='datadog_default',
- from_seconds_ago=3600,
- up_to_seconds_from_now=0,
- priority=None,
- sources=None,
- tags=None,
- response_check=None)
-
- self.assertTrue(sensor.poke({}))
-
- @patch('airflow.contrib.hooks.datadog_hook.api.Event.query')
- @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query')
- def test_sensor_fail(self, api1, api2):
- api1.return_value = zero_events
- api2.return_value = zero_events
-
- sensor = DatadogSensor(
- task_id='test_datadog',
- datadog_conn_id='datadog_default',
- from_seconds_ago=0,
- up_to_seconds_from_now=0,
- priority=None,
- sources=None,
- tags=None,
- response_check=None)
-
- self.assertFalse(sensor.poke({}))
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/emr_base_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/emr_base_sensor.py b/tests/contrib/sensors/emr_base_sensor.py
deleted file mode 100644
index 0b8ad2f..0000000
--- a/tests/contrib/sensors/emr_base_sensor.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-from airflow import configuration
-from airflow.exceptions import AirflowException
-from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
-
-
-class TestEmrBaseSensor(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- def test_subclasses_that_implment_required_methods_and_constants_succeed_when_response_is_good(self):
- class EmrBaseSensorSubclass(EmrBaseSensor):
- NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
- FAILED_STATE = 'FAILED'
-
- def get_emr_response(self):
- return {
- 'SomeKey': {'State': 'COMPLETED'},
- 'ResponseMetadata': {'HTTPStatusCode': 200}
- }
-
- def state_from_response(self, response):
- return response['SomeKey']['State']
-
- operator = EmrBaseSensorSubclass(
- task_id='test_task',
- poke_interval=2,
- job_flow_id='j-8989898989',
- aws_conn_id='aws_test'
- )
-
- operator.execute(None)
-
- def test_poke_returns_false_when_state_is_a_non_terminal_state(self):
- class EmrBaseSensorSubclass(EmrBaseSensor):
- NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
- FAILED_STATE = 'FAILED'
-
- def get_emr_response(self):
- return {
- 'SomeKey': {'State': 'PENDING'},
- 'ResponseMetadata': {'HTTPStatusCode': 200}
- }
-
- def state_from_response(self, response):
- return response['SomeKey']['State']
-
- operator = EmrBaseSensorSubclass(
- task_id='test_task',
- poke_interval=2,
- job_flow_id='j-8989898989',
- aws_conn_id='aws_test'
- )
-
- self.assertEqual(operator.poke(None), False)
-
- def test_poke_returns_false_when_http_response_is_bad(self):
- class EmrBaseSensorSubclass(EmrBaseSensor):
- NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
- FAILED_STATE = 'FAILED'
-
- def get_emr_response(self):
- return {
- 'SomeKey': {'State': 'COMPLETED'},
- 'ResponseMetadata': {'HTTPStatusCode': 400}
- }
-
- def state_from_response(self, response):
- return response['SomeKey']['State']
-
- operator = EmrBaseSensorSubclass(
- task_id='test_task',
- poke_interval=2,
- job_flow_id='j-8989898989',
- aws_conn_id='aws_test'
- )
-
- self.assertEqual(operator.poke(None), False)
-
-
- def test_poke_raises_error_when_job_has_failed(self):
- class EmrBaseSensorSubclass(EmrBaseSensor):
- NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
- FAILED_STATE = 'FAILED'
-
- def get_emr_response(self):
- return {
- 'SomeKey': {'State': 'FAILED'},
- 'ResponseMetadata': {'HTTPStatusCode': 200}
- }
-
- def state_from_response(self, response):
- return response['SomeKey']['State']
-
- operator = EmrBaseSensorSubclass(
- task_id='test_task',
- poke_interval=2,
- job_flow_id='j-8989898989',
- aws_conn_id='aws_test'
- )
-
- with self.assertRaises(AirflowException) as context:
-
- operator.poke(None)
-
-
- self.assertTrue('EMR job failed' in context.exception)
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/emr_job_flow_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/emr_job_flow_sensor.py b/tests/contrib/sensors/emr_job_flow_sensor.py
deleted file mode 100644
index f993786..0000000
--- a/tests/contrib/sensors/emr_job_flow_sensor.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-import datetime
-from dateutil.tz import tzlocal
-from mock import MagicMock, patch
-
-from airflow import configuration
-from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor
-
-DESCRIBE_CLUSTER_RUNNING_RETURN = {
- 'Cluster': {
- 'Applications': [
- {'Name': 'Spark', 'Version': '1.6.1'}
- ],
- 'AutoTerminate': True,
- 'Configurations': [],
- 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'},
- 'Id': 'j-27ZY9GBEEU2GU',
- 'LogUri': 's3n://some-location/',
- 'Name': 'PiCalc',
- 'NormalizedInstanceHours': 0,
- 'ReleaseLabel': 'emr-4.6.0',
- 'ServiceRole': 'EMR_DefaultRole',
- 'Status': {
- 'State': 'STARTING',
- 'StateChangeReason': {},
- 'Timeline': {'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())}
- },
- 'Tags': [
- {'Key': 'app', 'Value': 'analytics'},
- {'Key': 'environment', 'Value': 'development'}
- ],
- 'TerminationProtected': False,
- 'VisibleToAllUsers': True
- },
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200,
- 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'
- }
-}
-
-DESCRIBE_CLUSTER_TERMINATED_RETURN = {
- 'Cluster': {
- 'Applications': [
- {'Name': 'Spark', 'Version': '1.6.1'}
- ],
- 'AutoTerminate': True,
- 'Configurations': [],
- 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'},
- 'Id': 'j-27ZY9GBEEU2GU',
- 'LogUri': 's3n://some-location/',
- 'Name': 'PiCalc',
- 'NormalizedInstanceHours': 0,
- 'ReleaseLabel': 'emr-4.6.0',
- 'ServiceRole': 'EMR_DefaultRole',
- 'Status': {
- 'State': 'TERMINATED',
- 'StateChangeReason': {},
- 'Timeline': {'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())}
- },
- 'Tags': [
- {'Key': 'app', 'Value': 'analytics'},
- {'Key': 'environment', 'Value': 'development'}
- ],
- 'TerminationProtected': False,
- 'VisibleToAllUsers': True
- },
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200,
- 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'
- }
-}
-
-
-class TestEmrJobFlowSensor(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- # Mock out the emr_client (moto has incorrect response)
- self.mock_emr_client = MagicMock()
- self.mock_emr_client.describe_cluster.side_effect = [
- DESCRIBE_CLUSTER_RUNNING_RETURN,
- DESCRIBE_CLUSTER_TERMINATED_RETURN
- ]
-
- # Mock out the emr_client creator
- self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client)
-
-
- def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(self):
- with patch('boto3.client', self.boto3_client_mock):
-
- operator = EmrJobFlowSensor(
- task_id='test_task',
- poke_interval=2,
- job_flow_id='j-8989898989',
- aws_conn_id='aws_default'
- )
-
- operator.execute(None)
-
- # make sure we called twice
- self.assertEqual(self.mock_emr_client.describe_cluster.call_count, 2)
-
- # make sure it was called with the job_flow_id
- self.mock_emr_client.describe_cluster.assert_called_with(ClusterId='j-8989898989')
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/emr_step_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/emr_step_sensor.py b/tests/contrib/sensors/emr_step_sensor.py
deleted file mode 100644
index 58ee461..0000000
--- a/tests/contrib/sensors/emr_step_sensor.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-import datetime
-from dateutil.tz import tzlocal
-from mock import MagicMock, patch
-import boto3
-
-from airflow import configuration
-from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor
-
-DESCRIBE_JOB_STEP_RUNNING_RETURN = {
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200,
- 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'
- },
- 'Step': {
- 'ActionOnFailure': 'CONTINUE',
- 'Config': {
- 'Args': [
- '/usr/lib/spark/bin/run-example',
- 'SparkPi',
- '10'
- ],
- 'Jar': 'command-runner.jar',
- 'Properties': {}
- },
- 'Id': 's-VK57YR1Z9Z5N',
- 'Name': 'calculate_pi',
- 'Status': {
- 'State': 'RUNNING',
- 'StateChangeReason': {},
- 'Timeline': {
- 'CreationDateTime': datetime.datetime(2016, 6, 20, 19, 0, 18, 787000, tzinfo=tzlocal()),
- 'StartDateTime': datetime.datetime(2016, 6, 20, 19, 2, 34, 889000, tzinfo=tzlocal())
- }
- }
- }
-}
-
-DESCRIBE_JOB_STEP_COMPLETED_RETURN = {
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200,
- 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'
- },
- 'Step': {
- 'ActionOnFailure': 'CONTINUE',
- 'Config': {
- 'Args': [
- '/usr/lib/spark/bin/run-example',
- 'SparkPi',
- '10'
- ],
- 'Jar': 'command-runner.jar',
- 'Properties': {}
- },
- 'Id': 's-VK57YR1Z9Z5N',
- 'Name': 'calculate_pi',
- 'Status': {
- 'State': 'COMPLETED',
- 'StateChangeReason': {},
- 'Timeline': {
- 'CreationDateTime': datetime.datetime(2016, 6, 20, 19, 0, 18, 787000, tzinfo=tzlocal()),
- 'StartDateTime': datetime.datetime(2016, 6, 20, 19, 2, 34, 889000, tzinfo=tzlocal())
- }
- }
- }
-}
-
-
-class TestEmrStepSensor(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- # Mock out the emr_client (moto has incorrect response)
- self.mock_emr_client = MagicMock()
- self.mock_emr_client.describe_step.side_effect = [
- DESCRIBE_JOB_STEP_RUNNING_RETURN,
- DESCRIBE_JOB_STEP_COMPLETED_RETURN
- ]
-
- # Mock out the emr_client creator
- self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client)
-
-
- def test_execute_calls_with_the_job_flow_id_and_step_id_until_it_reaches_a_terminal_state(self):
- with patch('boto3.client', self.boto3_client_mock):
-
- operator = EmrStepSensor(
- task_id='test_task',
- poke_interval=1,
- job_flow_id='j-8989898989',
- step_id='s-VK57YR1Z9Z5N',
- aws_conn_id='aws_default',
- )
-
- operator.execute(None)
-
- # make sure we called twice
- self.assertEqual(self.mock_emr_client.describe_step.call_count, 2)
-
- # make sure it was called with the job_flow_id and step_id
- self.mock_emr_client.describe_step.assert_called_with(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N')
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/ftp_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/ftp_sensor.py b/tests/contrib/sensors/ftp_sensor.py
deleted file mode 100644
index 50f8b8b..0000000
--- a/tests/contrib/sensors/ftp_sensor.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-from ftplib import error_perm
-
-from mock import MagicMock
-
-from airflow.contrib.hooks.ftp_hook import FTPHook
-from airflow.contrib.sensors.ftp_sensor import FTPSensor
-
-
-class TestFTPSensor(unittest.TestCase):
- def setUp(self):
- super(TestFTPSensor, self).setUp()
- self._create_hook_orig = FTPSensor._create_hook
- self.hook_mock = MagicMock(spec=FTPHook)
-
- def _create_hook_mock(sensor):
- mock = MagicMock()
- mock.__enter__ = lambda x: self.hook_mock
-
- return mock
-
- FTPSensor._create_hook = _create_hook_mock
-
- def tearDown(self):
- FTPSensor._create_hook = self._create_hook_orig
- super(TestFTPSensor, self).tearDown()
-
- def test_poke(self):
- op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp",
- task_id="test_task")
-
- self.hook_mock.get_mod_time.side_effect = \
- [error_perm("550: Can't check for file existence"), None]
-
- self.assertFalse(op.poke(None))
- self.assertTrue(op.poke(None))
-
- def test_poke_fails_due_error(self):
- op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp",
- task_id="test_task")
-
- self.hook_mock.get_mod_time.side_effect = \
- error_perm("530: Login authentication failed")
-
- with self.assertRaises(error_perm) as context:
- op.execute(None)
-
- self.assertTrue("530" in str(context.exception))
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/hdfs_sensors.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/hdfs_sensors.py b/tests/contrib/sensors/hdfs_sensors.py
deleted file mode 100644
index 0e2ed0c..0000000
--- a/tests/contrib/sensors/hdfs_sensors.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import sys
-import unittest
-import re
-from datetime import timedelta
-from airflow.contrib.sensors.hdfs_sensors import HdfsSensorFolder, HdfsSensorRegex
-from airflow.exceptions import AirflowSensorTimeout
-
-
-class HdfsSensorFolderTests(unittest.TestCase):
- def setUp(self):
- if sys.version_info[0] == 3:
- raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here')
- from tests.core import FakeHDFSHook
- self.hook = FakeHDFSHook
- self.logger = logging.getLogger()
- self.logger.setLevel(logging.DEBUG)
-
- def test_should_be_empty_directory(self):
- """
- test the empty directory behaviour
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- task = HdfsSensorFolder(task_id='Should_be_empty_directory',
- filepath='/datadirectory/empty_directory',
- be_empty=True,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- task.execute(None)
-
- # Then
- # Nothing happens, nothing is raised exec is ok
-
- def test_should_be_empty_directory_fail(self):
- """
- test the empty directory behaviour
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- task = HdfsSensorFolder(task_id='Should_be_empty_directory_fail',
- filepath='/datadirectory/not_empty_directory',
- be_empty=True,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- # Then
- with self.assertRaises(AirflowSensorTimeout):
- task.execute(None)
-
- def test_should_be_a_non_empty_directory(self):
- """
- test the empty directory behaviour
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- task = HdfsSensorFolder(task_id='Should_be_non_empty_directory',
- filepath='/datadirectory/not_empty_directory',
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- task.execute(None)
-
- # Then
- # Nothing happens, nothing is raised exec is ok
-
- def test_should_be_non_empty_directory_fail(self):
- """
- test the empty directory behaviour
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- task = HdfsSensorFolder(task_id='Should_be_empty_directory_fail',
- filepath='/datadirectory/empty_directory',
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- # Then
- with self.assertRaises(AirflowSensorTimeout):
- task.execute(None)
-
-
-class HdfsSensorRegexTests(unittest.TestCase):
- def setUp(self):
- if sys.version_info[0] == 3:
- raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here')
- from tests.core import FakeHDFSHook
- self.hook = FakeHDFSHook
- self.logger = logging.getLogger()
- self.logger.setLevel(logging.DEBUG)
-
- def test_should_match_regex(self):
- """
- test the empty directory behaviour
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- compiled_regex = re.compile("test[1-2]file")
- task = HdfsSensorRegex(task_id='Should_match_the_regex',
- filepath='/datadirectory/regex_dir',
- regex=compiled_regex,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- task.execute(None)
-
- # Then
- # Nothing happens, nothing is raised exec is ok
-
- def test_should_not_match_regex(self):
- """
- test the empty directory behaviour
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- compiled_regex = re.compile("^IDoNotExist")
- task = HdfsSensorRegex(task_id='Should_not_match_the_regex',
- filepath='/datadirectory/regex_dir',
- regex=compiled_regex,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- # Then
- with self.assertRaises(AirflowSensorTimeout):
- task.execute(None)
-
- def test_should_match_regex_and_filesize(self):
- """
- test the file size behaviour with regex
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- compiled_regex = re.compile("test[1-2]file")
- task = HdfsSensorRegex(task_id='Should_match_the_regex_and_filesize',
- filepath='/datadirectory/regex_dir',
- regex=compiled_regex,
- ignore_copying=True,
- ignored_ext=['_COPYING_', 'sftp'],
- file_size=10,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- task.execute(None)
-
- # Then
- # Nothing happens, nothing is raised exec is ok
-
- def test_should_match_regex_but_filesize(self):
- """
- test the file size behaviour with regex
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- compiled_regex = re.compile("test[1-2]file")
- task = HdfsSensorRegex(task_id='Should_match_the_regex_but_filesize',
- filepath='/datadirectory/regex_dir',
- regex=compiled_regex,
- file_size=20,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- # Then
- with self.assertRaises(AirflowSensorTimeout):
- task.execute(None)
-
- def test_should_match_regex_but_copyingext(self):
- """
- test the file size behaviour with regex
- :return:
- """
- # Given
- self.logger.debug('#' * 10)
- self.logger.debug('Running %s', self._testMethodName)
- self.logger.debug('#' * 10)
- compiled_regex = re.compile("copying_file_\d+.txt")
- task = HdfsSensorRegex(task_id='Should_match_the_regex_but_filesize',
- filepath='/datadirectory/regex_dir',
- regex=compiled_regex,
- ignored_ext=['_COPYING_', 'sftp'],
- file_size=20,
- timeout=1,
- retry_delay=timedelta(seconds=1),
- poke_interval=1,
- hook=self.hook)
-
- # When
- # Then
- with self.assertRaises(AirflowSensorTimeout):
- task.execute(None)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/jira_sensor_test.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/jira_sensor_test.py b/tests/contrib/sensors/jira_sensor_test.py
deleted file mode 100644
index 77ca97f..0000000
--- a/tests/contrib/sensors/jira_sensor_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import datetime
-from mock import Mock
-from mock import patch
-
-from airflow import DAG, configuration
-from airflow.contrib.sensors.jira_sensor import JiraTicketSensor
-from airflow import models
-from airflow.utils import db
-
-DEFAULT_DATE = datetime.datetime(2017, 1, 1)
-jira_client_mock = Mock(
- name="jira_client_for_test"
-)
-
-minimal_test_ticket = {
- "id": "911539",
- "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539",
- "key": "TEST-1226",
- "fields": {
- "labels": [
- "test-label-1",
- "test-label-2"
- ],
- "description": "this is a test description",
- }
-}
-
-
-class TestJiraSensor(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
- dag = DAG('test_dag_id', default_args=args)
- self.dag = dag
- db.merge_conn(
- models.Connection(
- conn_id='jira_default', conn_type='jira',
- host='https://localhost/jira/', port=443,
- extra='{"verify": "False", "project": "AIRFLOW"}'))
-
- @patch("airflow.contrib.hooks.jira_hook.JIRA",
- autospec=True, return_value=jira_client_mock)
- def test_issue_label_set(self, jira_mock):
- jira_mock.return_value.issue.return_value = minimal_test_ticket
-
- ticket_label_sensor = JiraTicketSensor(task_id='search-ticket-test',
- ticket_id='TEST-1226',
- field_checker_func=
- TestJiraSensor.field_checker_func,
- timeout=518400,
- poke_interval=10,
- dag=self.dag)
-
- ticket_label_sensor.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- self.assertTrue(jira_mock.called)
- self.assertTrue(jira_mock.return_value.issue.called)
-
- @staticmethod
- def field_checker_func(context, issue):
- return "test-label-1" in issue['fields']['labels']
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/redis_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/redis_sensor.py b/tests/contrib/sensors/redis_sensor.py
deleted file mode 100644
index 8022a92..0000000
--- a/tests/contrib/sensors/redis_sensor.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import unittest
-import datetime
-
-from mock import patch
-
-from airflow import DAG
-from airflow import configuration
-from airflow.contrib.sensors.redis_key_sensor import RedisKeySensor
-
-DEFAULT_DATE = datetime.datetime(2017, 1, 1)
-
-
-class TestRedisSensor(unittest.TestCase):
-
- def setUp(self):
- configuration.load_test_config()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
-
- self.dag = DAG('test_dag_id', default_args=args)
- self.sensor = RedisKeySensor(
- task_id='test_task',
- redis_conn_id='redis_default',
- dag=self.dag,
- key='test_key'
- )
-
- @patch("airflow.contrib.hooks.redis_hook.RedisHook.key_exists")
- def test_poke(self, key_exists):
- key_exists.return_value = True
- self.assertTrue(self.sensor.poke(None))
-
- key_exists.return_value = False
- self.assertFalse(self.sensor.poke(None))
-
- @patch("airflow.contrib.hooks.redis_hook.StrictRedis.exists")
- def test_existing_key_called(self, redis_client_exists):
- self.sensor.run(
- start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True
- )
-
- self.assertTrue(redis_client_exists.called_with('test_key'))
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_datadog_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_datadog_sensor.py b/tests/contrib/sensors/test_datadog_sensor.py
new file mode 100644
index 0000000..d845c54
--- /dev/null
+++ b/tests/contrib/sensors/test_datadog_sensor.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import unittest
+from mock import patch
+
+from airflow import configuration
+from airflow.utils import db
+from airflow import models
+from airflow.contrib.sensors.datadog_sensor import DatadogSensor
+
+
+at_least_one_event = [{'alert_type': 'info',
+ 'comments': [],
+ 'date_happened': 1419436860,
+ 'device_name': None,
+ 'host': None,
+ 'id': 2603387619536318140,
+ 'is_aggregate': False,
+ 'priority': 'normal',
+ 'resource': '/api/v1/events/2603387619536318140',
+ 'source': 'My Apps',
+ 'tags': ['application:web', 'version:1'],
+ 'text': 'And let me tell you all about it here!',
+ 'title': 'Something big happened!',
+ 'url': '/event/jump_to?event_id=2603387619536318140'},
+ {'alert_type': 'info',
+ 'comments': [],
+ 'date_happened': 1419436865,
+ 'device_name': None,
+ 'host': None,
+ 'id': 2603387619536318141,
+ 'is_aggregate': False,
+ 'priority': 'normal',
+ 'resource': '/api/v1/events/2603387619536318141',
+ 'source': 'My Apps',
+ 'tags': ['application:web', 'version:1'],
+ 'text': 'And let me tell you all about it here!',
+ 'title': 'Something big happened!',
+ 'url': '/event/jump_to?event_id=2603387619536318141'}]
+
+zero_events = []
+
+
+class TestDatadogSensor(unittest.TestCase):
+
+ def setUp(self):
+ configuration.load_test_config()
+ db.merge_conn(
+ models.Connection(
+ conn_id='datadog_default', conn_type='datadog',
+ login='login', password='password',
+ extra=json.dumps({'api_key': 'api_key', 'app_key': 'app_key'})
+ )
+ )
+
+ @patch('airflow.contrib.hooks.datadog_hook.api.Event.query')
+ @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query')
+ def test_sensor_ok(self, api1, api2):
+ api1.return_value = at_least_one_event
+ api2.return_value = at_least_one_event
+
+ sensor = DatadogSensor(
+ task_id='test_datadog',
+ datadog_conn_id='datadog_default',
+ from_seconds_ago=3600,
+ up_to_seconds_from_now=0,
+ priority=None,
+ sources=None,
+ tags=None,
+ response_check=None)
+
+ self.assertTrue(sensor.poke({}))
+
+ @patch('airflow.contrib.hooks.datadog_hook.api.Event.query')
+ @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query')
+ def test_sensor_fail(self, api1, api2):
+ api1.return_value = zero_events
+ api2.return_value = zero_events
+
+ sensor = DatadogSensor(
+ task_id='test_datadog',
+ datadog_conn_id='datadog_default',
+ from_seconds_ago=0,
+ up_to_seconds_from_now=0,
+ priority=None,
+ sources=None,
+ tags=None,
+ response_check=None)
+
+ self.assertFalse(sensor.poke({}))
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_emr_base_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_emr_base_sensor.py b/tests/contrib/sensors/test_emr_base_sensor.py
new file mode 100644
index 0000000..9c39abb
--- /dev/null
+++ b/tests/contrib/sensors/test_emr_base_sensor.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.contrib.sensors.emr_base_sensor import EmrBaseSensor
+
+
+class TestEmrBaseSensor(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ def test_subclasses_that_implment_required_methods_and_constants_succeed_when_response_is_good(self):
+ class EmrBaseSensorSubclass(EmrBaseSensor):
+ NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
+ FAILED_STATE = 'FAILED'
+
+ def get_emr_response(self):
+ return {
+ 'SomeKey': {'State': 'COMPLETED'},
+ 'ResponseMetadata': {'HTTPStatusCode': 200}
+ }
+
+ def state_from_response(self, response):
+ return response['SomeKey']['State']
+
+ operator = EmrBaseSensorSubclass(
+ task_id='test_task',
+ poke_interval=2,
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_test'
+ )
+
+ operator.execute(None)
+
+ def test_poke_returns_false_when_state_is_a_non_terminal_state(self):
+ class EmrBaseSensorSubclass(EmrBaseSensor):
+ NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
+ FAILED_STATE = 'FAILED'
+
+ def get_emr_response(self):
+ return {
+ 'SomeKey': {'State': 'PENDING'},
+ 'ResponseMetadata': {'HTTPStatusCode': 200}
+ }
+
+ def state_from_response(self, response):
+ return response['SomeKey']['State']
+
+ operator = EmrBaseSensorSubclass(
+ task_id='test_task',
+ poke_interval=2,
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_test'
+ )
+
+ self.assertEqual(operator.poke(None), False)
+
+ def test_poke_returns_false_when_http_response_is_bad(self):
+ class EmrBaseSensorSubclass(EmrBaseSensor):
+ NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
+ FAILED_STATE = 'FAILED'
+
+ def get_emr_response(self):
+ return {
+ 'SomeKey': {'State': 'COMPLETED'},
+ 'ResponseMetadata': {'HTTPStatusCode': 400}
+ }
+
+ def state_from_response(self, response):
+ return response['SomeKey']['State']
+
+ operator = EmrBaseSensorSubclass(
+ task_id='test_task',
+ poke_interval=2,
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_test'
+ )
+
+ self.assertEqual(operator.poke(None), False)
+
+
+ def test_poke_raises_error_when_job_has_failed(self):
+ class EmrBaseSensorSubclass(EmrBaseSensor):
+ NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE']
+ FAILED_STATE = 'FAILED'
+
+ def get_emr_response(self):
+ return {
+ 'SomeKey': {'State': 'FAILED'},
+ 'ResponseMetadata': {'HTTPStatusCode': 200}
+ }
+
+ def state_from_response(self, response):
+ return response['SomeKey']['State']
+
+ operator = EmrBaseSensorSubclass(
+ task_id='test_task',
+ poke_interval=2,
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_test'
+ )
+
+ with self.assertRaises(AirflowException) as context:
+
+ operator.poke(None)
+
+
+ self.assertIn('EMR job failed', str(context.exception))
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_emr_job_flow_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_emr_job_flow_sensor.py b/tests/contrib/sensors/test_emr_job_flow_sensor.py
new file mode 100644
index 0000000..f993786
--- /dev/null
+++ b/tests/contrib/sensors/test_emr_job_flow_sensor.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import datetime
+from dateutil.tz import tzlocal
+from mock import MagicMock, patch
+
+from airflow import configuration
+from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor
+
+DESCRIBE_CLUSTER_RUNNING_RETURN = {
+ 'Cluster': {
+ 'Applications': [
+ {'Name': 'Spark', 'Version': '1.6.1'}
+ ],
+ 'AutoTerminate': True,
+ 'Configurations': [],
+ 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'},
+ 'Id': 'j-27ZY9GBEEU2GU',
+ 'LogUri': 's3n://some-location/',
+ 'Name': 'PiCalc',
+ 'NormalizedInstanceHours': 0,
+ 'ReleaseLabel': 'emr-4.6.0',
+ 'ServiceRole': 'EMR_DefaultRole',
+ 'Status': {
+ 'State': 'STARTING',
+ 'StateChangeReason': {},
+ 'Timeline': {'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())}
+ },
+ 'Tags': [
+ {'Key': 'app', 'Value': 'analytics'},
+ {'Key': 'environment', 'Value': 'development'}
+ ],
+ 'TerminationProtected': False,
+ 'VisibleToAllUsers': True
+ },
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200,
+ 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'
+ }
+}
+
+DESCRIBE_CLUSTER_TERMINATED_RETURN = {
+ 'Cluster': {
+ 'Applications': [
+ {'Name': 'Spark', 'Version': '1.6.1'}
+ ],
+ 'AutoTerminate': True,
+ 'Configurations': [],
+ 'Ec2InstanceAttributes': {'IamInstanceProfile': 'EMR_EC2_DefaultRole'},
+ 'Id': 'j-27ZY9GBEEU2GU',
+ 'LogUri': 's3n://some-location/',
+ 'Name': 'PiCalc',
+ 'NormalizedInstanceHours': 0,
+ 'ReleaseLabel': 'emr-4.6.0',
+ 'ServiceRole': 'EMR_DefaultRole',
+ 'Status': {
+ 'State': 'TERMINATED',
+ 'StateChangeReason': {},
+ 'Timeline': {'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2, 348000, tzinfo=tzlocal())}
+ },
+ 'Tags': [
+ {'Key': 'app', 'Value': 'analytics'},
+ {'Key': 'environment', 'Value': 'development'}
+ ],
+ 'TerminationProtected': False,
+ 'VisibleToAllUsers': True
+ },
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200,
+ 'RequestId': 'd5456308-3caa-11e6-9d46-951401f04e0e'
+ }
+}
+
+
+class TestEmrJobFlowSensor(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ # Mock out the emr_client (moto has incorrect response)
+ self.mock_emr_client = MagicMock()
+ self.mock_emr_client.describe_cluster.side_effect = [
+ DESCRIBE_CLUSTER_RUNNING_RETURN,
+ DESCRIBE_CLUSTER_TERMINATED_RETURN
+ ]
+
+ # Mock out the emr_client creator
+ self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client)
+
+
+ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(self):
+ with patch('boto3.client', self.boto3_client_mock):
+
+ operator = EmrJobFlowSensor(
+ task_id='test_task',
+ poke_interval=2,
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_default'
+ )
+
+ operator.execute(None)
+
+ # make sure we called twice
+ self.assertEqual(self.mock_emr_client.describe_cluster.call_count, 2)
+
+ # make sure it was called with the job_flow_id
+ self.mock_emr_client.describe_cluster.assert_called_with(ClusterId='j-8989898989')
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_emr_step_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_emr_step_sensor.py b/tests/contrib/sensors/test_emr_step_sensor.py
new file mode 100644
index 0000000..58ee461
--- /dev/null
+++ b/tests/contrib/sensors/test_emr_step_sensor.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import datetime
+from dateutil.tz import tzlocal
+from mock import MagicMock, patch
+import boto3
+
+from airflow import configuration
+from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor
+
+DESCRIBE_JOB_STEP_RUNNING_RETURN = {
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200,
+ 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'
+ },
+ 'Step': {
+ 'ActionOnFailure': 'CONTINUE',
+ 'Config': {
+ 'Args': [
+ '/usr/lib/spark/bin/run-example',
+ 'SparkPi',
+ '10'
+ ],
+ 'Jar': 'command-runner.jar',
+ 'Properties': {}
+ },
+ 'Id': 's-VK57YR1Z9Z5N',
+ 'Name': 'calculate_pi',
+ 'Status': {
+ 'State': 'RUNNING',
+ 'StateChangeReason': {},
+ 'Timeline': {
+ 'CreationDateTime': datetime.datetime(2016, 6, 20, 19, 0, 18, 787000, tzinfo=tzlocal()),
+ 'StartDateTime': datetime.datetime(2016, 6, 20, 19, 2, 34, 889000, tzinfo=tzlocal())
+ }
+ }
+ }
+}
+
+DESCRIBE_JOB_STEP_COMPLETED_RETURN = {
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200,
+ 'RequestId': '8dee8db2-3719-11e6-9e20-35b2f861a2a6'
+ },
+ 'Step': {
+ 'ActionOnFailure': 'CONTINUE',
+ 'Config': {
+ 'Args': [
+ '/usr/lib/spark/bin/run-example',
+ 'SparkPi',
+ '10'
+ ],
+ 'Jar': 'command-runner.jar',
+ 'Properties': {}
+ },
+ 'Id': 's-VK57YR1Z9Z5N',
+ 'Name': 'calculate_pi',
+ 'Status': {
+ 'State': 'COMPLETED',
+ 'StateChangeReason': {},
+ 'Timeline': {
+ 'CreationDateTime': datetime.datetime(2016, 6, 20, 19, 0, 18, 787000, tzinfo=tzlocal()),
+ 'StartDateTime': datetime.datetime(2016, 6, 20, 19, 2, 34, 889000, tzinfo=tzlocal())
+ }
+ }
+ }
+}
+
+
+class TestEmrStepSensor(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ # Mock out the emr_client (moto has incorrect response)
+ self.mock_emr_client = MagicMock()
+ self.mock_emr_client.describe_step.side_effect = [
+ DESCRIBE_JOB_STEP_RUNNING_RETURN,
+ DESCRIBE_JOB_STEP_COMPLETED_RETURN
+ ]
+
+ # Mock out the emr_client creator
+ self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client)
+
+
+ def test_execute_calls_with_the_job_flow_id_and_step_id_until_it_reaches_a_terminal_state(self):
+ with patch('boto3.client', self.boto3_client_mock):
+
+ operator = EmrStepSensor(
+ task_id='test_task',
+ poke_interval=1,
+ job_flow_id='j-8989898989',
+ step_id='s-VK57YR1Z9Z5N',
+ aws_conn_id='aws_default',
+ )
+
+ operator.execute(None)
+
+ # make sure we called twice
+ self.assertEqual(self.mock_emr_client.describe_step.call_count, 2)
+
+ # make sure it was called with the job_flow_id and step_id
+ self.mock_emr_client.describe_step.assert_called_with(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N')
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_ftp_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_ftp_sensor.py b/tests/contrib/sensors/test_ftp_sensor.py
new file mode 100644
index 0000000..50f8b8b
--- /dev/null
+++ b/tests/contrib/sensors/test_ftp_sensor.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from ftplib import error_perm
+
+from mock import MagicMock
+
+from airflow.contrib.hooks.ftp_hook import FTPHook
+from airflow.contrib.sensors.ftp_sensor import FTPSensor
+
+
+class TestFTPSensor(unittest.TestCase):
+ def setUp(self):
+ super(TestFTPSensor, self).setUp()
+ self._create_hook_orig = FTPSensor._create_hook
+ self.hook_mock = MagicMock(spec=FTPHook)
+
+ def _create_hook_mock(sensor):
+ mock = MagicMock()
+ mock.__enter__ = lambda x: self.hook_mock
+
+ return mock
+
+ FTPSensor._create_hook = _create_hook_mock
+
+ def tearDown(self):
+ FTPSensor._create_hook = self._create_hook_orig
+ super(TestFTPSensor, self).tearDown()
+
+ def test_poke(self):
+ op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp",
+ task_id="test_task")
+
+ self.hook_mock.get_mod_time.side_effect = \
+ [error_perm("550: Can't check for file existence"), None]
+
+ self.assertFalse(op.poke(None))
+ self.assertTrue(op.poke(None))
+
+ def test_poke_fails_due_error(self):
+ op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp",
+ task_id="test_task")
+
+ self.hook_mock.get_mod_time.side_effect = \
+ error_perm("530: Login authentication failed")
+
+ with self.assertRaises(error_perm) as context:
+ op.execute(None)
+
+ self.assertTrue("530" in str(context.exception))
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_hdfs_sensors.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_hdfs_sensors.py b/tests/contrib/sensors/test_hdfs_sensors.py
new file mode 100644
index 0000000..0e2ed0c
--- /dev/null
+++ b/tests/contrib/sensors/test_hdfs_sensors.py
@@ -0,0 +1,251 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+import unittest
+import re
+from datetime import timedelta
+from airflow.contrib.sensors.hdfs_sensors import HdfsSensorFolder, HdfsSensorRegex
+from airflow.exceptions import AirflowSensorTimeout
+
+
+class HdfsSensorFolderTests(unittest.TestCase):
+ def setUp(self):
+ if sys.version_info[0] == 3:
+ raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here')
+ from tests.core import FakeHDFSHook
+ self.hook = FakeHDFSHook
+ self.logger = logging.getLogger()
+ self.logger.setLevel(logging.DEBUG)
+
+ def test_should_be_empty_directory(self):
+ """
+ test the empty directory behaviour
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ task = HdfsSensorFolder(task_id='Should_be_empty_directory',
+ filepath='/datadirectory/empty_directory',
+ be_empty=True,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ task.execute(None)
+
+ # Then
+ # Nothing happens, nothing is raised exec is ok
+
+ def test_should_be_empty_directory_fail(self):
+ """
+ test the empty directory behaviour
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ task = HdfsSensorFolder(task_id='Should_be_empty_directory_fail',
+ filepath='/datadirectory/not_empty_directory',
+ be_empty=True,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ # Then
+ with self.assertRaises(AirflowSensorTimeout):
+ task.execute(None)
+
+ def test_should_be_a_non_empty_directory(self):
+ """
+ test the empty directory behaviour
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ task = HdfsSensorFolder(task_id='Should_be_non_empty_directory',
+ filepath='/datadirectory/not_empty_directory',
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ task.execute(None)
+
+ # Then
+ # Nothing happens, nothing is raised exec is ok
+
+ def test_should_be_non_empty_directory_fail(self):
+ """
+ test the empty directory behaviour
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ task = HdfsSensorFolder(task_id='Should_be_empty_directory_fail',
+ filepath='/datadirectory/empty_directory',
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ # Then
+ with self.assertRaises(AirflowSensorTimeout):
+ task.execute(None)
+
+
+class HdfsSensorRegexTests(unittest.TestCase):
+ def setUp(self):
+ if sys.version_info[0] == 3:
+ raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here')
+ from tests.core import FakeHDFSHook
+ self.hook = FakeHDFSHook
+ self.logger = logging.getLogger()
+ self.logger.setLevel(logging.DEBUG)
+
+ def test_should_match_regex(self):
+ """
+ test the empty directory behaviour
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ compiled_regex = re.compile("test[1-2]file")
+ task = HdfsSensorRegex(task_id='Should_match_the_regex',
+ filepath='/datadirectory/regex_dir',
+ regex=compiled_regex,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ task.execute(None)
+
+ # Then
+ # Nothing happens, nothing is raised exec is ok
+
+ def test_should_not_match_regex(self):
+ """
+ test the empty directory behaviour
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ compiled_regex = re.compile("^IDoNotExist")
+ task = HdfsSensorRegex(task_id='Should_not_match_the_regex',
+ filepath='/datadirectory/regex_dir',
+ regex=compiled_regex,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ # Then
+ with self.assertRaises(AirflowSensorTimeout):
+ task.execute(None)
+
+ def test_should_match_regex_and_filesize(self):
+ """
+ test the file size behaviour with regex
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ compiled_regex = re.compile("test[1-2]file")
+ task = HdfsSensorRegex(task_id='Should_match_the_regex_and_filesize',
+ filepath='/datadirectory/regex_dir',
+ regex=compiled_regex,
+ ignore_copying=True,
+ ignored_ext=['_COPYING_', 'sftp'],
+ file_size=10,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ task.execute(None)
+
+ # Then
+ # Nothing happens, nothing is raised exec is ok
+
+ def test_should_match_regex_but_filesize(self):
+ """
+ test the file size behaviour with regex
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ compiled_regex = re.compile("test[1-2]file")
+ task = HdfsSensorRegex(task_id='Should_match_the_regex_but_filesize',
+ filepath='/datadirectory/regex_dir',
+ regex=compiled_regex,
+ file_size=20,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ # Then
+ with self.assertRaises(AirflowSensorTimeout):
+ task.execute(None)
+
+ def test_should_match_regex_but_copyingext(self):
+ """
+ test the file size behaviour with regex
+ :return:
+ """
+ # Given
+ self.logger.debug('#' * 10)
+ self.logger.debug('Running %s', self._testMethodName)
+ self.logger.debug('#' * 10)
+ compiled_regex = re.compile("copying_file_\d+.txt")
+ task = HdfsSensorRegex(task_id='Should_match_the_regex_but_filesize',
+ filepath='/datadirectory/regex_dir',
+ regex=compiled_regex,
+ ignored_ext=['_COPYING_', 'sftp'],
+ file_size=20,
+ timeout=1,
+ retry_delay=timedelta(seconds=1),
+ poke_interval=1,
+ hook=self.hook)
+
+ # When
+ # Then
+ with self.assertRaises(AirflowSensorTimeout):
+ task.execute(None)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_jira_sensor_test.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_jira_sensor_test.py b/tests/contrib/sensors/test_jira_sensor_test.py
new file mode 100644
index 0000000..77ca97f
--- /dev/null
+++ b/tests/contrib/sensors/test_jira_sensor_test.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import datetime
+from mock import Mock
+from mock import patch
+
+from airflow import DAG, configuration
+from airflow.contrib.sensors.jira_sensor import JiraTicketSensor
+from airflow import models
+from airflow.utils import db
+
+DEFAULT_DATE = datetime.datetime(2017, 1, 1)
+jira_client_mock = Mock(
+ name="jira_client_for_test"
+)
+
+minimal_test_ticket = {
+ "id": "911539",
+ "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539",
+ "key": "TEST-1226",
+ "fields": {
+ "labels": [
+ "test-label-1",
+ "test-label-2"
+ ],
+ "description": "this is a test description",
+ }
+}
+
+
+class TestJiraSensor(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
+ }
+ dag = DAG('test_dag_id', default_args=args)
+ self.dag = dag
+ db.merge_conn(
+ models.Connection(
+ conn_id='jira_default', conn_type='jira',
+ host='https://localhost/jira/', port=443,
+ extra='{"verify": "False", "project": "AIRFLOW"}'))
+
+ @patch("airflow.contrib.hooks.jira_hook.JIRA",
+ autospec=True, return_value=jira_client_mock)
+ def test_issue_label_set(self, jira_mock):
+ jira_mock.return_value.issue.return_value = minimal_test_ticket
+
+ ticket_label_sensor = JiraTicketSensor(task_id='search-ticket-test',
+ ticket_id='TEST-1226',
+ field_checker_func=
+ TestJiraSensor.field_checker_func,
+ timeout=518400,
+ poke_interval=10,
+ dag=self.dag)
+
+ ticket_label_sensor.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ self.assertTrue(jira_mock.called)
+ self.assertTrue(jira_mock.return_value.issue.called)
+
+ @staticmethod
+ def field_checker_func(context, issue):
+ return "test-label-1" in issue['fields']['labels']
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/sensors/test_redis_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_redis_sensor.py b/tests/contrib/sensors/test_redis_sensor.py
new file mode 100644
index 0000000..8022a92
--- /dev/null
+++ b/tests/contrib/sensors/test_redis_sensor.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+import datetime
+
+from mock import patch
+
+from airflow import DAG
+from airflow import configuration
+from airflow.contrib.sensors.redis_key_sensor import RedisKeySensor
+
+DEFAULT_DATE = datetime.datetime(2017, 1, 1)
+
+
+class TestRedisSensor(unittest.TestCase):
+
+ def setUp(self):
+ configuration.load_test_config()
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
+ }
+
+ self.dag = DAG('test_dag_id', default_args=args)
+ self.sensor = RedisKeySensor(
+ task_id='test_task',
+ redis_conn_id='redis_default',
+ dag=self.dag,
+ key='test_key'
+ )
+
+ @patch("airflow.contrib.hooks.redis_hook.RedisHook.key_exists")
+ def test_poke(self, key_exists):
+ key_exists.return_value = True
+ self.assertTrue(self.sensor.poke(None))
+
+ key_exists.return_value = False
+ self.assertFalse(self.sensor.poke(None))
+
+ @patch("airflow.contrib.hooks.redis_hook.StrictRedis.exists")
+ def test_existing_key_called(self, redis_client_exists):
+ self.sensor.run(
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True
+ )
+
+ self.assertTrue(redis_client_exists.called_with('test_key'))
+
+
+if __name__ == '__main__':
+ unittest.main()
[2/3] incubator-airflow git commit: [AIRFLOW-1094] Run unit tests
under contrib in Travis
Posted by bo...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/zendesk_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/zendesk_hook.py b/tests/contrib/hooks/zendesk_hook.py
deleted file mode 100644
index 66b8e6b..0000000
--- a/tests/contrib/hooks/zendesk_hook.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-from unittest.mock import Mock, patch
-from plugins.hooks.zendesk_hook import ZendeskHook
-from zdesk import RateLimitError
-from pytest import raises
-
-
-@patch("plugins.hooks.zendesk_hook.time")
-@patch("plugins.hooks.zendesk_hook.Zendesk")
-def test_sleeps_for_correct_interval(_, mocked_time):
- sleep_time = 10
-
- # To break out of the otherwise infinite tries
- mocked_time.sleep = Mock(side_effect=ValueError)
- conn_mock = Mock()
- mock_response = Mock()
- mock_response.headers.get.return_value = sleep_time
- conn_mock.call = Mock(
- side_effect=RateLimitError(msg="some message", code="some code",
- response=mock_response))
-
- zendesk_hook = ZendeskHook("conn_id")
- zendesk_hook.get_conn = Mock(return_value=conn_mock)
-
- with raises(ValueError):
- zendesk_hook.call("some_path", get_all_pages=False)
- mocked_time.sleep.assert_called_with(sleep_time)
-
-
-@patch("plugins.hooks.zendesk_hook.Zendesk")
-def test_returns_single_page_if_get_all_pages_false(_):
- zendesk_hook = ZendeskHook("conn_id")
- mock_connection = Mock()
- mock_connection.host = "some_host"
- zendesk_hook.get_connection = Mock(return_value=mock_connection)
- zendesk_hook.get_conn()
-
- mock_conn = Mock()
- mock_call = Mock(
- return_value={'next_page': 'https://some_host/something', 'path': []})
- mock_conn.call = mock_call
- zendesk_hook.get_conn = Mock(return_value=mock_conn)
- zendesk_hook.call("path", get_all_pages=False)
- mock_call.assert_called_once_with("path", None)
-
-
-@patch("plugins.hooks.zendesk_hook.Zendesk")
-def test_returns_multiple_pages_if_get_all_pages_true(_):
- zendesk_hook = ZendeskHook("conn_id")
- mock_connection = Mock()
- mock_connection.host = "some_host"
- zendesk_hook.get_connection = Mock(return_value=mock_connection)
- zendesk_hook.get_conn()
-
- mock_conn = Mock()
- mock_call = Mock(
- return_value={'next_page': 'https://some_host/something', 'path': []})
- mock_conn.call = mock_call
- zendesk_hook.get_conn = Mock(return_value=mock_conn)
- zendesk_hook.call("path", get_all_pages=True)
- assert mock_call.call_count == 2
-
-
-@patch("plugins.hooks.zendesk_hook.Zendesk")
-def test_zdesk_is_inited_correctly(mock_zendesk):
- conn_mock = Mock()
- conn_mock.host = "conn_host"
- conn_mock.login = "conn_login"
- conn_mock.password = "conn_pass"
-
- zendesk_hook = ZendeskHook("conn_id")
- zendesk_hook.get_connection = Mock(return_value=conn_mock)
- zendesk_hook.get_conn()
- mock_zendesk.assert_called_with('https://conn_host', 'conn_login',
- 'conn_pass', True)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/__init__.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/__init__.py b/tests/contrib/operators/__init__.py
index 6e38bea..cdd2147 100644
--- a/tests/contrib/operators/__init__.py
+++ b/tests/contrib/operators/__init__.py
@@ -13,6 +13,3 @@
# limitations under the License.
#
-from __future__ import absolute_import
-from .ssh_execute_operator import *
-from .fs_operator import *
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/databricks_operator.py b/tests/contrib/operators/databricks_operator.py
deleted file mode 100644
index aab47fa..0000000
--- a/tests/contrib/operators/databricks_operator.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow.contrib.hooks.databricks_hook import RunState
-from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
-from airflow.exceptions import AirflowException
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-TASK_ID = 'databricks-operator'
-DEFAULT_CONN_ID = 'databricks_default'
-NOTEBOOK_TASK = {
- 'notebook_path': '/test'
-}
-SPARK_JAR_TASK = {
- 'main_class_name': 'com.databricks.Test'
-}
-NEW_CLUSTER = {
- 'spark_version': '2.0.x-scala2.10',
- 'node_type_id': 'development-node',
- 'num_workers': 1
-}
-EXISTING_CLUSTER_ID = 'existing-cluster-id'
-RUN_NAME = 'run-name'
-RUN_ID = 1
-
-
-class DatabricksSubmitRunOperatorTest(unittest.TestCase):
- def test_init_with_named_parameters(self):
- """
- Test the initializer with the named parameters.
- """
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK)
- expected = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': TASK_ID
- }
- self.assertDictEqual(expected, op.json)
-
- def test_init_with_json(self):
- """
- Test the initializer with json data.
- """
- json = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK
- }
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': TASK_ID
- }
- self.assertDictEqual(expected, op.json)
-
- def test_init_with_specified_run_name(self):
- """
- Test the initializer with a specified run_name.
- """
- json = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': RUN_NAME
- }
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': RUN_NAME
- }
- self.assertDictEqual(expected, op.json)
-
- def test_init_with_merging(self):
- """
- Test the initializer when json and other named parameters are both
- provided. The named parameters should override top level keys in the
- json dict.
- """
- override_new_cluster = {'workers': 999}
- json = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- }
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster)
- expected = {
- 'new_cluster': override_new_cluster,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': TASK_ID,
- }
- self.assertDictEqual(expected, op.json)
-
- @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
- def test_exec_success(self, db_mock_class):
- """
- Test the execute function in case where the run is successful.
- """
- run = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- }
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
- db_mock = db_mock_class.return_value
- db_mock.submit_run.return_value = 1
- db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
-
- op.execute(None)
-
- expected = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': TASK_ID
- }
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit)
- db_mock.submit_run.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run_state.assert_called_once_with(RUN_ID)
- self.assertEquals(RUN_ID, op.run_id)
-
- @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
- def test_exec_failure(self, db_mock_class):
- """
- Test the execute function in case where the run failed.
- """
- run = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- }
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
- db_mock = db_mock_class.return_value
- db_mock.submit_run.return_value = 1
- db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
-
- with self.assertRaises(AirflowException):
- op.execute(None)
-
- expected = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': TASK_ID,
- }
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit)
- db_mock.submit_run.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run_state.assert_called_once_with(RUN_ID)
- self.assertEquals(RUN_ID, op.run_id)
-
- @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
- def test_on_kill(self, db_mock_class):
- run = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- }
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
- db_mock = db_mock_class.return_value
- op.run_id = RUN_ID
-
- op.on_kill()
-
- db_mock.cancel_run.assert_called_once_with(RUN_ID)
-
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/dataflow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/dataflow_operator.py b/tests/contrib/operators/dataflow_operator.py
deleted file mode 100644
index 7455a45..0000000
--- a/tests/contrib/operators/dataflow_operator.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-
-TASK_ID = 'test-python-dataflow'
-PY_FILE = 'gs://my-bucket/my-object.py'
-PY_OPTIONS = ['-m']
-DEFAULT_OPTIONS = {
- 'project': 'test',
- 'stagingLocation': 'gs://test/staging'
-}
-ADDITIONAL_OPTIONS = {
- 'output': 'gs://test/output'
-}
-GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}'
-
-
-class DataFlowPythonOperatorTest(unittest.TestCase):
-
- def setUp(self):
- self.dataflow = DataFlowPythonOperator(
- task_id=TASK_ID,
- py_file=PY_FILE,
- py_options=PY_OPTIONS,
- dataflow_default_options=DEFAULT_OPTIONS,
- options=ADDITIONAL_OPTIONS)
-
- def test_init(self):
- """Test DataFlowPythonOperator instance is properly initialized."""
- self.assertEqual(self.dataflow.task_id, TASK_ID)
- self.assertEqual(self.dataflow.py_file, PY_FILE)
- self.assertEqual(self.dataflow.py_options, PY_OPTIONS)
- self.assertEqual(self.dataflow.dataflow_default_options,
- DEFAULT_OPTIONS)
- self.assertEqual(self.dataflow.options,
- ADDITIONAL_OPTIONS)
-
- @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
- @mock.patch(GCS_HOOK_STRING.format('GoogleCloudStorageHook'))
- def test_exec(self, gcs_hook, dataflow_mock):
- """Test DataFlowHook is created and the right args are passed to
- start_python_workflow.
-
- """
- start_python_hook = dataflow_mock.return_value.start_python_dataflow
- gcs_download_hook = gcs_hook.return_value.download
- self.dataflow.execute(None)
- self.assertTrue(dataflow_mock.called)
- expected_options = {
- 'project': 'test',
- 'staging_location': 'gs://test/staging',
- 'output': 'gs://test/output'
- }
- gcs_download_hook.assert_called_once_with(
- 'my-bucket', 'my-object.py', mock.ANY)
- start_python_hook.assert_called_once_with(TASK_ID, expected_options,
- mock.ANY, PY_OPTIONS)
- self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/ecs_operator.py b/tests/contrib/operators/ecs_operator.py
deleted file mode 100644
index 5a593a6..0000000
--- a/tests/contrib/operators/ecs_operator.py
+++ /dev/null
@@ -1,207 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import sys
-import unittest
-from copy import deepcopy
-
-from airflow import configuration
-from airflow.exceptions import AirflowException
-from airflow.contrib.operators.ecs_operator import ECSOperator
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-
-RESPONSE_WITHOUT_FAILURES = {
- "failures": [],
- "tasks": [
- {
- "containers": [
- {
- "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868",
- "lastStatus": "PENDING",
- "name": "wordpress",
- "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55"
- }
- ],
- "desiredStatus": "RUNNING",
- "lastStatus": "PENDING",
- "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
- "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11"
- }
- ]
-}
-
-
-class TestECSOperator(unittest.TestCase):
-
- @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
- def setUp(self, aws_hook_mock):
- configuration.load_test_config()
-
- self.aws_hook_mock = aws_hook_mock
- self.ecs = ECSOperator(
- task_id='task',
- task_definition='t',
- cluster='c',
- overrides={},
- aws_conn_id=None,
- region_name='eu-west-1')
-
- def test_init(self):
-
- self.assertEqual(self.ecs.region_name, 'eu-west-1')
- self.assertEqual(self.ecs.task_definition, 't')
- self.assertEqual(self.ecs.aws_conn_id, None)
- self.assertEqual(self.ecs.cluster, 'c')
- self.assertEqual(self.ecs.overrides, {})
- self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value)
-
- self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
-
- def test_template_fields_overrides(self):
- self.assertEqual(self.ecs.template_fields, ('overrides',))
-
- @mock.patch.object(ECSOperator, '_wait_for_task_ended')
- @mock.patch.object(ECSOperator, '_check_success_task')
- def test_execute_without_failures(self, check_mock, wait_mock):
-
- client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
- client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
-
- self.ecs.execute(None)
-
- self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
- client_mock.run_task.assert_called_once_with(
- cluster='c',
- overrides={},
- startedBy='Airflow',
- taskDefinition='t'
- )
-
- wait_mock.assert_called_once_with()
- check_mock.assert_called_once_with()
- self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55')
-
- def test_execute_with_failures(self):
-
- client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
- resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES)
- resp_failures['failures'].append('dummy error')
- client_mock.run_task.return_value = resp_failures
-
- with self.assertRaises(AirflowException):
- self.ecs.execute(None)
-
- self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
- client_mock.run_task.assert_called_once_with(
- cluster='c',
- overrides={},
- startedBy='Airflow',
- taskDefinition='t'
- )
-
- def test_wait_end_tasks(self):
-
- client_mock = mock.Mock()
- self.ecs.arn = 'arn'
- self.ecs.client = client_mock
-
- self.ecs._wait_for_task_ended()
- client_mock.get_waiter.assert_called_once_with('tasks_stopped')
- client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn'])
- self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts)
-
- def test_check_success_tasks_raises(self):
- client_mock = mock.Mock()
- self.ecs.arn = 'arn'
- self.ecs.client = client_mock
-
- client_mock.describe_tasks.return_value = {
- 'tasks': [{
- 'containers': [{
- 'name': 'foo',
- 'lastStatus': 'STOPPED',
- 'exitCode': 1
- }]
- }]
- }
- with self.assertRaises(Exception) as e:
- self.ecs._check_success_task()
-
- self.assertEquals(str(e.exception), "This task is not in success state {'containers': [{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}")
- client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
-
- def test_check_success_tasks_raises_pending(self):
- client_mock = mock.Mock()
- self.ecs.client = client_mock
- self.ecs.arn = 'arn'
- client_mock.describe_tasks.return_value = {
- 'tasks': [{
- 'containers': [{
- 'name': 'container-name',
- 'lastStatus': 'PENDING'
- }]
- }]
- }
- with self.assertRaises(Exception) as e:
- self.ecs._check_success_task()
- self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus': 'PENDING', 'name': 'container-name'}]}")
- client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
-
- def test_check_success_tasks_raises_mutliple(self):
- client_mock = mock.Mock()
- self.ecs.client = client_mock
- self.ecs.arn = 'arn'
- client_mock.describe_tasks.return_value = {
- 'tasks': [{
- 'containers': [{
- 'name': 'foo',
- 'exitCode': 1
- }, {
- 'name': 'bar',
- 'lastStatus': 'STOPPED',
- 'exitCode': 0
- }]
- }]
- }
- self.ecs._check_success_task()
- client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
-
- def test_check_success_task_not_raises(self):
- client_mock = mock.Mock()
- self.ecs.client = client_mock
- self.ecs.arn = 'arn'
- client_mock.describe_tasks.return_value = {
- 'tasks': [{
- 'containers': [{
- 'name': 'container-name',
- 'lastStatus': 'STOPPED',
- 'exitCode': 0
- }]
- }]
- }
- self.ecs._check_success_task()
- client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/emr_add_steps_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/emr_add_steps_operator.py b/tests/contrib/operators/emr_add_steps_operator.py
deleted file mode 100644
index 37f9a4c..0000000
--- a/tests/contrib/operators/emr_add_steps_operator.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-from mock import MagicMock, patch
-
-from airflow import configuration
-from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
-
-ADD_STEPS_SUCCESS_RETURN = {
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200
- },
- 'StepIds': ['s-2LH3R5GW3A53T']
-}
-
-
-class TestEmrAddStepsOperator(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- # Mock out the emr_client (moto has incorrect response)
- mock_emr_client = MagicMock()
- mock_emr_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
-
- # Mock out the emr_client creator
- self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
-
-
- def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self):
- with patch('boto3.client', self.boto3_client_mock):
-
- operator = EmrAddStepsOperator(
- task_id='test_task',
- job_flow_id='j-8989898989',
- aws_conn_id='aws_default'
- )
-
- self.assertEqual(operator.execute(None), ['s-2LH3R5GW3A53T'])
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/emr_create_job_flow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/emr_create_job_flow_operator.py b/tests/contrib/operators/emr_create_job_flow_operator.py
deleted file mode 100644
index 4aa4cd2..0000000
--- a/tests/contrib/operators/emr_create_job_flow_operator.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-from mock import MagicMock, patch
-
-from airflow import configuration
-from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
-
-RUN_JOB_FLOW_SUCCESS_RETURN = {
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200
- },
- 'JobFlowId': 'j-8989898989'
-}
-
-class TestEmrCreateJobFlowOperator(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- # Mock out the emr_client (moto has incorrect response)
- mock_emr_client = MagicMock()
- mock_emr_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
-
- # Mock out the emr_client creator
- self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
-
-
- def test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self):
- with patch('boto3.client', self.boto3_client_mock):
-
- operator = EmrCreateJobFlowOperator(
- task_id='test_task',
- aws_conn_id='aws_default',
- emr_conn_id='emr_default'
- )
-
- self.assertEqual(operator.execute(None), 'j-8989898989')
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/emr_terminate_job_flow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/emr_terminate_job_flow_operator.py b/tests/contrib/operators/emr_terminate_job_flow_operator.py
deleted file mode 100644
index 94c0124..0000000
--- a/tests/contrib/operators/emr_terminate_job_flow_operator.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-from mock import MagicMock, patch
-
-from airflow import configuration
-from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator
-
-TERMINATE_SUCCESS_RETURN = {
- 'ResponseMetadata': {
- 'HTTPStatusCode': 200
- }
-}
-
-
-class TestEmrTerminateJobFlowOperator(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- # Mock out the emr_client (moto has incorrect response)
- mock_emr_client = MagicMock()
- mock_emr_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN
-
- # Mock out the emr_client creator
- self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
-
-
- def test_execute_terminates_the_job_flow_and_does_not_error(self):
- with patch('boto3.client', self.boto3_client_mock):
-
- operator = EmrTerminateJobFlowOperator(
- task_id='test_task',
- job_flow_id='j-8989898989',
- aws_conn_id='aws_default'
- )
-
- operator.execute(None)
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/fs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/fs_operator.py b/tests/contrib/operators/fs_operator.py
deleted file mode 100644
index f990157..0000000
--- a/tests/contrib/operators/fs_operator.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-from datetime import datetime
-
-from airflow import configuration
-from airflow.settings import Session
-from airflow import models, DAG
-from airflow.contrib.operators.fs_operator import FileSensor
-
-TEST_DAG_ID = 'unit_tests'
-DEFAULT_DATE = datetime(2015, 1, 1)
-configuration.load_test_config()
-
-
-def reset(dag_id=TEST_DAG_ID):
- session = Session()
- tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
- tis.delete()
- session.commit()
- session.close()
-
-reset()
-
-class FileSensorTest(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
- from airflow.contrib.hooks.fs_hook import FSHook
- hook = FSHook()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'provide_context': True
- }
- dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args)
- dag.schedule_interval = '@once'
- self.hook = hook
- self.dag = dag
-
- def test_simple(self):
- task = FileSensor(
- task_id="test",
- filepath="etc/hosts",
- fs_conn_id='fs_default',
- _hook=self.hook,
- dag=self.dag,
- )
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/hipchat_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/hipchat_operator.py b/tests/contrib/operators/hipchat_operator.py
deleted file mode 100644
index 65a2edd..0000000
--- a/tests/contrib/operators/hipchat_operator.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-import requests
-
-from airflow.contrib.operators.hipchat_operator import \
- HipChatAPISendRoomNotificationOperator
-from airflow.exceptions import AirflowException
-from airflow import configuration
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-
-class HipChatOperatorTest(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
-
- @unittest.skipIf(mock is None, 'mock package not present')
- @mock.patch('requests.request')
- def test_execute(self, request_mock):
- resp = requests.Response()
- resp.status_code = 200
- request_mock.return_value = resp
-
- operator = HipChatAPISendRoomNotificationOperator(
- task_id='test_hipchat_success',
- owner = 'airflow',
- token='abc123',
- room_id='room_id',
- message='hello world!'
- )
-
- operator.execute(None)
-
- @unittest.skipIf(mock is None, 'mock package not present')
- @mock.patch('requests.request')
- def test_execute_error_response(self, request_mock):
- resp = requests.Response()
- resp.status_code = 404
- resp.reason = 'Not Found'
- request_mock.return_value = resp
-
- operator = HipChatAPISendRoomNotificationOperator(
- task_id='test_hipchat_failure',
- owner='airflow',
- token='abc123',
- room_id='room_id',
- message='hello world!'
- )
-
- with self.assertRaises(AirflowException):
- operator.execute(None)
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/jira_operator_test.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/jira_operator_test.py b/tests/contrib/operators/jira_operator_test.py
deleted file mode 100644
index 6d615df..0000000
--- a/tests/contrib/operators/jira_operator_test.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import datetime
-from mock import Mock
-from mock import patch
-
-from airflow import DAG, configuration
-from airflow.contrib.operators.jira_operator import JiraOperator
-from airflow import models
-from airflow.utils import db
-
-DEFAULT_DATE = datetime.datetime(2017, 1, 1)
-jira_client_mock = Mock(
- name="jira_client_for_test"
-)
-
-minimal_test_ticket = {
- "id": "911539",
- "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539",
- "key": "TEST-1226",
- "fields": {
- "labels": [
- "test-label-1",
- "test-label-2"
- ],
- "description": "this is a test description",
- }
-}
-
-
-class TestJiraOperator(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
- dag = DAG('test_dag_id', default_args=args)
- self.dag = dag
- db.merge_conn(
- models.Connection(
- conn_id='jira_default', conn_type='jira',
- host='https://localhost/jira/', port=443,
- extra='{"verify": "False", "project": "AIRFLOW"}'))
-
- @patch("airflow.contrib.hooks.jira_hook.JIRA",
- autospec=True, return_value=jira_client_mock)
- def test_issue_search(self, jira_mock):
- jql_str = 'issuekey=TEST-1226'
- jira_mock.return_value.search_issues.return_value = minimal_test_ticket
-
- jira_ticket_search_operator = JiraOperator(task_id='search-ticket-test',
- jira_method="search_issues",
- jira_method_args={
- 'jql_str': jql_str,
- 'maxResults': '1'
- },
- dag=self.dag)
-
- jira_ticket_search_operator.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- self.assertTrue(jira_mock.called)
- self.assertTrue(jira_mock.return_value.search_issues.called)
-
- @patch("airflow.contrib.hooks.jira_hook.JIRA",
- autospec=True, return_value=jira_client_mock)
- def test_update_issue(self, jira_mock):
- jira_mock.return_value.add_comment.return_value = True
-
- add_comment_operator = JiraOperator(task_id='add_comment_test',
- jira_method="add_comment",
- jira_method_args={
- 'issue': minimal_test_ticket.get("key"),
- 'body': 'this is test comment'
- },
- dag=self.dag)
-
- add_comment_operator.run(start_date=DEFAULT_DATE,
- end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- self.assertTrue(jira_mock.called)
- self.assertTrue(jira_mock.return_value.add_comment.called)
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/spark_submit_operator.py b/tests/contrib/operators/spark_submit_operator.py
deleted file mode 100644
index 4e2afb2..0000000
--- a/tests/contrib/operators/spark_submit_operator.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import datetime
-
-from airflow import DAG, configuration
-from airflow.contrib.operators.spark_submit_operator import SparkSubmitOperator
-
-DEFAULT_DATE = datetime.datetime(2017, 1, 1)
-
-
-class TestSparkSubmitOperator(unittest.TestCase):
- _config = {
- 'conf': {
- 'parquet.compression': 'SNAPPY'
- },
- 'files': 'hive-site.xml',
- 'py_files': 'sample_library.py',
- 'jars': 'parquet.jar',
- 'executor_cores': 4,
- 'executor_memory': '22g',
- 'keytab': 'privileged_user.keytab',
- 'principal': 'user/spark@airflow.org',
- 'name': 'spark-job',
- 'num_executors': 10,
- 'verbose': True,
- 'application': 'test_application.py',
- 'driver_memory': '3g',
- 'java_class': 'com.foo.bar.AppMain'
- }
-
- def setUp(self):
- configuration.load_test_config()
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
- }
- self.dag = DAG('test_dag_id', default_args=args)
-
- def test_execute(self, conn_id='spark_default'):
- operator = SparkSubmitOperator(
- task_id='spark_submit_job',
- dag=self.dag,
- **self._config
- )
-
- self.assertEqual(conn_id, operator._conn_id)
-
- self.assertEqual(self._config['application'], operator._application)
- self.assertEqual(self._config['conf'], operator._conf)
- self.assertEqual(self._config['files'], operator._files)
- self.assertEqual(self._config['py_files'], operator._py_files)
- self.assertEqual(self._config['jars'], operator._jars)
- self.assertEqual(self._config['executor_cores'], operator._executor_cores)
- self.assertEqual(self._config['executor_memory'], operator._executor_memory)
- self.assertEqual(self._config['keytab'], operator._keytab)
- self.assertEqual(self._config['principal'], operator._principal)
- self.assertEqual(self._config['name'], operator._name)
- self.assertEqual(self._config['num_executors'], operator._num_executors)
- self.assertEqual(self._config['verbose'], operator._verbose)
- self.assertEqual(self._config['java_class'], operator._java_class)
- self.assertEqual(self._config['driver_memory'], operator._driver_memory)
-
-
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/sqoop_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/sqoop_operator.py b/tests/contrib/operators/sqoop_operator.py
deleted file mode 100644
index a46dc93..0000000
--- a/tests/contrib/operators/sqoop_operator.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import datetime
-import unittest
-
-from airflow import DAG, configuration
-from airflow.contrib.operators.sqoop_operator import SqoopOperator
-
-
-class TestSqoopOperator(unittest.TestCase):
- _config = {
- 'cmd_type': 'export',
- 'table': 'target_table',
- 'query': 'SELECT * FROM schema.table',
- 'target_dir': '/path/on/hdfs/to/import',
- 'append': True,
- 'file_type': 'avro',
- 'columns': 'a,b,c',
- 'num_mappers': 22,
- 'split_by': 'id',
- 'export_dir': '/path/on/hdfs/to/export',
- 'input_null_string': '\n',
- 'input_null_non_string': '\t',
- 'staging_table': 'target_table_staging',
- 'clear_staging_table': True,
- 'enclosed_by': '"',
- 'escaped_by': '\\',
- 'input_fields_terminated_by': '|',
- 'input_lines_terminated_by': '\n',
- 'input_optionally_enclosed_by': '"',
- 'batch': True,
- 'relaxed_isolation': True,
- 'direct': True,
- 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver',
- 'properties': {
- 'mapred.map.max.attempts': '1'
- }
- }
-
- def setUp(self):
- configuration.load_test_config()
- args = {
- 'owner': 'airflow',
- 'start_date': datetime.datetime(2017, 1, 1)
- }
- self.dag = DAG('test_dag_id', default_args=args)
-
- def test_execute(self, conn_id='sqoop_default'):
- operator = SqoopOperator(
- task_id='sqoop_job',
- dag=self.dag,
- **self._config
- )
-
- self.assertEqual(conn_id, operator.conn_id)
-
- self.assertEqual(self._config['cmd_type'], operator.cmd_type)
- self.assertEqual(self._config['table'], operator.table)
- self.assertEqual(self._config['target_dir'], operator.target_dir)
- self.assertEqual(self._config['append'], operator.append)
- self.assertEqual(self._config['file_type'], operator.file_type)
- self.assertEqual(self._config['num_mappers'], operator.num_mappers)
- self.assertEqual(self._config['split_by'], operator.split_by)
- self.assertEqual(self._config['input_null_string'],
- operator.input_null_string)
- self.assertEqual(self._config['input_null_non_string'],
- operator.input_null_non_string)
- self.assertEqual(self._config['staging_table'], operator.staging_table)
- self.assertEqual(self._config['clear_staging_table'],
- operator.clear_staging_table)
- self.assertEqual(self._config['batch'], operator.batch)
- self.assertEqual(self._config['relaxed_isolation'],
- operator.relaxed_isolation)
- self.assertEqual(self._config['direct'], operator.direct)
- self.assertEqual(self._config['driver'], operator.driver)
- self.assertEqual(self._config['properties'], operator.properties)
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/ssh_execute_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/ssh_execute_operator.py b/tests/contrib/operators/ssh_execute_operator.py
deleted file mode 100644
index ef8162c..0000000
--- a/tests/contrib/operators/ssh_execute_operator.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-import os
-from datetime import datetime
-
-from airflow import configuration
-from airflow.settings import Session
-from airflow import models, DAG
-from airflow.contrib.operators.ssh_execute_operator import SSHExecuteOperator
-
-
-TEST_DAG_ID = 'unit_tests'
-DEFAULT_DATE = datetime(2015, 1, 1)
-configuration.load_test_config()
-
-
-def reset(dag_id=TEST_DAG_ID):
- session = Session()
- tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
- tis.delete()
- session.commit()
- session.close()
-
-reset()
-
-
-class SSHExecuteOperatorTest(unittest.TestCase):
- def setUp(self):
- configuration.load_test_config()
- from airflow.contrib.hooks.ssh_hook import SSHHook
- hook = SSHHook()
- hook.no_host_key_check = True
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- 'provide_context': True
- }
- dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args)
- dag.schedule_interval = '@once'
- self.hook = hook
- self.dag = dag
-
- def test_simple(self):
- task = SSHExecuteOperator(
- task_id="test",
- bash_command="echo airflow",
- ssh_hook=self.hook,
- dag=self.dag,
- )
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
- def test_with_env(self):
- test_env = os.environ.copy()
- test_env['AIRFLOW_test'] = "test"
- task = SSHExecuteOperator(
- task_id="test",
- bash_command="echo $AIRFLOW_HOME",
- ssh_hook=self.hook,
- env=test_env,
- dag=self.dag,
- )
- task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_databricks_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py
new file mode 100644
index 0000000..aab47fa
--- /dev/null
+++ b/tests/contrib/operators/test_databricks_operator.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.contrib.hooks.databricks_hook import RunState
+from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
+from airflow.exceptions import AirflowException
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+ 'notebook_path': '/test'
+}
+SPARK_JAR_TASK = {
+ 'main_class_name': 'com.databricks.Test'
+}
+NEW_CLUSTER = {
+ 'spark_version': '2.0.x-scala2.10',
+ 'node_type_id': 'development-node',
+ 'num_workers': 1
+}
+EXISTING_CLUSTER_ID = 'existing-cluster-id'
+RUN_NAME = 'run-name'
+RUN_ID = 1
+
+
+class DatabricksSubmitRunOperatorTest(unittest.TestCase):
+ def test_init_with_named_parameters(self):
+ """
+ Test the initializer with the named parameters.
+ """
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK)
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID
+ }
+ self.assertDictEqual(expected, op.json)
+
+ def test_init_with_json(self):
+ """
+ Test the initializer with json data.
+ """
+ json = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID
+ }
+ self.assertDictEqual(expected, op.json)
+
+ def test_init_with_specified_run_name(self):
+ """
+ Test the initializer with a specified run_name.
+ """
+ json = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': RUN_NAME
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': RUN_NAME
+ }
+ self.assertDictEqual(expected, op.json)
+
+ def test_init_with_merging(self):
+ """
+ Test the initializer when json and other named parameters are both
+ provided. The named parameters should override top level keys in the
+ json dict.
+ """
+ override_new_cluster = {'workers': 999}
+ json = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster)
+ expected = {
+ 'new_cluster': override_new_cluster,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID,
+ }
+ self.assertDictEqual(expected, op.json)
+
+ @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+ def test_exec_success(self, db_mock_class):
+ """
+ Test the execute function in case where the run is successful.
+ """
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
+ op.execute(None)
+
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID
+ }
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit)
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run_state.assert_called_once_with(RUN_ID)
+ self.assertEquals(RUN_ID, op.run_id)
+
+ @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+ def test_exec_failure(self, db_mock_class):
+ """
+ Test the execute function in case where the run failed.
+ """
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
+
+ with self.assertRaises(AirflowException):
+ op.execute(None)
+
+ expected = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': TASK_ID,
+ }
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit)
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run_state.assert_called_once_with(RUN_ID)
+ self.assertEquals(RUN_ID, op.run_id)
+
+ @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+ def test_on_kill(self, db_mock_class):
+ run = {
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ }
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ op.run_id = RUN_ID
+
+ op.on_kill()
+
+ db_mock.cancel_run.assert_called_once_with(RUN_ID)
+
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_dataflow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py
new file mode 100644
index 0000000..0423616
--- /dev/null
+++ b/tests/contrib/operators/test_dataflow_operator.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+TASK_ID = 'test-python-dataflow'
+PY_FILE = 'gs://my-bucket/my-object.py'
+PY_OPTIONS = ['-m']
+DEFAULT_OPTIONS = {
+ 'project': 'test',
+ 'stagingLocation': 'gs://test/staging'
+}
+ADDITIONAL_OPTIONS = {
+ 'output': 'gs://test/output'
+}
+GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}'
+
+
+class DataFlowPythonOperatorTest(unittest.TestCase):
+
+ def setUp(self):
+ self.dataflow = DataFlowPythonOperator(
+ task_id=TASK_ID,
+ py_file=PY_FILE,
+ py_options=PY_OPTIONS,
+ dataflow_default_options=DEFAULT_OPTIONS,
+ options=ADDITIONAL_OPTIONS)
+
+ def test_init(self):
+ """Test DataFlowPythonOperator instance is properly initialized."""
+ self.assertEqual(self.dataflow.task_id, TASK_ID)
+ self.assertEqual(self.dataflow.py_file, PY_FILE)
+ self.assertEqual(self.dataflow.py_options, PY_OPTIONS)
+ self.assertEqual(self.dataflow.dataflow_default_options,
+ DEFAULT_OPTIONS)
+ self.assertEqual(self.dataflow.options,
+ ADDITIONAL_OPTIONS)
+
+ @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
+ @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper'))
+ def test_exec(self, gcs_hook, dataflow_mock):
+ """Test DataFlowHook is created and the right args are passed to
+ start_python_workflow.
+
+ """
+ start_python_hook = dataflow_mock.return_value.start_python_dataflow
+ gcs_download_hook = gcs_hook.return_value.google_cloud_to_local
+ self.dataflow.execute(None)
+ self.assertTrue(dataflow_mock.called)
+ expected_options = {
+ 'project': 'test',
+ 'staging_location': 'gs://test/staging',
+ 'output': 'gs://test/output'
+ }
+ gcs_download_hook.assert_called_once_with(PY_FILE)
+ start_python_hook.assert_called_once_with(TASK_ID, expected_options,
+ mock.ANY, PY_OPTIONS)
+ self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_ecs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_ecs_operator.py b/tests/contrib/operators/test_ecs_operator.py
new file mode 100644
index 0000000..80dedd3
--- /dev/null
+++ b/tests/contrib/operators/test_ecs_operator.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+from copy import deepcopy
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.ecs_operator import ECSOperator
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+RESPONSE_WITHOUT_FAILURES = {
+ "failures": [],
+ "tasks": [
+ {
+ "containers": [
+ {
+ "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868",
+ "lastStatus": "PENDING",
+ "name": "wordpress",
+ "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55"
+ }
+ ],
+ "desiredStatus": "RUNNING",
+ "lastStatus": "PENDING",
+ "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
+ "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11"
+ }
+ ]
+}
+
+
+class TestECSOperator(unittest.TestCase):
+
+ @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
+ def setUp(self, aws_hook_mock):
+ configuration.load_test_config()
+
+ self.aws_hook_mock = aws_hook_mock
+ self.ecs = ECSOperator(
+ task_id='task',
+ task_definition='t',
+ cluster='c',
+ overrides={},
+ aws_conn_id=None,
+ region_name='eu-west-1')
+
+ def test_init(self):
+
+ self.assertEqual(self.ecs.region_name, 'eu-west-1')
+ self.assertEqual(self.ecs.task_definition, 't')
+ self.assertEqual(self.ecs.aws_conn_id, None)
+ self.assertEqual(self.ecs.cluster, 'c')
+ self.assertEqual(self.ecs.overrides, {})
+ self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value)
+
+ self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
+
+ def test_template_fields_overrides(self):
+ self.assertEqual(self.ecs.template_fields, ('overrides',))
+
+ @mock.patch.object(ECSOperator, '_wait_for_task_ended')
+ @mock.patch.object(ECSOperator, '_check_success_task')
+ def test_execute_without_failures(self, check_mock, wait_mock):
+
+ client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+ client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
+
+ self.ecs.execute(None)
+
+ self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
+ client_mock.run_task.assert_called_once_with(
+ cluster='c',
+ overrides={},
+ startedBy=mock.ANY, # Can by 'airflow' or 'Airflow'
+ taskDefinition='t'
+ )
+
+ wait_mock.assert_called_once_with()
+ check_mock.assert_called_once_with()
+ self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55')
+
+ def test_execute_with_failures(self):
+
+ client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+ resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES)
+ resp_failures['failures'].append('dummy error')
+ client_mock.run_task.return_value = resp_failures
+
+ with self.assertRaises(AirflowException):
+ self.ecs.execute(None)
+
+ self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
+ client_mock.run_task.assert_called_once_with(
+ cluster='c',
+ overrides={},
+ startedBy=mock.ANY, # Can by 'airflow' or 'Airflow'
+ taskDefinition='t'
+ )
+
+ def test_wait_end_tasks(self):
+
+ client_mock = mock.Mock()
+ self.ecs.arn = 'arn'
+ self.ecs.client = client_mock
+
+ self.ecs._wait_for_task_ended()
+ client_mock.get_waiter.assert_called_once_with('tasks_stopped')
+ client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn'])
+ self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts)
+
+ def test_check_success_tasks_raises(self):
+ client_mock = mock.Mock()
+ self.ecs.arn = 'arn'
+ self.ecs.client = client_mock
+
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'foo',
+ 'lastStatus': 'STOPPED',
+ 'exitCode': 1
+ }]
+ }]
+ }
+ with self.assertRaises(Exception) as e:
+ self.ecs._check_success_task()
+
+ # Ordering of str(dict) is not guaranteed.
+ self.assertIn("This task is not in success state ", str(e.exception))
+ self.assertIn("'name': 'foo'", str(e.exception))
+ self.assertIn("'lastStatus': 'STOPPED'", str(e.exception))
+ self.assertIn("'exitCode': 1", str(e.exception))
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+ def test_check_success_tasks_raises_pending(self):
+ client_mock = mock.Mock()
+ self.ecs.client = client_mock
+ self.ecs.arn = 'arn'
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'container-name',
+ 'lastStatus': 'PENDING'
+ }]
+ }]
+ }
+ with self.assertRaises(Exception) as e:
+ self.ecs._check_success_task()
+ # Ordering of str(dict) is not guaranteed.
+ self.assertIn("This task is still pending ", str(e.exception))
+ self.assertIn("'name': 'container-name'", str(e.exception))
+ self.assertIn("'lastStatus': 'PENDING'", str(e.exception))
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+ def test_check_success_tasks_raises_mutliple(self):
+ client_mock = mock.Mock()
+ self.ecs.client = client_mock
+ self.ecs.arn = 'arn'
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'foo',
+ 'exitCode': 1
+ }, {
+ 'name': 'bar',
+ 'lastStatus': 'STOPPED',
+ 'exitCode': 0
+ }]
+ }]
+ }
+ self.ecs._check_success_task()
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+ def test_check_success_task_not_raises(self):
+ client_mock = mock.Mock()
+ self.ecs.client = client_mock
+ self.ecs.arn = 'arn'
+ client_mock.describe_tasks.return_value = {
+ 'tasks': [{
+ 'containers': [{
+ 'name': 'container-name',
+ 'lastStatus': 'STOPPED',
+ 'exitCode': 0
+ }]
+ }]
+ }
+ self.ecs._check_success_task()
+ client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_emr_add_steps_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_emr_add_steps_operator.py b/tests/contrib/operators/test_emr_add_steps_operator.py
new file mode 100644
index 0000000..37f9a4c
--- /dev/null
+++ b/tests/contrib/operators/test_emr_add_steps_operator.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from mock import MagicMock, patch
+
+from airflow import configuration
+from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator
+
+ADD_STEPS_SUCCESS_RETURN = {
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200
+ },
+ 'StepIds': ['s-2LH3R5GW3A53T']
+}
+
+
+class TestEmrAddStepsOperator(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ # Mock out the emr_client (moto has incorrect response)
+ mock_emr_client = MagicMock()
+ mock_emr_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
+
+ # Mock out the emr_client creator
+ self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
+
+
+ def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self):
+ with patch('boto3.client', self.boto3_client_mock):
+
+ operator = EmrAddStepsOperator(
+ task_id='test_task',
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_default'
+ )
+
+ self.assertEqual(operator.execute(None), ['s-2LH3R5GW3A53T'])
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_emr_create_job_flow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_emr_create_job_flow_operator.py b/tests/contrib/operators/test_emr_create_job_flow_operator.py
new file mode 100644
index 0000000..4aa4cd2
--- /dev/null
+++ b/tests/contrib/operators/test_emr_create_job_flow_operator.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from mock import MagicMock, patch
+
+from airflow import configuration
+from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
+
+RUN_JOB_FLOW_SUCCESS_RETURN = {
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200
+ },
+ 'JobFlowId': 'j-8989898989'
+}
+
+class TestEmrCreateJobFlowOperator(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ # Mock out the emr_client (moto has incorrect response)
+ mock_emr_client = MagicMock()
+ mock_emr_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
+
+ # Mock out the emr_client creator
+ self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
+
+
+ def test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self):
+ with patch('boto3.client', self.boto3_client_mock):
+
+ operator = EmrCreateJobFlowOperator(
+ task_id='test_task',
+ aws_conn_id='aws_default',
+ emr_conn_id='emr_default'
+ )
+
+ self.assertEqual(operator.execute(None), 'j-8989898989')
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
new file mode 100644
index 0000000..94c0124
--- /dev/null
+++ b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from mock import MagicMock, patch
+
+from airflow import configuration
+from airflow.contrib.operators.emr_terminate_job_flow_operator import EmrTerminateJobFlowOperator
+
+TERMINATE_SUCCESS_RETURN = {
+ 'ResponseMetadata': {
+ 'HTTPStatusCode': 200
+ }
+}
+
+
+class TestEmrTerminateJobFlowOperator(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ # Mock out the emr_client (moto has incorrect response)
+ mock_emr_client = MagicMock()
+ mock_emr_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN
+
+ # Mock out the emr_client creator
+ self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
+
+
+ def test_execute_terminates_the_job_flow_and_does_not_error(self):
+ with patch('boto3.client', self.boto3_client_mock):
+
+ operator = EmrTerminateJobFlowOperator(
+ task_id='test_task',
+ job_flow_id='j-8989898989',
+ aws_conn_id='aws_default'
+ )
+
+ operator.execute(None)
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_fs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_fs_operator.py b/tests/contrib/operators/test_fs_operator.py
new file mode 100644
index 0000000..f990157
--- /dev/null
+++ b/tests/contrib/operators/test_fs_operator.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from datetime import datetime
+
+from airflow import configuration
+from airflow.settings import Session
+from airflow import models, DAG
+from airflow.contrib.operators.fs_operator import FileSensor
+
+TEST_DAG_ID = 'unit_tests'
+DEFAULT_DATE = datetime(2015, 1, 1)
+configuration.load_test_config()
+
+
+def reset(dag_id=TEST_DAG_ID):
+ session = Session()
+ tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
+ tis.delete()
+ session.commit()
+ session.close()
+
+reset()
+
+class FileSensorTest(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+ from airflow.contrib.hooks.fs_hook import FSHook
+ hook = FSHook()
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'provide_context': True
+ }
+ dag = DAG(TEST_DAG_ID+'test_schedule_dag_once', default_args=args)
+ dag.schedule_interval = '@once'
+ self.hook = hook
+ self.dag = dag
+
+ def test_simple(self):
+ task = FileSensor(
+ task_id="test",
+ filepath="etc/hosts",
+ fs_conn_id='fs_default',
+ _hook=self.hook,
+ dag=self.dag,
+ )
+ task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_hipchat_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_hipchat_operator.py b/tests/contrib/operators/test_hipchat_operator.py
new file mode 100644
index 0000000..65a2edd
--- /dev/null
+++ b/tests/contrib/operators/test_hipchat_operator.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+import requests
+
+from airflow.contrib.operators.hipchat_operator import \
+ HipChatAPISendRoomNotificationOperator
+from airflow.exceptions import AirflowException
+from airflow import configuration
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+class HipChatOperatorTest(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+
+ @unittest.skipIf(mock is None, 'mock package not present')
+ @mock.patch('requests.request')
+ def test_execute(self, request_mock):
+ resp = requests.Response()
+ resp.status_code = 200
+ request_mock.return_value = resp
+
+ operator = HipChatAPISendRoomNotificationOperator(
+ task_id='test_hipchat_success',
+ owner = 'airflow',
+ token='abc123',
+ room_id='room_id',
+ message='hello world!'
+ )
+
+ operator.execute(None)
+
+ @unittest.skipIf(mock is None, 'mock package not present')
+ @mock.patch('requests.request')
+ def test_execute_error_response(self, request_mock):
+ resp = requests.Response()
+ resp.status_code = 404
+ resp.reason = 'Not Found'
+ request_mock.return_value = resp
+
+ operator = HipChatAPISendRoomNotificationOperator(
+ task_id='test_hipchat_failure',
+ owner='airflow',
+ token='abc123',
+ room_id='room_id',
+ message='hello world!'
+ )
+
+ with self.assertRaises(AirflowException):
+ operator.execute(None)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_jira_operator_test.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_jira_operator_test.py b/tests/contrib/operators/test_jira_operator_test.py
new file mode 100644
index 0000000..6d615df
--- /dev/null
+++ b/tests/contrib/operators/test_jira_operator_test.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import datetime
+from mock import Mock
+from mock import patch
+
+from airflow import DAG, configuration
+from airflow.contrib.operators.jira_operator import JiraOperator
+from airflow import models
+from airflow.utils import db
+
+DEFAULT_DATE = datetime.datetime(2017, 1, 1)
+jira_client_mock = Mock(
+ name="jira_client_for_test"
+)
+
+minimal_test_ticket = {
+ "id": "911539",
+ "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539",
+ "key": "TEST-1226",
+ "fields": {
+ "labels": [
+ "test-label-1",
+ "test-label-2"
+ ],
+ "description": "this is a test description",
+ }
+}
+
+
+class TestJiraOperator(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
+ }
+ dag = DAG('test_dag_id', default_args=args)
+ self.dag = dag
+ db.merge_conn(
+ models.Connection(
+ conn_id='jira_default', conn_type='jira',
+ host='https://localhost/jira/', port=443,
+ extra='{"verify": "False", "project": "AIRFLOW"}'))
+
+ @patch("airflow.contrib.hooks.jira_hook.JIRA",
+ autospec=True, return_value=jira_client_mock)
+ def test_issue_search(self, jira_mock):
+ jql_str = 'issuekey=TEST-1226'
+ jira_mock.return_value.search_issues.return_value = minimal_test_ticket
+
+ jira_ticket_search_operator = JiraOperator(task_id='search-ticket-test',
+ jira_method="search_issues",
+ jira_method_args={
+ 'jql_str': jql_str,
+ 'maxResults': '1'
+ },
+ dag=self.dag)
+
+ jira_ticket_search_operator.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ self.assertTrue(jira_mock.called)
+ self.assertTrue(jira_mock.return_value.search_issues.called)
+
+ @patch("airflow.contrib.hooks.jira_hook.JIRA",
+ autospec=True, return_value=jira_client_mock)
+ def test_update_issue(self, jira_mock):
+ jira_mock.return_value.add_comment.return_value = True
+
+ add_comment_operator = JiraOperator(task_id='add_comment_test',
+ jira_method="add_comment",
+ jira_method_args={
+ 'issue': minimal_test_ticket.get("key"),
+ 'body': 'this is test comment'
+ },
+ dag=self.dag)
+
+ add_comment_operator.run(start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+ self.assertTrue(jira_mock.called)
+ self.assertTrue(jira_mock.return_value.add_comment.called)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/operators/test_spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_spark_submit_operator.py b/tests/contrib/operators/test_spark_submit_operator.py
new file mode 100644
index 0000000..3c11dbb
--- /dev/null
+++ b/tests/contrib/operators/test_spark_submit_operator.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import datetime
+import sys
+
+from airflow import DAG, configuration
+from airflow.contrib.operators.spark_submit_operator import SparkSubmitOperator
+
+DEFAULT_DATE = datetime.datetime(2017, 1, 1)
+
+
+class TestSparkSubmitOperator(unittest.TestCase):
+
+ _config = {
+ 'conf': {
+ 'parquet.compression': 'SNAPPY'
+ },
+ 'files': 'hive-site.xml',
+ 'py_files': 'sample_library.py',
+ 'jars': 'parquet.jar',
+ 'executor_cores': 4,
+ 'executor_memory': '22g',
+ 'keytab': 'privileged_user.keytab',
+ 'principal': 'user/spark@airflow.org',
+ 'name': 'spark-job',
+ 'num_executors': 10,
+ 'verbose': True,
+ 'application': 'test_application.py',
+ 'driver_memory': '3g',
+ 'java_class': 'com.foo.bar.AppMain'
+ }
+
+ def setUp(self):
+
+ if sys.version_info[0] == 3:
+ raise unittest.SkipTest('TestSparkSubmitOperator won\'t work with '
+ 'python3. No need to test anything here')
+
+ configuration.load_test_config()
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
+ }
+ self.dag = DAG('test_dag_id', default_args=args)
+
+ def test_execute(self, conn_id='spark_default'):
+ operator = SparkSubmitOperator(
+ task_id='spark_submit_job',
+ dag=self.dag,
+ **self._config
+ )
+
+ self.assertEqual(conn_id, operator._conn_id)
+
+ self.assertEqual(self._config['application'], operator._application)
+ self.assertEqual(self._config['conf'], operator._conf)
+ self.assertEqual(self._config['files'], operator._files)
+ self.assertEqual(self._config['py_files'], operator._py_files)
+ self.assertEqual(self._config['jars'], operator._jars)
+ self.assertEqual(self._config['executor_cores'], operator._executor_cores)
+ self.assertEqual(self._config['executor_memory'], operator._executor_memory)
+ self.assertEqual(self._config['keytab'], operator._keytab)
+ self.assertEqual(self._config['principal'], operator._principal)
+ self.assertEqual(self._config['name'], operator._name)
+ self.assertEqual(self._config['num_executors'], operator._num_executors)
+ self.assertEqual(self._config['verbose'], operator._verbose)
+ self.assertEqual(self._config['java_class'], operator._java_class)
+ self.assertEqual(self._config['driver_memory'], operator._driver_memory)
+
+
+
+
+if __name__ == '__main__':
+ unittest.main()
[3/3] incubator-airflow git commit: [AIRFLOW-1094] Run unit tests
under contrib in Travis
Posted by bo...@apache.org.
[AIRFLOW-1094] Run unit tests under contrib in Travis
Rename all unit tests under tests/contrib to start
with test_* and fix
broken unit tests so that they run for the Python
2 and 3 builds.
Closes #2234 from hgrif/AIRFLOW-1094
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/219c5064
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/219c5064
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/219c5064
Branch: refs/heads/master
Commit: 219c5064142c66cf8f051455199f2dda9b164584
Parents: 74c1ce2
Author: Henk Griffioen <hg...@users.noreply.github.com>
Authored: Mon Apr 17 10:04:29 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Mon Apr 17 10:04:36 2017 +0200
----------------------------------------------------------------------
airflow/contrib/operators/ecs_operator.py | 2 +-
airflow/hooks/__init__.py | 1 +
airflow/hooks/zendesk_hook.py | 2 +-
scripts/ci/requirements.txt | 5 +
tests/contrib/hooks/aws_hook.py | 47 ----
tests/contrib/hooks/bigquery_hook.py | 139 ----------
tests/contrib/hooks/databricks_hook.py | 226 -----------------
tests/contrib/hooks/emr_hook.py | 53 ----
tests/contrib/hooks/gcp_dataflow_hook.py | 56 -----
tests/contrib/hooks/spark_submit_hook.py | 185 --------------
tests/contrib/hooks/sqoop_hook.py | 219 ----------------
tests/contrib/hooks/test_aws_hook.py | 47 ++++
tests/contrib/hooks/test_bigquery_hook.py | 139 ++++++++++
tests/contrib/hooks/test_databricks_hook.py | 226 +++++++++++++++++
tests/contrib/hooks/test_emr_hook.py | 53 ++++
tests/contrib/hooks/test_gcp_dataflow_hook.py | 56 +++++
tests/contrib/hooks/test_spark_submit_hook.py | 197 +++++++++++++++
tests/contrib/hooks/test_sqoop_hook.py | 218 ++++++++++++++++
tests/contrib/hooks/test_zendesk_hook.py | 89 +++++++
tests/contrib/hooks/zendesk_hook.py | 90 -------
tests/contrib/operators/__init__.py | 3 -
tests/contrib/operators/databricks_operator.py | 185 --------------
tests/contrib/operators/dataflow_operator.py | 82 ------
tests/contrib/operators/ecs_operator.py | 207 ---------------
.../contrib/operators/emr_add_steps_operator.py | 53 ----
.../operators/emr_create_job_flow_operator.py | 53 ----
.../emr_terminate_job_flow_operator.py | 52 ----
tests/contrib/operators/fs_operator.py | 64 -----
tests/contrib/operators/hipchat_operator.py | 74 ------
tests/contrib/operators/jira_operator_test.py | 101 --------
.../contrib/operators/spark_submit_operator.py | 81 ------
tests/contrib/operators/sqoop_operator.py | 93 -------
tests/contrib/operators/ssh_execute_operator.py | 79 ------
.../operators/test_databricks_operator.py | 185 ++++++++++++++
.../contrib/operators/test_dataflow_operator.py | 81 ++++++
tests/contrib/operators/test_ecs_operator.py | 214 ++++++++++++++++
.../operators/test_emr_add_steps_operator.py | 53 ++++
.../test_emr_create_job_flow_operator.py | 53 ++++
.../test_emr_terminate_job_flow_operator.py | 52 ++++
tests/contrib/operators/test_fs_operator.py | 64 +++++
.../contrib/operators/test_hipchat_operator.py | 74 ++++++
.../operators/test_jira_operator_test.py | 101 ++++++++
.../operators/test_spark_submit_operator.py | 88 +++++++
tests/contrib/operators/test_sqoop_operator.py | 93 +++++++
.../operators/test_ssh_execute_operator.py | 95 +++++++
tests/contrib/sensors/datadog_sensor.py | 91 -------
tests/contrib/sensors/emr_base_sensor.py | 126 ----------
tests/contrib/sensors/emr_job_flow_sensor.py | 123 ---------
tests/contrib/sensors/emr_step_sensor.py | 119 ---------
tests/contrib/sensors/ftp_sensor.py | 66 -----
tests/contrib/sensors/hdfs_sensors.py | 251 -------------------
tests/contrib/sensors/jira_sensor_test.py | 85 -------
tests/contrib/sensors/redis_sensor.py | 64 -----
tests/contrib/sensors/test_datadog_sensor.py | 106 ++++++++
tests/contrib/sensors/test_emr_base_sensor.py | 126 ++++++++++
.../contrib/sensors/test_emr_job_flow_sensor.py | 123 +++++++++
tests/contrib/sensors/test_emr_step_sensor.py | 119 +++++++++
tests/contrib/sensors/test_ftp_sensor.py | 66 +++++
tests/contrib/sensors/test_hdfs_sensors.py | 251 +++++++++++++++++++
tests/contrib/sensors/test_jira_sensor_test.py | 85 +++++++
tests/contrib/sensors/test_redis_sensor.py | 64 +++++
61 files changed, 3126 insertions(+), 3069 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py
index df02c4e..11f8c94 100644
--- a/airflow/contrib/operators/ecs_operator.py
+++ b/airflow/contrib/operators/ecs_operator.py
@@ -89,7 +89,7 @@ class ECSOperator(BaseOperator):
def _wait_for_task_ended(self):
waiter = self.client.get_waiter('tasks_stopped')
- waiter.config.max_attempts = sys.maxint # timeout is managed by airflow
+ waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(
cluster=self.cluster,
tasks=[self.arn]
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index cc09f5a..bb02967 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -48,6 +48,7 @@ _hooks = {
'samba_hook': ['SambaHook'],
'sqlite_hook': ['SqliteHook'],
'S3_hook': ['S3Hook'],
+ 'zendesk_hook': ['ZendeskHook'],
'http_hook': ['HttpHook'],
'druid_hook': ['DruidHook'],
'jdbc_hook': ['JdbcHook'],
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/airflow/hooks/zendesk_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py
index 438597f..907d1e8 100644
--- a/airflow/hooks/zendesk_hook.py
+++ b/airflow/hooks/zendesk_hook.py
@@ -21,7 +21,7 @@ A hook to talk to Zendesk
import logging
import time
from zdesk import Zendesk, RateLimitError, ZendeskError
-from airflow.hooks import BaseHook
+from airflow.hooks.base_hook import BaseHook
class ZendeskHook(BaseHook):
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index 1905398..751c13f 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -3,6 +3,7 @@ azure-storage>=0.34.0
bcrypt
bleach
boto
+boto3
celery
cgroupspy
chartkick
@@ -11,6 +12,7 @@ coverage
coveralls
croniter
cryptography
+datadog
dill
distributed
docker-py
@@ -25,6 +27,7 @@ Flask-WTF
flower
freezegun
future
+google-api-python-client>=1.5.0,<1.6.0
gunicorn
hdfs
hive-thrift-py
@@ -37,6 +40,7 @@ ldap3
lxml
markdown
mock
+moto
mysqlclient
nose
nose-exclude
@@ -69,3 +73,4 @@ statsd
thrift
thrift_sasl
unicodecsv
+zdesk
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/aws_hook.py b/tests/contrib/hooks/aws_hook.py
deleted file mode 100644
index 6f13e58..0000000
--- a/tests/contrib/hooks/aws_hook.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import boto3
-
-from airflow import configuration
-from airflow.contrib.hooks.aws_hook import AwsHook
-
-
-try:
- from moto import mock_emr
-except ImportError:
- mock_emr = None
-
-
-class TestAwsHook(unittest.TestCase):
- @mock_emr
- def setUp(self):
- configuration.load_test_config()
-
- @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
- @mock_emr
- def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
- client = boto3.client('emr', region_name='us-east-1')
- if len(client.list_clusters()['Clusters']):
- raise ValueError('AWS not properly mocked')
-
- hook = AwsHook(aws_conn_id='aws_default')
- client_from_hook = hook.get_client_type('emr')
-
- self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/bigquery_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/bigquery_hook.py b/tests/contrib/hooks/bigquery_hook.py
deleted file mode 100644
index 68856f8..0000000
--- a/tests/contrib/hooks/bigquery_hook.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow.contrib.hooks import bigquery_hook as hook
-
-
-class TestBigQueryTableSplitter(unittest.TestCase):
- def test_internal_need_default_project(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('dataset.table', None)
-
- self.assertIn('INTERNAL: No default project is specified',
- str(context.exception), "")
-
- def test_split_dataset_table(self):
- project, dataset, table = hook._split_tablename('dataset.table',
- 'project')
- self.assertEqual("project", project)
- self.assertEqual("dataset", dataset)
- self.assertEqual("table", table)
-
- def test_split_project_dataset_table(self):
- project, dataset, table = hook._split_tablename('alternative:dataset.table',
- 'project')
- self.assertEqual("alternative", project)
- self.assertEqual("dataset", dataset)
- self.assertEqual("table", table)
-
- def test_sql_split_project_dataset_table(self):
- project, dataset, table = hook._split_tablename('alternative.dataset.table',
- 'project')
- self.assertEqual("alternative", project)
- self.assertEqual("dataset", dataset)
- self.assertEqual("table", table)
-
- def test_invalid_syntax_column_double_project(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt.dataset.table',
- 'project')
-
- self.assertIn('Use either : or . to specify project',
- str(context.exception), "")
- self.assertFalse('Format exception for' in str(context.exception))
-
- def test_invalid_syntax_double_column(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt:dataset.table',
- 'project')
-
- self.assertIn('Expect format of (<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertFalse('Format exception for' in str(context.exception))
-
- def test_invalid_syntax_tiple_dot(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1.alt.dataset.table',
- 'project')
-
- self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertFalse('Format exception for' in str(context.exception))
-
- def test_invalid_syntax_column_double_project_var(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt.dataset.table',
- 'project', 'var_x')
-
- self.assertIn('Use either : or . to specify project',
- str(context.exception), "")
- self.assertIn('Format exception for var_x:',
- str(context.exception), "")
-
- def test_invalid_syntax_double_column_var(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1:alt:dataset.table',
- 'project', 'var_x')
-
- self.assertIn('Expect format of (<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertIn('Format exception for var_x:',
- str(context.exception), "")
-
- def test_invalid_syntax_tiple_dot_var(self):
- with self.assertRaises(Exception) as context:
- hook._split_tablename('alt1.alt.dataset.table',
- 'project', 'var_x')
-
- self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
- str(context.exception), "")
- self.assertIn('Format exception for var_x:',
- str(context.exception), "")
-
-class TestBigQueryHookSourceFormat(unittest.TestCase):
- def test_invalid_source_format(self):
- with self.assertRaises(Exception) as context:
- hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json")
-
- # since we passed 'json' in, and it's not valid, make sure it's present in the error string.
- self.assertIn("json", str(context.exception))
-
-
-class TestBigQueryBaseCursor(unittest.TestCase):
- def test_invalid_schema_update_options(self):
- with self.assertRaises(Exception) as context:
- hook.BigQueryBaseCursor("test", "test").run_load(
- "test.test",
- "test_schema.json",
- ["test_data.json"],
- schema_update_options=["THIS IS NOT VALID"]
- )
- self.assertIn("THIS IS NOT VALID", str(context.exception))
-
- def test_invalid_schema_update_and_write_disposition(self):
- with self.assertRaises(Exception) as context:
- hook.BigQueryBaseCursor("test", "test").run_load(
- "test.test",
- "test_schema.json",
- ["test_data.json"],
- schema_update_options=['ALLOW_FIELD_ADDITION'],
- write_disposition='WRITE_EMPTY'
- )
- self.assertIn("schema_update_options is only", str(context.exception))
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/databricks_hook.py b/tests/contrib/hooks/databricks_hook.py
deleted file mode 100644
index 6c789f9..0000000
--- a/tests/contrib/hooks/databricks_hook.py
+++ /dev/null
@@ -1,226 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow import __version__
-from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
-from airflow.exceptions import AirflowException
-from airflow.models import Connection
-from airflow.utils import db
-from requests import exceptions as requests_exceptions
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-TASK_ID = 'databricks-operator'
-DEFAULT_CONN_ID = 'databricks_default'
-NOTEBOOK_TASK = {
- 'notebook_path': '/test'
-}
-NEW_CLUSTER = {
- 'spark_version': '2.0.x-scala2.10',
- 'node_type_id': 'r3.xlarge',
- 'num_workers': 1
-}
-RUN_ID = 1
-HOST = 'xx.cloud.databricks.com'
-HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
-LOGIN = 'login'
-PASSWORD = 'password'
-USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
-RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
-LIFE_CYCLE_STATE = 'PENDING'
-STATE_MESSAGE = 'Waiting for cluster'
-GET_RUN_RESPONSE = {
- 'run_page_url': RUN_PAGE_URL,
- 'state': {
- 'life_cycle_state': LIFE_CYCLE_STATE,
- 'state_message': STATE_MESSAGE
- }
-}
-RESULT_STATE = None
-
-
-def submit_run_endpoint(host):
- """
- Utility function to generate the submit run endpoint given the host.
- """
- return 'https://{}/api/2.0/jobs/runs/submit'.format(host)
-
-
-def get_run_endpoint(host):
- """
- Utility function to generate the get run endpoint given the host.
- """
- return 'https://{}/api/2.0/jobs/runs/get'.format(host)
-
-def cancel_run_endpoint(host):
- """
- Utility function to generate the get run endpoint given the host.
- """
- return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
-
-class DatabricksHookTest(unittest.TestCase):
- """
- Tests for DatabricksHook.
- """
- @db.provide_session
- def setUp(self, session=None):
- conn = session.query(Connection) \
- .filter(Connection.conn_id == DEFAULT_CONN_ID) \
- .first()
- conn.host = HOST
- conn.login = LOGIN
- conn.password = PASSWORD
- session.commit()
-
- self.hook = DatabricksHook()
-
- def test_parse_host_with_proper_host(self):
- host = self.hook._parse_host(HOST)
- self.assertEquals(host, HOST)
-
- def test_parse_host_with_scheme(self):
- host = self.hook._parse_host(HOST_WITH_SCHEME)
- self.assertEquals(host, HOST)
-
- def test_init_bad_retry_limit(self):
- with self.assertRaises(AssertionError):
- DatabricksHook(retry_limit = 0)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.logging')
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_do_api_call_with_error_retry(self, mock_requests, mock_logging):
- for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
- mock_requests.reset_mock()
- mock_logging.reset_mock()
- mock_requests.post.side_effect = exception()
-
- with self.assertRaises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
-
- self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_do_api_call_with_bad_status_code(self, mock_requests):
- mock_requests.codes.ok = 200
- status_code_mock = mock.PropertyMock(return_value=500)
- type(mock_requests.post.return_value).status_code = status_code_mock
- with self.assertRaises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_submit_run(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.post.return_value.json.return_value = {'run_id': '1'}
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.post.return_value).status_code = status_code_mock
- json = {
- 'notebook_task': NOTEBOOK_TASK,
- 'new_cluster': NEW_CLUSTER
- }
- run_id = self.hook.submit_run(json)
-
- self.assertEquals(run_id, '1')
- mock_requests.post.assert_called_once_with(
- submit_run_endpoint(HOST),
- json={
- 'notebook_task': NOTEBOOK_TASK,
- 'new_cluster': NEW_CLUSTER,
- },
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_get_run_page_url(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.get.return_value).status_code = status_code_mock
-
- run_page_url = self.hook.get_run_page_url(RUN_ID)
-
- self.assertEquals(run_page_url, RUN_PAGE_URL)
- mock_requests.get.assert_called_once_with(
- get_run_endpoint(HOST),
- json={'run_id': RUN_ID},
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_get_run_state(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.get.return_value).status_code = status_code_mock
-
- run_state = self.hook.get_run_state(RUN_ID)
-
- self.assertEquals(run_state, RunState(
- LIFE_CYCLE_STATE,
- RESULT_STATE,
- STATE_MESSAGE))
- mock_requests.get.assert_called_once_with(
- get_run_endpoint(HOST),
- json={'run_id': RUN_ID},
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
- @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
- def test_cancel_run(self, mock_requests):
- mock_requests.codes.ok = 200
- mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
- status_code_mock = mock.PropertyMock(return_value=200)
- type(mock_requests.post.return_value).status_code = status_code_mock
-
- self.hook.cancel_run(RUN_ID)
-
- mock_requests.post.assert_called_once_with(
- cancel_run_endpoint(HOST),
- json={'run_id': RUN_ID},
- auth=(LOGIN, PASSWORD),
- headers=USER_AGENT_HEADER,
- timeout=self.hook.timeout_seconds)
-
-class RunStateTest(unittest.TestCase):
- def test_is_terminal_true(self):
- terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
- for state in terminal_states:
- run_state = RunState(state, '', '')
- self.assertTrue(run_state.is_terminal)
-
- def test_is_terminal_false(self):
- non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
- for state in non_terminal_states:
- run_state = RunState(state, '', '')
- self.assertFalse(run_state.is_terminal)
-
- def test_is_terminal_with_nonexistent_life_cycle_state(self):
- run_state = RunState('blah', '', '')
- with self.assertRaises(AirflowException):
- run_state.is_terminal
-
- def test_is_successful(self):
- run_state = RunState('TERMINATED', 'SUCCESS', '')
- self.assertTrue(run_state.is_successful)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/emr_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/emr_hook.py b/tests/contrib/hooks/emr_hook.py
deleted file mode 100644
index 119df99..0000000
--- a/tests/contrib/hooks/emr_hook.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-import boto3
-
-from airflow import configuration
-from airflow.contrib.hooks.emr_hook import EmrHook
-
-
-try:
- from moto import mock_emr
-except ImportError:
- mock_emr = None
-
-
-class TestEmrHook(unittest.TestCase):
- @mock_emr
- def setUp(self):
- configuration.load_test_config()
-
- @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
- @mock_emr
- def test_get_conn_returns_a_boto3_connection(self):
- hook = EmrHook(aws_conn_id='aws_default')
- self.assertIsNotNone(hook.get_conn().list_clusters())
-
- @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
- @mock_emr
- def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
- client = boto3.client('emr', region_name='us-east-1')
- if len(client.list_clusters()['Clusters']):
- raise ValueError('AWS not properly mocked')
-
- hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
- cluster = hook.create_job_flow({'Name': 'test_cluster'})
-
- self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/gcp_dataflow_hook.py b/tests/contrib/hooks/gcp_dataflow_hook.py
deleted file mode 100644
index 797d40c..0000000
--- a/tests/contrib/hooks/gcp_dataflow_hook.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook
-
-try:
- from unittest import mock
-except ImportError:
- try:
- import mock
- except ImportError:
- mock = None
-
-
-TASK_ID = 'test-python-dataflow'
-PY_FILE = 'apache_beam.examples.wordcount'
-PY_OPTIONS = ['-m']
-OPTIONS = {
- 'project': 'test',
- 'staging_location': 'gs://test/staging'
-}
-BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
-DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}'
-
-
-def mock_init(self, gcp_conn_id, delegate_to=None):
- pass
-
-
-class DataFlowHookTest(unittest.TestCase):
-
- def setUp(self):
- with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
- new=mock_init):
- self.dataflow_hook = DataFlowHook(gcp_conn_id='test')
-
- @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow'))
- def test_start_python_dataflow(self, internal_dataflow_mock):
- self.dataflow_hook.start_python_dataflow(
- task_id=TASK_ID, variables=OPTIONS,
- dataflow=PY_FILE, py_options=PY_OPTIONS)
- internal_dataflow_mock.assert_called_once_with(
- TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/spark_submit_hook.py b/tests/contrib/hooks/spark_submit_hook.py
deleted file mode 100644
index 8f514c2..0000000
--- a/tests/contrib/hooks/spark_submit_hook.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import os
-import unittest
-
-from airflow import configuration, models
-from airflow.utils import db
-from airflow.exceptions import AirflowException
-from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
-
-
-class TestSparkSubmitHook(unittest.TestCase):
- _spark_job_file = 'test_application.py'
- _config = {
- 'conf': {
- 'parquet.compression': 'SNAPPY'
- },
- 'conn_id': 'default_spark',
- 'files': 'hive-site.xml',
- 'py_files': 'sample_library.py',
- 'jars': 'parquet.jar',
- 'executor_cores': 4,
- 'executor_memory': '22g',
- 'keytab': 'privileged_user.keytab',
- 'principal': 'user/spark@airflow.org',
- 'name': 'spark-job',
- 'num_executors': 10,
- 'verbose': True,
- 'driver_memory': '3g',
- 'java_class': 'com.foo.bar.AppMain'
- }
-
- def setUp(self):
- configuration.load_test_config()
- db.merge_conn(
- models.Connection(
- conn_id='spark_yarn_cluster', conn_type='spark',
- host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
- )
- db.merge_conn(
- models.Connection(
- conn_id='spark_default_mesos', conn_type='spark',
- host='mesos://host', port=5050)
- )
-
- db.merge_conn(
- models.Connection(
- conn_id='spark_home_set', conn_type='spark',
- host='yarn://yarn-master',
- extra='{"spark-home": "/opt/myspark"}')
- )
-
- db.merge_conn(
- models.Connection(
- conn_id='spark_home_not_set', conn_type='spark',
- host='yarn://yarn-master')
- )
-
- def test_build_command(self):
- hook = SparkSubmitHook(**self._config)
-
- # The subprocess requires an array but we build the cmd by joining on a space
- cmd = ' '.join(hook._build_command(self._spark_job_file))
-
- # Check if the URL gets build properly and everything exists.
- assert self._spark_job_file in cmd
-
- # Check all the parameters
- assert "--files {}".format(self._config['files']) in cmd
- assert "--py-files {}".format(self._config['py_files']) in cmd
- assert "--jars {}".format(self._config['jars']) in cmd
- assert "--executor-cores {}".format(self._config['executor_cores']) in cmd
- assert "--executor-memory {}".format(self._config['executor_memory']) in cmd
- assert "--keytab {}".format(self._config['keytab']) in cmd
- assert "--principal {}".format(self._config['principal']) in cmd
- assert "--name {}".format(self._config['name']) in cmd
- assert "--num-executors {}".format(self._config['num_executors']) in cmd
- assert "--class {}".format(self._config['java_class']) in cmd
- assert "--driver-memory {}".format(self._config['driver_memory']) in cmd
-
- # Check if all config settings are there
- for k in self._config['conf']:
- assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd
-
- if self._config['verbose']:
- assert "--verbose" in cmd
-
- def test_submit(self):
- hook = SparkSubmitHook()
-
- # We don't have spark-submit available, and this is hard to mock, so just accept
- # an exception for now.
- with self.assertRaises(AirflowException):
- hook.submit(self._spark_job_file)
-
- def test_resolve_connection(self):
-
- # Default to the standard yarn connection because conn_id does not exists
- hook = SparkSubmitHook(conn_id='')
- self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None))
- assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
-
- # Default to the standard yarn connection
- hook = SparkSubmitHook(conn_id='spark_default')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn', 'root.default', None, None)
- )
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert "--master yarn" in cmd
- assert "--queue root.default" in cmd
-
- # Connect to a mesos master
- hook = SparkSubmitHook(conn_id='spark_default_mesos')
- self.assertEqual(
- hook._resolve_connection(),
- ('mesos://host:5050', None, None, None)
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert "--master mesos://host:5050" in cmd
-
- # Set specific queue and deploy mode
- hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn://yarn-master', 'root.etl', 'cluster', None)
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert "--master yarn://yarn-master" in cmd
- assert "--queue root.etl" in cmd
- assert "--deploy-mode cluster" in cmd
-
- # Set the spark home
- hook = SparkSubmitHook(conn_id='spark_home_set')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn://yarn-master', None, None, '/opt/myspark')
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert cmd.startswith('/opt/myspark/bin/spark-submit')
-
- # Spark home not set
- hook = SparkSubmitHook(conn_id='spark_home_not_set')
- self.assertEqual(
- hook._resolve_connection(),
- ('yarn://yarn-master', None, None, None)
- )
-
- cmd = ' '.join(hook._build_command(self._spark_job_file))
- assert cmd.startswith('spark-submit')
-
- def test_process_log(self):
- # Must select yarn connection
- hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
-
- log_lines = [
- 'SPARK_MAJOR_VERSION is set to 2, using Spark2',
- 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable',
- 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.',
- 'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
- 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
- ]
-
- hook._process_log(log_lines)
-
- assert hook._yarn_application_id == 'application_1486558679801_1820'
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/sqoop_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/sqoop_hook.py b/tests/contrib/hooks/sqoop_hook.py
deleted file mode 100644
index 1d85e43..0000000
--- a/tests/contrib/hooks/sqoop_hook.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import json
-import unittest
-from exceptions import OSError
-
-from airflow import configuration, models
-from airflow.contrib.hooks.sqoop_hook import SqoopHook
-from airflow.utils import db
-
-
-class TestSqoopHook(unittest.TestCase):
- _config = {
- 'conn_id': 'sqoop_test',
- 'num_mappers': 22,
- 'verbose': True,
- 'properties': {
- 'mapred.map.max.attempts': '1'
- }
- }
- _config_export = {
- 'table': 'domino.export_data_to',
- 'export_dir': '/hdfs/data/to/be/exported',
- 'input_null_string': '\n',
- 'input_null_non_string': '\t',
- 'staging_table': 'database.staging',
- 'clear_staging_table': True,
- 'enclosed_by': '"',
- 'escaped_by': '\\',
- 'input_fields_terminated_by': '|',
- 'input_lines_terminated_by': '\n',
- 'input_optionally_enclosed_by': '"',
- 'batch': True,
- 'relaxed_isolation': True
- }
- _config_import = {
- 'target_dir': '/hdfs/data/target/location',
- 'append': True,
- 'file_type': 'parquet',
- 'split_by': '\n',
- 'direct': True,
- 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver'
- }
-
- _config_json = {
- 'namenode': 'http://0.0.0.0:50070/',
- 'job_tracker': 'http://0.0.0.0:50030/',
- 'libjars': '/path/to/jars',
- 'files': '/path/to/files',
- 'archives': '/path/to/archives'
- }
-
- def setUp(self):
- configuration.load_test_config()
- db.merge_conn(
- models.Connection(
- conn_id='sqoop_test', conn_type='sqoop',
- host='rmdbs', port=5050, extra=json.dumps(self._config_json)
- )
- )
-
- def test_popen(self):
- hook = SqoopHook(**self._config)
-
- # Should go well
- hook.Popen(['ls'])
-
- # Should give an exception
- with self.assertRaises(OSError):
- hook.Popen('exit 1')
-
- def test_submit(self):
- hook = SqoopHook(**self._config)
-
- cmd = ' '.join(hook._prepare_command())
-
- # Check if the config has been extracted from the json
- if self._config_json['namenode']:
- assert "-fs {}".format(self._config_json['namenode']) in cmd
-
- if self._config_json['job_tracker']:
- assert "-jt {}".format(self._config_json['job_tracker']) in cmd
-
- if self._config_json['libjars']:
- assert "-libjars {}".format(self._config_json['libjars']) in cmd
-
- if self._config_json['files']:
- assert "-files {}".format(self._config_json['files']) in cmd
-
- if self._config_json['archives']:
- assert "-archives {}".format(self._config_json['archives']) in cmd
-
- # Check the regulator stuff passed by the default constructor
- if self._config['verbose']:
- assert "--verbose" in cmd
-
- if self._config['num_mappers']:
- assert "--num-mappers {}".format(
- self._config['num_mappers']) in cmd
-
- print(self._config['properties'])
- for key, value in self._config['properties'].items():
- assert "-D {}={}".format(key, value) in cmd
-
- # We don't have the sqoop binary available, and this is hard to mock,
- # so just accept an exception for now.
- with self.assertRaises(OSError):
- hook.export_table(**self._config_export)
-
- with self.assertRaises(OSError):
- hook.import_table(table='schema.table',
- target_dir='/sqoop/example/path')
-
- with self.assertRaises(OSError):
- hook.import_query(query='SELECT * FROM sometable',
- target_dir='/sqoop/example/path')
-
- def test_export_cmd(self):
- hook = SqoopHook()
-
- # The subprocess requires an array but we build the cmd by joining on a space
- cmd = ' '.join(
- hook._export_cmd(
- self._config_export['table'],
- self._config_export['export_dir'],
- input_null_string=self._config_export['input_null_string'],
- input_null_non_string=self._config_export[
- 'input_null_non_string'],
- staging_table=self._config_export['staging_table'],
- clear_staging_table=self._config_export['clear_staging_table'],
- enclosed_by=self._config_export['enclosed_by'],
- escaped_by=self._config_export['escaped_by'],
- input_fields_terminated_by=self._config_export[
- 'input_fields_terminated_by'],
- input_lines_terminated_by=self._config_export[
- 'input_lines_terminated_by'],
- input_optionally_enclosed_by=self._config_export[
- 'input_optionally_enclosed_by'],
- batch=self._config_export['batch'],
- relaxed_isolation=self._config_export['relaxed_isolation'])
- )
-
- assert "--input-null-string {}".format(
- self._config_export['input_null_string']) in cmd
- assert "--input-null-non-string {}".format(
- self._config_export['input_null_non_string']) in cmd
- assert "--staging-table {}".format(
- self._config_export['staging_table']) in cmd
- assert "--enclosed-by {}".format(
- self._config_export['enclosed_by']) in cmd
- assert "--escaped-by {}".format(
- self._config_export['escaped_by']) in cmd
- assert "--input-fields-terminated-by {}".format(
- self._config_export['input_fields_terminated_by']) in cmd
- assert "--input-lines-terminated-by {}".format(
- self._config_export['input_lines_terminated_by']) in cmd
- assert "--input-optionally-enclosed-by {}".format(
- self._config_export['input_optionally_enclosed_by']) in cmd
-
- if self._config_export['clear_staging_table']:
- assert "--clear-staging-table" in cmd
-
- if self._config_export['batch']:
- assert "--batch" in cmd
-
- if self._config_export['relaxed_isolation']:
- assert "--relaxed-isolation" in cmd
-
- def test_import_cmd(self):
- hook = SqoopHook()
-
- # The subprocess requires an array but we build the cmd by joining on a space
- cmd = ' '.join(
- hook._import_cmd(self._config_import['target_dir'],
- append=self._config_import['append'],
- file_type=self._config_import['file_type'],
- split_by=self._config_import['split_by'],
- direct=self._config_import['direct'],
- driver=self._config_import['driver'])
- )
-
- if self._config_import['append']:
- assert '--append' in cmd
-
- if self._config_import['direct']:
- assert '--direct' in cmd
-
- assert '--target-dir {}'.format(
- self._config_import['target_dir']) in cmd
-
- assert '--driver {}'.format(self._config_import['driver']) in cmd
- assert '--split-by {}'.format(self._config_import['split_by']) in cmd
-
- def test_get_export_format_argument(self):
- hook = SqoopHook()
- assert "--as-avrodatafile" in hook._get_export_format_argument('avro')
- assert "--as-parquetfile" in hook._get_export_format_argument(
- 'parquet')
- assert "--as-sequencefile" in hook._get_export_format_argument(
- 'sequence')
- assert "--as-textfile" in hook._get_export_format_argument('text')
- assert "--as-textfile" in hook._get_export_format_argument('unknown')
-
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_aws_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py
new file mode 100644
index 0000000..6f13e58
--- /dev/null
+++ b/tests/contrib/hooks/test_aws_hook.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+try:
+ from moto import mock_emr
+except ImportError:
+ mock_emr = None
+
+
+class TestAwsHook(unittest.TestCase):
+ @mock_emr
+ def setUp(self):
+ configuration.load_test_config()
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
+ client = boto3.client('emr', region_name='us-east-1')
+ if len(client.list_clusters()['Clusters']):
+ raise ValueError('AWS not properly mocked')
+
+ hook = AwsHook(aws_conn_id='aws_default')
+ client_from_hook = hook.get_client_type('emr')
+
+ self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_bigquery_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
new file mode 100644
index 0000000..0adffc5
--- /dev/null
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.contrib.hooks import bigquery_hook as hook
+
+
+class TestBigQueryTableSplitter(unittest.TestCase):
+ def test_internal_need_default_project(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('dataset.table', None)
+
+ self.assertIn('INTERNAL: No default project is specified',
+ str(context.exception), "")
+
+ def test_split_dataset_table(self):
+ project, dataset, table = hook._split_tablename('dataset.table',
+ 'project')
+ self.assertEqual("project", project)
+ self.assertEqual("dataset", dataset)
+ self.assertEqual("table", table)
+
+ def test_split_project_dataset_table(self):
+ project, dataset, table = hook._split_tablename('alternative:dataset.table',
+ 'project')
+ self.assertEqual("alternative", project)
+ self.assertEqual("dataset", dataset)
+ self.assertEqual("table", table)
+
+ def test_sql_split_project_dataset_table(self):
+ project, dataset, table = hook._split_tablename('alternative.dataset.table',
+ 'project')
+ self.assertEqual("alternative", project)
+ self.assertEqual("dataset", dataset)
+ self.assertEqual("table", table)
+
+ def test_invalid_syntax_column_double_project(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt.dataset.table',
+ 'project')
+
+ self.assertIn('Use either : or . to specify project',
+ str(context.exception), "")
+ self.assertFalse('Format exception for' in str(context.exception))
+
+ def test_invalid_syntax_double_column(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt:dataset.table',
+ 'project')
+
+ self.assertIn('Expect format of (<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertFalse('Format exception for' in str(context.exception))
+
+ def test_invalid_syntax_tiple_dot(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1.alt.dataset.table',
+ 'project')
+
+ self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertFalse('Format exception for' in str(context.exception))
+
+ def test_invalid_syntax_column_double_project_var(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt.dataset.table',
+ 'project', 'var_x')
+
+ self.assertIn('Use either : or . to specify project',
+ str(context.exception), "")
+ self.assertIn('Format exception for var_x:',
+ str(context.exception), "")
+
+ def test_invalid_syntax_double_column_var(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1:alt:dataset.table',
+ 'project', 'var_x')
+
+ self.assertIn('Expect format of (<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertIn('Format exception for var_x:',
+ str(context.exception), "")
+
+ def test_invalid_syntax_tiple_dot_var(self):
+ with self.assertRaises(Exception) as context:
+ hook._split_tablename('alt1.alt.dataset.table',
+ 'project', 'var_x')
+
+ self.assertIn('Expect format of (<project.|<project:)<dataset>.<table>',
+ str(context.exception), "")
+ self.assertIn('Format exception for var_x:',
+ str(context.exception), "")
+
+class TestBigQueryHookSourceFormat(unittest.TestCase):
+ def test_invalid_source_format(self):
+ with self.assertRaises(Exception) as context:
+ hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json")
+
+ # since we passed 'json' in, and it's not valid, make sure it's present in the error string.
+ self.assertIn("JSON", str(context.exception))
+
+
+class TestBigQueryBaseCursor(unittest.TestCase):
+ def test_invalid_schema_update_options(self):
+ with self.assertRaises(Exception) as context:
+ hook.BigQueryBaseCursor("test", "test").run_load(
+ "test.test",
+ "test_schema.json",
+ ["test_data.json"],
+ schema_update_options=["THIS IS NOT VALID"]
+ )
+ self.assertIn("THIS IS NOT VALID", str(context.exception))
+
+ def test_invalid_schema_update_and_write_disposition(self):
+ with self.assertRaises(Exception) as context:
+ hook.BigQueryBaseCursor("test", "test").run_load(
+ "test.test",
+ "test_schema.json",
+ ["test_data.json"],
+ schema_update_options=['ALLOW_FIELD_ADDITION'],
+ write_disposition='WRITE_EMPTY'
+ )
+ self.assertIn("schema_update_options is only", str(context.exception))
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
new file mode 100644
index 0000000..6c789f9
--- /dev/null
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow import __version__
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.utils import db
+from requests import exceptions as requests_exceptions
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+ 'notebook_path': '/test'
+}
+NEW_CLUSTER = {
+ 'spark_version': '2.0.x-scala2.10',
+ 'node_type_id': 'r3.xlarge',
+ 'num_workers': 1
+}
+RUN_ID = 1
+HOST = 'xx.cloud.databricks.com'
+HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
+LOGIN = 'login'
+PASSWORD = 'password'
+USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
+LIFE_CYCLE_STATE = 'PENDING'
+STATE_MESSAGE = 'Waiting for cluster'
+GET_RUN_RESPONSE = {
+ 'run_page_url': RUN_PAGE_URL,
+ 'state': {
+ 'life_cycle_state': LIFE_CYCLE_STATE,
+ 'state_message': STATE_MESSAGE
+ }
+}
+RESULT_STATE = None
+
+
+def submit_run_endpoint(host):
+ """
+ Utility function to generate the submit run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/submit'.format(host)
+
+
+def get_run_endpoint(host):
+ """
+ Utility function to generate the get run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/get'.format(host)
+
+def cancel_run_endpoint(host):
+ """
+ Utility function to generate the get run endpoint given the host.
+ """
+ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
+
+class DatabricksHookTest(unittest.TestCase):
+ """
+ Tests for DatabricksHook.
+ """
+ @db.provide_session
+ def setUp(self, session=None):
+ conn = session.query(Connection) \
+ .filter(Connection.conn_id == DEFAULT_CONN_ID) \
+ .first()
+ conn.host = HOST
+ conn.login = LOGIN
+ conn.password = PASSWORD
+ session.commit()
+
+ self.hook = DatabricksHook()
+
+ def test_parse_host_with_proper_host(self):
+ host = self.hook._parse_host(HOST)
+ self.assertEquals(host, HOST)
+
+ def test_parse_host_with_scheme(self):
+ host = self.hook._parse_host(HOST_WITH_SCHEME)
+ self.assertEquals(host, HOST)
+
+ def test_init_bad_retry_limit(self):
+ with self.assertRaises(AssertionError):
+ DatabricksHook(retry_limit = 0)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.logging')
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_do_api_call_with_error_retry(self, mock_requests, mock_logging):
+ for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
+ mock_requests.reset_mock()
+ mock_logging.reset_mock()
+ mock_requests.post.side_effect = exception()
+
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_do_api_call_with_bad_status_code(self, mock_requests):
+ mock_requests.codes.ok = 200
+ status_code_mock = mock.PropertyMock(return_value=500)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_submit_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+ json = {
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER
+ }
+ run_id = self.hook.submit_run(json)
+
+ self.assertEquals(run_id, '1')
+ mock_requests.post.assert_called_once_with(
+ submit_run_endpoint(HOST),
+ json={
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER,
+ },
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_get_run_page_url(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.get.return_value).status_code = status_code_mock
+
+ run_page_url = self.hook.get_run_page_url(RUN_ID)
+
+ self.assertEquals(run_page_url, RUN_PAGE_URL)
+ mock_requests.get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_get_run_state(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.get.return_value).status_code = status_code_mock
+
+ run_state = self.hook.get_run_state(RUN_ID)
+
+ self.assertEquals(run_state, RunState(
+ LIFE_CYCLE_STATE,
+ RESULT_STATE,
+ STATE_MESSAGE))
+ mock_requests.get.assert_called_once_with(
+ get_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+ @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+ def test_cancel_run(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+
+ self.hook.cancel_run(RUN_ID)
+
+ mock_requests.post.assert_called_once_with(
+ cancel_run_endpoint(HOST),
+ json={'run_id': RUN_ID},
+ auth=(LOGIN, PASSWORD),
+ headers=USER_AGENT_HEADER,
+ timeout=self.hook.timeout_seconds)
+
+class RunStateTest(unittest.TestCase):
+ def test_is_terminal_true(self):
+ terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
+ for state in terminal_states:
+ run_state = RunState(state, '', '')
+ self.assertTrue(run_state.is_terminal)
+
+ def test_is_terminal_false(self):
+ non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
+ for state in non_terminal_states:
+ run_state = RunState(state, '', '')
+ self.assertFalse(run_state.is_terminal)
+
+ def test_is_terminal_with_nonexistent_life_cycle_state(self):
+ run_state = RunState('blah', '', '')
+ with self.assertRaises(AirflowException):
+ run_state.is_terminal
+
+ def test_is_successful(self):
+ run_state = RunState('TERMINATED', 'SUCCESS', '')
+ self.assertTrue(run_state.is_successful)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_emr_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_emr_hook.py b/tests/contrib/hooks/test_emr_hook.py
new file mode 100644
index 0000000..119df99
--- /dev/null
+++ b/tests/contrib/hooks/test_emr_hook.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.hooks.emr_hook import EmrHook
+
+
+try:
+ from moto import mock_emr
+except ImportError:
+ mock_emr = None
+
+
+class TestEmrHook(unittest.TestCase):
+ @mock_emr
+ def setUp(self):
+ configuration.load_test_config()
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_conn_returns_a_boto3_connection(self):
+ hook = EmrHook(aws_conn_id='aws_default')
+ self.assertIsNotNone(hook.get_conn().list_clusters())
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
+ client = boto3.client('emr', region_name='us-east-1')
+ if len(client.list_clusters()['Clusters']):
+ raise ValueError('AWS not properly mocked')
+
+ hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
+ cluster = hook.create_job_flow({'Name': 'test_cluster'})
+
+ self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId'])
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_gcp_dataflow_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py
new file mode 100644
index 0000000..797d40c
--- /dev/null
+++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook
+
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+
+TASK_ID = 'test-python-dataflow'
+PY_FILE = 'apache_beam.examples.wordcount'
+PY_OPTIONS = ['-m']
+OPTIONS = {
+ 'project': 'test',
+ 'staging_location': 'gs://test/staging'
+}
+BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
+DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}'
+
+
+def mock_init(self, gcp_conn_id, delegate_to=None):
+ pass
+
+
+class DataFlowHookTest(unittest.TestCase):
+
+ def setUp(self):
+ with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
+ new=mock_init):
+ self.dataflow_hook = DataFlowHook(gcp_conn_id='test')
+
+ @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow'))
+ def test_start_python_dataflow(self, internal_dataflow_mock):
+ self.dataflow_hook.start_python_dataflow(
+ task_id=TASK_ID, variables=OPTIONS,
+ dataflow=PY_FILE, py_options=PY_OPTIONS)
+ internal_dataflow_mock.assert_called_once_with(
+ TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
new file mode 100644
index 0000000..24315fa
--- /dev/null
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -0,0 +1,197 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import sys
+import unittest
+from io import StringIO
+
+import mock
+
+from airflow import configuration, models
+from airflow.utils import db
+from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
+
+
+class TestSparkSubmitHook(unittest.TestCase):
+
+ _spark_job_file = 'test_application.py'
+ _config = {
+ 'conf': {
+ 'parquet.compression': 'SNAPPY'
+ },
+ 'conn_id': 'default_spark',
+ 'files': 'hive-site.xml',
+ 'py_files': 'sample_library.py',
+ 'jars': 'parquet.jar',
+ 'executor_cores': 4,
+ 'executor_memory': '22g',
+ 'keytab': 'privileged_user.keytab',
+ 'principal': 'user/spark@airflow.org',
+ 'name': 'spark-job',
+ 'num_executors': 10,
+ 'verbose': True,
+ 'driver_memory': '3g',
+ 'java_class': 'com.foo.bar.AppMain'
+ }
+
+ def setUp(self):
+
+ if sys.version_info[0] == 3:
+ raise unittest.SkipTest('TestSparkSubmitHook won\'t work with '
+ 'python3. No need to test anything here')
+
+ configuration.load_test_config()
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_yarn_cluster', conn_type='spark',
+ host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
+ )
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_default_mesos', conn_type='spark',
+ host='mesos://host', port=5050)
+ )
+
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_home_set', conn_type='spark',
+ host='yarn://yarn-master',
+ extra='{"spark-home": "/opt/myspark"}')
+ )
+
+ db.merge_conn(
+ models.Connection(
+ conn_id='spark_home_not_set', conn_type='spark',
+ host='yarn://yarn-master')
+ )
+
+ def test_build_command(self):
+ hook = SparkSubmitHook(**self._config)
+
+ # The subprocess requires an array but we build the cmd by joining on a space
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+
+ # Check if the URL gets build properly and everything exists.
+ assert self._spark_job_file in cmd
+
+ # Check all the parameters
+ assert "--files {}".format(self._config['files']) in cmd
+ assert "--py-files {}".format(self._config['py_files']) in cmd
+ assert "--jars {}".format(self._config['jars']) in cmd
+ assert "--executor-cores {}".format(self._config['executor_cores']) in cmd
+ assert "--executor-memory {}".format(self._config['executor_memory']) in cmd
+ assert "--keytab {}".format(self._config['keytab']) in cmd
+ assert "--principal {}".format(self._config['principal']) in cmd
+ assert "--name {}".format(self._config['name']) in cmd
+ assert "--num-executors {}".format(self._config['num_executors']) in cmd
+ assert "--class {}".format(self._config['java_class']) in cmd
+ assert "--driver-memory {}".format(self._config['driver_memory']) in cmd
+
+ # Check if all config settings are there
+ for k in self._config['conf']:
+ assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd
+
+ if self._config['verbose']:
+ assert "--verbose" in cmd
+
+ @mock.patch('airflow.contrib.hooks.spark_submit_hook.subprocess')
+ def test_submit(self, mock_process):
+ # We don't have spark-submit available, and this is hard to mock, so let's
+ # just use this simple mock.
+ mock_Popen = mock_process.Popen.return_value
+ mock_Popen.stdout = StringIO(u'stdout')
+ mock_Popen.stderr = StringIO(u'stderr')
+ mock_Popen.returncode = None
+ mock_Popen.communicate.return_value = ['extra stdout', 'extra stderr']
+ hook = SparkSubmitHook()
+ hook.submit(self._spark_job_file)
+
+ def test_resolve_connection(self):
+
+ # Default to the standard yarn connection because conn_id does not exists
+ hook = SparkSubmitHook(conn_id='')
+ self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None))
+ assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
+
+ # Default to the standard yarn connection
+ hook = SparkSubmitHook(conn_id='spark_default')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn', 'root.default', None, None)
+ )
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert "--master yarn" in cmd
+ assert "--queue root.default" in cmd
+
+ # Connect to a mesos master
+ hook = SparkSubmitHook(conn_id='spark_default_mesos')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('mesos://host:5050', None, None, None)
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert "--master mesos://host:5050" in cmd
+
+ # Set specific queue and deploy mode
+ hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn://yarn-master', 'root.etl', 'cluster', None)
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert "--master yarn://yarn-master" in cmd
+ assert "--queue root.etl" in cmd
+ assert "--deploy-mode cluster" in cmd
+
+ # Set the spark home
+ hook = SparkSubmitHook(conn_id='spark_home_set')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn://yarn-master', None, None, '/opt/myspark')
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert cmd.startswith('/opt/myspark/bin/spark-submit')
+
+ # Spark home not set
+ hook = SparkSubmitHook(conn_id='spark_home_not_set')
+ self.assertEqual(
+ hook._resolve_connection(),
+ ('yarn://yarn-master', None, None, None)
+ )
+
+ cmd = ' '.join(hook._build_command(self._spark_job_file))
+ assert cmd.startswith('spark-submit')
+
+ def test_process_log(self):
+ # Must select yarn connection
+ hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
+
+ log_lines = [
+ 'SPARK_MAJOR_VERSION is set to 2, using Spark2',
+ 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable',
+ 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.',
+ 'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
+ 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
+ ]
+
+ hook._process_log(log_lines)
+
+ assert hook._yarn_application_id == 'application_1486558679801_1820'
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_sqoop_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_sqoop_hook.py b/tests/contrib/hooks/test_sqoop_hook.py
new file mode 100644
index 0000000..ca8033b
--- /dev/null
+++ b/tests/contrib/hooks/test_sqoop_hook.py
@@ -0,0 +1,218 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import unittest
+
+from airflow import configuration, models
+from airflow.contrib.hooks.sqoop_hook import SqoopHook
+from airflow.utils import db
+
+
+class TestSqoopHook(unittest.TestCase):
+ _config = {
+ 'conn_id': 'sqoop_test',
+ 'num_mappers': 22,
+ 'verbose': True,
+ 'properties': {
+ 'mapred.map.max.attempts': '1'
+ }
+ }
+ _config_export = {
+ 'table': 'domino.export_data_to',
+ 'export_dir': '/hdfs/data/to/be/exported',
+ 'input_null_string': '\n',
+ 'input_null_non_string': '\t',
+ 'staging_table': 'database.staging',
+ 'clear_staging_table': True,
+ 'enclosed_by': '"',
+ 'escaped_by': '\\',
+ 'input_fields_terminated_by': '|',
+ 'input_lines_terminated_by': '\n',
+ 'input_optionally_enclosed_by': '"',
+ 'batch': True,
+ 'relaxed_isolation': True
+ }
+ _config_import = {
+ 'target_dir': '/hdfs/data/target/location',
+ 'append': True,
+ 'file_type': 'parquet',
+ 'split_by': '\n',
+ 'direct': True,
+ 'driver': 'com.microsoft.jdbc.sqlserver.SQLServerDriver'
+ }
+
+ _config_json = {
+ 'namenode': 'http://0.0.0.0:50070/',
+ 'job_tracker': 'http://0.0.0.0:50030/',
+ 'libjars': '/path/to/jars',
+ 'files': '/path/to/files',
+ 'archives': '/path/to/archives'
+ }
+
+ def setUp(self):
+ configuration.load_test_config()
+ db.merge_conn(
+ models.Connection(
+ conn_id='sqoop_test', conn_type='sqoop',
+ host='rmdbs', port=5050, extra=json.dumps(self._config_json)
+ )
+ )
+
+ def test_popen(self):
+ hook = SqoopHook(**self._config)
+
+ # Should go well
+ hook.Popen(['ls'])
+
+ # Should give an exception
+ with self.assertRaises(OSError):
+ hook.Popen('exit 1')
+
+ def test_submit(self):
+ hook = SqoopHook(**self._config)
+
+ cmd = ' '.join(hook._prepare_command())
+
+ # Check if the config has been extracted from the json
+ if self._config_json['namenode']:
+ assert "-fs {}".format(self._config_json['namenode']) in cmd
+
+ if self._config_json['job_tracker']:
+ assert "-jt {}".format(self._config_json['job_tracker']) in cmd
+
+ if self._config_json['libjars']:
+ assert "-libjars {}".format(self._config_json['libjars']) in cmd
+
+ if self._config_json['files']:
+ assert "-files {}".format(self._config_json['files']) in cmd
+
+ if self._config_json['archives']:
+ assert "-archives {}".format(self._config_json['archives']) in cmd
+
+ # Check the regulator stuff passed by the default constructor
+ if self._config['verbose']:
+ assert "--verbose" in cmd
+
+ if self._config['num_mappers']:
+ assert "--num-mappers {}".format(
+ self._config['num_mappers']) in cmd
+
+ print(self._config['properties'])
+ for key, value in self._config['properties'].items():
+ assert "-D {}={}".format(key, value) in cmd
+
+ # We don't have the sqoop binary available, and this is hard to mock,
+ # so just accept an exception for now.
+ with self.assertRaises(OSError):
+ hook.export_table(**self._config_export)
+
+ with self.assertRaises(OSError):
+ hook.import_table(table='schema.table',
+ target_dir='/sqoop/example/path')
+
+ with self.assertRaises(OSError):
+ hook.import_query(query='SELECT * FROM sometable',
+ target_dir='/sqoop/example/path')
+
+ def test_export_cmd(self):
+ hook = SqoopHook()
+
+ # The subprocess requires an array but we build the cmd by joining on a space
+ cmd = ' '.join(
+ hook._export_cmd(
+ self._config_export['table'],
+ self._config_export['export_dir'],
+ input_null_string=self._config_export['input_null_string'],
+ input_null_non_string=self._config_export[
+ 'input_null_non_string'],
+ staging_table=self._config_export['staging_table'],
+ clear_staging_table=self._config_export['clear_staging_table'],
+ enclosed_by=self._config_export['enclosed_by'],
+ escaped_by=self._config_export['escaped_by'],
+ input_fields_terminated_by=self._config_export[
+ 'input_fields_terminated_by'],
+ input_lines_terminated_by=self._config_export[
+ 'input_lines_terminated_by'],
+ input_optionally_enclosed_by=self._config_export[
+ 'input_optionally_enclosed_by'],
+ batch=self._config_export['batch'],
+ relaxed_isolation=self._config_export['relaxed_isolation'])
+ )
+
+ assert "--input-null-string {}".format(
+ self._config_export['input_null_string']) in cmd
+ assert "--input-null-non-string {}".format(
+ self._config_export['input_null_non_string']) in cmd
+ assert "--staging-table {}".format(
+ self._config_export['staging_table']) in cmd
+ assert "--enclosed-by {}".format(
+ self._config_export['enclosed_by']) in cmd
+ assert "--escaped-by {}".format(
+ self._config_export['escaped_by']) in cmd
+ assert "--input-fields-terminated-by {}".format(
+ self._config_export['input_fields_terminated_by']) in cmd
+ assert "--input-lines-terminated-by {}".format(
+ self._config_export['input_lines_terminated_by']) in cmd
+ assert "--input-optionally-enclosed-by {}".format(
+ self._config_export['input_optionally_enclosed_by']) in cmd
+
+ if self._config_export['clear_staging_table']:
+ assert "--clear-staging-table" in cmd
+
+ if self._config_export['batch']:
+ assert "--batch" in cmd
+
+ if self._config_export['relaxed_isolation']:
+ assert "--relaxed-isolation" in cmd
+
+ def test_import_cmd(self):
+ hook = SqoopHook()
+
+ # The subprocess requires an array but we build the cmd by joining on a space
+ cmd = ' '.join(
+ hook._import_cmd(self._config_import['target_dir'],
+ append=self._config_import['append'],
+ file_type=self._config_import['file_type'],
+ split_by=self._config_import['split_by'],
+ direct=self._config_import['direct'],
+ driver=self._config_import['driver'])
+ )
+
+ if self._config_import['append']:
+ assert '--append' in cmd
+
+ if self._config_import['direct']:
+ assert '--direct' in cmd
+
+ assert '--target-dir {}'.format(
+ self._config_import['target_dir']) in cmd
+
+ assert '--driver {}'.format(self._config_import['driver']) in cmd
+ assert '--split-by {}'.format(self._config_import['split_by']) in cmd
+
+ def test_get_export_format_argument(self):
+ hook = SqoopHook()
+ assert "--as-avrodatafile" in hook._get_export_format_argument('avro')
+ assert "--as-parquetfile" in hook._get_export_format_argument(
+ 'parquet')
+ assert "--as-sequencefile" in hook._get_export_format_argument(
+ 'sequence')
+ assert "--as-textfile" in hook._get_export_format_argument('text')
+ assert "--as-textfile" in hook._get_export_format_argument('unknown')
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/219c5064/tests/contrib/hooks/test_zendesk_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_zendesk_hook.py b/tests/contrib/hooks/test_zendesk_hook.py
new file mode 100644
index 0000000..7751a2b
--- /dev/null
+++ b/tests/contrib/hooks/test_zendesk_hook.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import mock
+
+from airflow.hooks.zendesk_hook import ZendeskHook
+from zdesk import RateLimitError
+
+
+class TestZendeskHook(unittest.TestCase):
+
+ @mock.patch("airflow.hooks.zendesk_hook.time")
+ def test_sleeps_for_correct_interval(self, mocked_time):
+ sleep_time = 10
+ # To break out of the otherwise infinite tries
+ mocked_time.sleep = mock.Mock(side_effect=ValueError, return_value=3)
+ conn_mock = mock.Mock()
+ mock_response = mock.Mock()
+ mock_response.headers.get.return_value = sleep_time
+ conn_mock.call = mock.Mock(
+ side_effect=RateLimitError(msg="some message", code="some code",
+ response=mock_response))
+
+ zendesk_hook = ZendeskHook("conn_id")
+ zendesk_hook.get_conn = mock.Mock(return_value=conn_mock)
+
+ with self.assertRaises(ValueError):
+ zendesk_hook.call("some_path", get_all_pages=False)
+ mocked_time.sleep.assert_called_with(sleep_time)
+
+ @mock.patch("airflow.hooks.zendesk_hook.Zendesk")
+ def test_returns_single_page_if_get_all_pages_false(self, _):
+ zendesk_hook = ZendeskHook("conn_id")
+ mock_connection = mock.Mock()
+ mock_connection.host = "some_host"
+ zendesk_hook.get_connection = mock.Mock(return_value=mock_connection)
+ zendesk_hook.get_conn()
+
+ mock_conn = mock.Mock()
+ mock_call = mock.Mock(
+ return_value={'next_page': 'https://some_host/something', 'path':
+ []})
+ mock_conn.call = mock_call
+ zendesk_hook.get_conn = mock.Mock(return_value=mock_conn)
+ zendesk_hook.call("path", get_all_pages=False)
+ mock_call.assert_called_once_with("path", None)
+
+ @mock.patch("airflow.hooks.zendesk_hook.Zendesk")
+ def test_returns_multiple_pages_if_get_all_pages_true(self, _):
+ zendesk_hook = ZendeskHook("conn_id")
+ mock_connection = mock.Mock()
+ mock_connection.host = "some_host"
+ zendesk_hook.get_connection = mock.Mock(return_value=mock_connection)
+ zendesk_hook.get_conn()
+
+ mock_conn = mock.Mock()
+ mock_call = mock.Mock(
+ return_value={'next_page': 'https://some_host/something', 'path': []})
+ mock_conn.call = mock_call
+ zendesk_hook.get_conn = mock.Mock(return_value=mock_conn)
+ zendesk_hook.call("path", get_all_pages=True)
+ assert mock_call.call_count == 2
+
+ @mock.patch("airflow.hooks.zendesk_hook.Zendesk")
+ def test_zdesk_is_inited_correctly(self, mock_zendesk):
+ conn_mock = mock.Mock()
+ conn_mock.host = "conn_host"
+ conn_mock.login = "conn_login"
+ conn_mock.password = "conn_pass"
+
+ zendesk_hook = ZendeskHook("conn_id")
+ zendesk_hook.get_connection = mock.Mock(return_value=conn_mock)
+ zendesk_hook.get_conn()
+ mock_zendesk.assert_called_with('https://conn_host', 'conn_login',
+ 'conn_pass', True)