You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/08/06 09:23:19 UTC

[GitHub] Fokko closed pull request #3690: [AIRFLOW-2845] Remove asserts from the contrib package

Fokko closed pull request #3690: [AIRFLOW-2845] Remove asserts from the contrib package
URL: https://github.com/apache/incubator-airflow/pull/3690
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
index f4c1a3b520..aa8fc382a6 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -592,9 +592,11 @@ def run_query(self,
         }
 
         if destination_dataset_table:
-            assert '.' in destination_dataset_table, (
-                'Expected destination_dataset_table in the format of '
-                '<dataset>.<table>. Got: {}').format(destination_dataset_table)
+            if '.' not in destination_dataset_table:
+                raise ValueError(
+                    'Expected destination_dataset_table name in the format of '
+                    '<dataset>.<table>. Got: {}'.format(
+                        destination_dataset_table))
             destination_project, destination_dataset, destination_table = \
                 _split_tablename(table_input=destination_dataset_table,
                                  default_project_id=self.project_id)
@@ -610,7 +612,9 @@ def run_query(self,
                 }
             })
         if udf_config:
-            assert isinstance(udf_config, list)
+            if not isinstance(udf_config, list):
+                raise TypeError("udf_config argument must have a type 'list'"
+                                " not {}".format(type(udf_config)))
             configuration['query'].update({
                 'userDefinedFunctionResources': udf_config
             })
@@ -1153,10 +1157,10 @@ def run_table_delete(self, deletion_dataset_table,
         :type ignore_if_missing: boolean
         :return:
         """
-
-        assert '.' in deletion_dataset_table, (
-            'Expected deletion_dataset_table in the format of '
-            '<dataset>.<table>. Got: {}').format(deletion_dataset_table)
+        if '.' not in deletion_dataset_table:
+            raise ValueError(
+                'Expected deletion_dataset_table name in the format of '
+                '<dataset>.<table>. Got: {}'.format(deletion_dataset_table))
         deletion_project, deletion_dataset, deletion_table = \
             _split_tablename(table_input=deletion_dataset_table,
                              default_project_id=self.project_id)
@@ -1284,14 +1288,10 @@ def run_grant_dataset_view_access(self,
             # if view is already in access, do nothing.
             self.log.info(
                 'Table %s:%s.%s already has authorized view access to %s:%s dataset.',
-                view_project, view_dataset, view_table, source_project,
-                source_dataset)
+                view_project, view_dataset, view_table, source_project, source_dataset)
             return source_dataset_resource
 
-    def delete_dataset(self,
-                       project_id,
-                       dataset_id
-                       ):
+    def delete_dataset(self, project_id, dataset_id):
         """
         Delete a dataset of Big query in your project.
         :param project_id: The name of the project where we have the dataset .
@@ -1308,9 +1308,8 @@ def delete_dataset(self,
             self.service.datasets().delete(
                 projectId=project_id,
                 datasetId=dataset_id).execute()
-
-            self.log.info('Dataset deleted successfully: In project %s Dataset %s',
-                          project_id, dataset_id)
+            self.log.info('Dataset deleted successfully: In project %s '
+                          'Dataset %s', project_id, dataset_id)
 
         except HttpError as err:
             raise AirflowException(
@@ -1518,14 +1517,17 @@ def _bq_cast(string_field, bq_type):
     elif bq_type == 'FLOAT' or bq_type == 'TIMESTAMP':
         return float(string_field)
     elif bq_type == 'BOOLEAN':
-        assert string_field in set(['true', 'false'])
+        if string_field not in ['true', 'false']:
+            raise ValueError("{} must have value 'true' or 'false'".format(
+                string_field))
         return string_field == 'true'
     else:
         return string_field
 
 
 def _split_tablename(table_input, default_project_id, var_name=None):
-    assert default_project_id is not None, "INTERNAL: No default project is specified"
+    if not default_project_id:
+        raise ValueError("INTERNAL: No default project is specified")
 
     def var_print(var_name):
         if var_name is None:
@@ -1537,7 +1539,6 @@ def var_print(var_name):
         raise Exception(('{var}Use either : or . to specify project '
                          'got {input}').format(
                              var=var_print(var_name), input=table_input))
-
     cmpt = table_input.rsplit(':', 1)
     project_id = None
     rest = table_input
@@ -1555,8 +1556,10 @@ def var_print(var_name):
 
     cmpt = rest.split('.')
     if len(cmpt) == 3:
-        assert project_id is None, ("{var}Use either : or . to specify project"
-                                    ).format(var=var_print(var_name))
+        if project_id:
+            raise ValueError(
+                "{var}Use either : or . to specify project".format(
+                    var=var_print(var_name)))
         project_id = cmpt[0]
         dataset_id = cmpt[1]
         table_id = cmpt[2]
@@ -1586,10 +1589,10 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in):
     # if it is a partitioned table ($ is in the table name) add partition load option
     time_partitioning_out = {}
     if destination_dataset_table and '$' in destination_dataset_table:
-        assert not time_partitioning_in.get('field'), (
-            "Cannot specify field partition and partition name "
-            "(dataset.table$partition) at the same time"
-        )
+        if time_partitioning_in.get('field'):
+            raise ValueError(
+                "Cannot specify field partition and partition name"
+                "(dataset.table$partition) at the same time")
         time_partitioning_out['type'] = 'DAY'
 
     time_partitioning_out.update(time_partitioning_in)
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 1443ff4740..2e5f1399b4 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -61,7 +61,8 @@ def __init__(
         self.databricks_conn_id = databricks_conn_id
         self.databricks_conn = self.get_connection(databricks_conn_id)
         self.timeout_seconds = timeout_seconds
-        assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
+        if retry_limit < 1:
+            raise ValueError('Retry limit must be greater than equal to 1')
         self.retry_limit = retry_limit
 
     def _parse_host(self, host):
diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py
index 279b9dd21a..ee3b510ed7 100644
--- a/airflow/contrib/hooks/gcp_dataflow_hook.py
+++ b/airflow/contrib/hooks/gcp_dataflow_hook.py
@@ -225,11 +225,11 @@ def label_formatter(labels_dict):
     def _build_dataflow_job_name(task_id, append_job_name=True):
         task_id = str(task_id).replace('_', '-')
 
-        assert re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id), \
-            'Invalid job_name ({}); the name must consist of ' \
-            'only the characters [-a-z0-9], starting with a ' \
-            'letter and ending with a letter or number '.format(
-                task_id)
+        if not re.match(r"^[a-z]([-a-z0-9]*[a-z0-9])?$", task_id):
+            raise ValueError(
+                'Invalid job_name ({}); the name must consist of'
+                'only the characters [-a-z0-9], starting with a '
+                'letter and ending with a letter or number '.format(task_id))
 
         if append_job_name:
             job_name = task_id + "-" + str(uuid.uuid1())[:8]
@@ -238,7 +238,8 @@ def _build_dataflow_job_name(task_id, append_job_name=True):
 
         return job_name
 
-    def _build_cmd(self, task_id, variables, label_formatter):
+    @staticmethod
+    def _build_cmd(task_id, variables, label_formatter):
         command = ["--runner=DataflowRunner"]
         if variables is not None:
             for attr, value in variables.items():
@@ -250,7 +251,8 @@ def _build_cmd(self, task_id, variables, label_formatter):
                     command.append("--" + attr + "=" + value)
         return command
 
-    def _start_template_dataflow(self, name, variables, parameters, dataflow_template):
+    def _start_template_dataflow(self, name, variables, parameters,
+                                 dataflow_template):
         # Builds RuntimeEnvironment from variables dictionary
         # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
         environment = {}
@@ -262,9 +264,11 @@ def _start_template_dataflow(self, name, variables, parameters, dataflow_templat
                 "parameters": parameters,
                 "environment": environment}
         service = self.get_conn()
-        request = service.projects().templates().launch(projectId=variables['project'],
-                                                        gcsPath=dataflow_template,
-                                                        body=body)
+        request = service.projects().templates().launch(
+            projectId=variables['project'],
+            gcsPath=dataflow_template,
+            body=body
+        )
         response = request.execute()
         variables = self._set_variables(variables)
         _DataflowJob(self.get_conn(), variables['project'], name, variables['region'],
diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py
index 66f392b156..b9f1008fa7 100644
--- a/airflow/contrib/hooks/gcp_mlengine_hook.py
+++ b/airflow/contrib/hooks/gcp_mlengine_hook.py
@@ -152,7 +152,8 @@ def _wait_for_job_done(self, project_id, job_id, interval=30):
             apiclient.errors.HttpError: if HTTP error is returned when getting
             the job
         """
-        assert interval > 0
+        if interval <= 0:
+            raise ValueError("Interval must be > 0")
         while True:
             job = self._get_job(project_id, job_id)
             if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
@@ -242,7 +243,9 @@ def create_model(self, project_id, model):
         """
         Create a Model. Blocks until finished.
         """
-        assert model['name'] is not None and model['name'] is not ''
+        if not model['name']:
+            raise ValueError("Model name must be provided and "
+                             "could not be an empty string")
         project = 'projects/{}'.format(project_id)
 
         request = self._mlengine.projects().models().create(
@@ -253,7 +256,9 @@ def get_model(self, project_id, model_name):
         """
         Gets a Model. Blocks until finished.
         """
-        assert model_name is not None and model_name is not ''
+        if not model_name:
+            raise ValueError("Model name must be provided and "
+                             "it could not be an empty string")
         full_model_name = 'projects/{}/models/{}'.format(
             project_id, model_name)
         request = self._mlengine.projects().models().get(name=full_model_name)
diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py
index c5e356f41c..08d44ce7fa 100644
--- a/airflow/contrib/hooks/gcs_hook.py
+++ b/airflow/contrib/hooks/gcs_hook.py
@@ -477,15 +477,16 @@ def create_bucket(self,
 
         self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s',
                       bucket_name, location, storage_class)
-        assert storage_class in storage_classes, \
-            'Invalid value ({}) passed to storage_class. Value should be ' \
-            'one of {}'.format(storage_class, storage_classes)
+        if storage_class not in storage_classes:
+            raise ValueError(
+                'Invalid value ({}) passed to storage_class. Value should be '
+                'one of {}'.format(storage_class, storage_classes))
 
-        assert re.match('[a-zA-Z0-9]+', bucket_name[0]), \
-            'Bucket names must start with a number or letter.'
+        if not re.match('[a-zA-Z0-9]+', bucket_name[0]):
+            raise ValueError('Bucket names must start with a number or letter.')
 
-        assert re.match('[a-zA-Z0-9]+', bucket_name[-1]), \
-            'Bucket names must end with a number or letter.'
+        if not re.match('[a-zA-Z0-9]+', bucket_name[-1]):
+            raise ValueError('Bucket names must end with a number or letter.')
 
         service = self.get_conn()
         bucket_resource = {
diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py
index 9fe966d387..8e75b3c608 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -427,7 +427,9 @@ def execute(self, context):
             gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
 
         if self._operation == 'create':
-            assert self._version is not None
+            if not self._version:
+                raise ValueError("version attribute of {} could not "
+                                 "be empty".format(self.__class__.__name__))
             return hook.create_version(self._project_id, self._model_name,
                                        self._version)
         elif self._operation == 'set_default':
diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py
index 2f39bd9bce..39435f0c4e 100644
--- a/tests/contrib/hooks/test_bigquery_hook.py
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -414,7 +414,7 @@ def test_extra_time_partitioning_options(self):
         self.assertEqual(tp_out, expect)
 
     def test_cant_add_dollar_and_field_name(self):
-        with self.assertRaises(AssertionError):
+        with self.assertRaises(ValueError):
             _cleanse_time_partitioning(
                 'test.teast$20170101',
                 {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000}
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index 6052a6d54f..aca8dd9600 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -110,7 +110,7 @@ def test_parse_host_with_scheme(self):
         self.assertEquals(host, HOST)
 
     def test_init_bad_retry_limit(self):
-        with self.assertRaises(AssertionError):
+        with self.assertRaises(ValueError):
             DatabricksHook(retry_limit = 0)
 
     @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py
index 90714c6ee4..bc7c587135 100644
--- a/tests/contrib/hooks/test_gcp_dataflow_hook.py
+++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py
@@ -195,7 +195,7 @@ def test_invalid_dataflow_job_name(self):
         fixed_name = invalid_job_name.replace(
             '_', '-')
 
-        with self.assertRaises(AssertionError) as e:
+        with self.assertRaises(ValueError) as e:
             self.dataflow_hook._build_dataflow_job_name(
                 task_id=invalid_job_name, append_job_name=False
             )
@@ -222,19 +222,19 @@ def test_dataflow_job_regex_check(self):
         ), 'dfjob1')
 
         self.assertRaises(
-            AssertionError,
+            ValueError,
             self.dataflow_hook._build_dataflow_job_name,
             task_id='1dfjob', append_job_name=False
         )
 
         self.assertRaises(
-            AssertionError,
+            ValueError,
             self.dataflow_hook._build_dataflow_job_name,
             task_id='dfjob@', append_job_name=False
         )
 
         self.assertRaises(
-            AssertionError,
+            ValueError,
             self.dataflow_hook._build_dataflow_job_name,
             task_id='df^jo', append_job_name=False
         )
diff --git a/tests/contrib/hooks/test_gcs_hook.py b/tests/contrib/hooks/test_gcs_hook.py
index fb65938fd9..eedceff1f7 100644
--- a/tests/contrib/hooks/test_gcs_hook.py
+++ b/tests/contrib/hooks/test_gcs_hook.py
@@ -66,14 +66,14 @@ class TestGCSBucket(unittest.TestCase):
     def test_bucket_name_value(self):
 
         bad_start_bucket_name = '/testing123'
-        with self.assertRaises(AssertionError):
+        with self.assertRaises(ValueError):
 
             gcs_hook.GoogleCloudStorageHook().create_bucket(
                 bucket_name=bad_start_bucket_name
             )
 
         bad_end_bucket_name = 'testing123/'
-        with self.assertRaises(AssertionError):
+        with self.assertRaises(ValueError):
             gcs_hook.GoogleCloudStorageHook().create_bucket(
                 bucket_name=bad_end_bucket_name
             )


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services