You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/03/13 04:45:10 UTC
[12/45] incubator-airflow git commit: [AIRFLOW-793] Enable compressed
loading in S3ToHiveTransfer
[AIRFLOW-793] Enable compressed loading in S3ToHiveTransfer
Testing Done:
- Added new unit tests for the S3ToHiveTransfer
module
Closes #2012 from krishnabhupatiraju/S3ToHiveTrans
fer_compress_loading
(cherry picked from commit ad15f5efd6c663bd5f0c8cd3f556d08182cc778c)
Signed-off-by: Bolke de Bruin <bo...@xs4all.nl>
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1c231333
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1c231333
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1c231333
Branch: refs/heads/v1-8-stable
Commit: 1c2313338a586aae4a7752c3fb3b9de4e3564415
Parents: 3658bf3
Author: Krishna Bhupatiraju <kr...@airbnb.com>
Authored: Mon Feb 6 16:52:11 2017 -0800
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Sat Feb 18 15:56:37 2017 +0100
----------------------------------------------------------------------
airflow/operators/s3_to_hive_operator.py | 151 ++++++++++++----
airflow/utils/compression.py | 38 ++++
tests/operators/__init__.py | 1 +
tests/operators/s3_to_hive_operator.py | 247 ++++++++++++++++++++++++++
tests/utils/compression.py | 97 ++++++++++
5 files changed, 497 insertions(+), 37 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/airflow/operators/s3_to_hive_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py
index 3e01c29..92340f8 100644
--- a/airflow/operators/s3_to_hive_operator.py
+++ b/airflow/operators/s3_to_hive_operator.py
@@ -16,13 +16,18 @@ from builtins import next
from builtins import zip
import logging
from tempfile import NamedTemporaryFile
+from airflow.utils.file import TemporaryDirectory
+import gzip
+import bz2
+import tempfile
+import os
from airflow.exceptions import AirflowException
from airflow.hooks.S3_hook import S3Hook
from airflow.hooks.hive_hooks import HiveCliHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
-
+from airflow.utils.compression import uncompress_file
class S3ToHiveTransfer(BaseOperator):
"""
@@ -68,8 +73,11 @@ class S3ToHiveTransfer(BaseOperator):
:type delimiter: str
:param s3_conn_id: source s3 connection
:type s3_conn_id: str
- :param hive_conn_id: destination hive connection
- :type hive_conn_id: str
+ :param hive_cli_conn_id: destination hive connection
+ :type hive_cli_conn_id: str
+ :param input_compressed: Boolean to determine if file decompression is
+ required to process headers
+ :type input_compressed: bool
"""
template_fields = ('s3_key', 'partition', 'hive_table')
@@ -91,6 +99,7 @@ class S3ToHiveTransfer(BaseOperator):
wildcard_match=False,
s3_conn_id='s3_default',
hive_cli_conn_id='hive_cli_default',
+ input_compressed=False,
*args, **kwargs):
super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
self.s3_key = s3_key
@@ -105,28 +114,41 @@ class S3ToHiveTransfer(BaseOperator):
self.wildcard_match = wildcard_match
self.hive_cli_conn_id = hive_cli_conn_id
self.s3_conn_id = s3_conn_id
+ self.input_compressed = input_compressed
+
+ if (self.check_headers and
+ not (self.field_dict is not None and self.headers)):
+ raise AirflowException("To check_headers provide " +
+ "field_dict and headers")
def execute(self, context):
- self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
+ # Downloading file from S3
self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
+ self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
logging.info("Downloading S3 file")
+
if self.wildcard_match:
if not self.s3.check_for_wildcard_key(self.s3_key):
- raise AirflowException("No key matches {0}".format(self.s3_key))
+ raise AirflowException("No key matches {0}"
+ .format(self.s3_key))
s3_key_object = self.s3.get_wildcard_key(self.s3_key)
else:
if not self.s3.check_for_key(self.s3_key):
raise AirflowException(
"The key {0} does not exists".format(self.s3_key))
s3_key_object = self.s3.get_key(self.s3_key)
- with NamedTemporaryFile("w") as f:
+ root, file_ext = os.path.splitext(s3_key_object.key)
+ with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
+ NamedTemporaryFile(mode="w",
+ dir=tmp_dir,
+ suffix=file_ext) as f:
logging.info("Dumping S3 key {0} contents to local"
" file {1}".format(s3_key_object.key, f.name))
s3_key_object.get_contents_to_file(f)
f.flush()
self.s3.connection.close()
if not self.headers:
- logging.info("Loading file into Hive")
+ logging.info("Loading file {0} into Hive".format(f.name))
self.hive.load_file(
f.name,
self.hive_table,
@@ -136,33 +158,88 @@ class S3ToHiveTransfer(BaseOperator):
delimiter=self.delimiter,
recreate=self.recreate)
else:
- with open(f.name, 'r') as tmpf:
- if self.check_headers:
- header_l = tmpf.readline()
- header_line = header_l.rstrip()
- header_list = header_line.split(self.delimiter)
- field_names = list(self.field_dict.keys())
- test_field_match = [h1.lower() == h2.lower() for h1, h2
- in zip(header_list, field_names)]
- if not all(test_field_match):
- logging.warning("Headers do not match field names"
- "File headers:\n {header_list}\n"
- "Field names: \n {field_names}\n"
- "".format(**locals()))
- raise AirflowException("Headers do not match the "
- "field_dict keys")
- with NamedTemporaryFile("w") as f_no_headers:
- tmpf.seek(0)
- next(tmpf)
- for line in tmpf:
- f_no_headers.write(line)
- f_no_headers.flush()
- logging.info("Loading file without headers into Hive")
- self.hive.load_file(
- f_no_headers.name,
- self.hive_table,
- field_dict=self.field_dict,
- create=self.create,
- partition=self.partition,
- delimiter=self.delimiter,
- recreate=self.recreate)
+ # Decompressing file
+ if self.input_compressed:
+ logging.info("Uncompressing file {0}".format(f.name))
+ fn_uncompressed = uncompress_file(f.name,
+ file_ext,
+ tmp_dir)
+ logging.info("Uncompressed to {0}".format(fn_uncompressed))
+ # uncompressed file available now so deleting
+ # compressed file to save disk space
+ f.close()
+ else:
+ fn_uncompressed = f.name
+
+ # Testing if header matches field_dict
+ if self.check_headers:
+ logging.info("Matching file header against field_dict")
+ header_list = self._get_top_row_as_list(fn_uncompressed)
+ if not self._match_headers(header_list):
+ raise AirflowException("Header check failed")
+
+ # Deleting top header row
+ logging.info("Removing header from file {0}".
+ format(fn_uncompressed))
+ headless_file = (
+ self._delete_top_row_and_compress(fn_uncompressed,
+ file_ext,
+ tmp_dir))
+ logging.info("Headless file {0}".format(headless_file))
+ logging.info("Loading file {0} into Hive".format(headless_file))
+ self.hive.load_file(headless_file,
+ self.hive_table,
+ field_dict=self.field_dict,
+ create=self.create,
+ partition=self.partition,
+ delimiter=self.delimiter,
+ recreate=self.recreate)
+
+ def _get_top_row_as_list(self, file_name):
+ with open(file_name, 'rt') as f:
+ header_line = f.readline().strip()
+ header_list = header_line.split(self.delimiter)
+ return header_list
+
+ def _match_headers(self, header_list):
+ if not header_list:
+ raise AirflowException("Unable to retrieve header row from file")
+ field_names = self.field_dict.keys()
+ if len(field_names) != len(header_list):
+ logging.warning("Headers count mismatch"
+ "File headers:\n {header_list}\n"
+ "Field names: \n {field_names}\n"
+ "".format(**locals()))
+ return False
+ test_field_match = [h1.lower() == h2.lower()
+ for h1, h2 in zip(header_list, field_names)]
+ if not all(test_field_match):
+ logging.warning("Headers do not match field names"
+ "File headers:\n {header_list}\n"
+ "Field names: \n {field_names}\n"
+ "".format(**locals()))
+ return False
+ else:
+ return True
+
+ def _delete_top_row_and_compress(
+ self,
+ input_file_name,
+ output_file_ext,
+ dest_dir):
+ # When output_file_ext is not defined, file is not compressed
+ open_fn = open
+ if output_file_ext.lower() == '.gz':
+ open_fn = gzip.GzipFile
+ elif output_file_ext.lower() == '.bz2':
+ open_fn = bz2.BZ2File
+
+ os_fh_output, fn_output = \
+ tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
+ with open(input_file_name, 'rb') as f_in,\
+ open_fn(fn_output, 'wb') as f_out:
+ f_in.seek(0)
+ next(f_in)
+ for line in f_in:
+ f_out.write(line)
+ return fn_output
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/airflow/utils/compression.py
----------------------------------------------------------------------
diff --git a/airflow/utils/compression.py b/airflow/utils/compression.py
new file mode 100644
index 0000000..9d0785f
--- /dev/null
+++ b/airflow/utils/compression.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tempfile import NamedTemporaryFile
+import shutil
+import gzip
+import bz2
+
+
+def uncompress_file(input_file_name, file_extension, dest_dir):
+ """
+ Uncompress gz and bz2 files
+ """
+ if file_extension.lower() not in ('.gz', '.bz2'):
+ raise NotImplementedError("Received {} format. Only gz and bz2 "
+ "files can currently be uncompressed."
+ .format(file_extension))
+ if file_extension.lower() == '.gz':
+ fmodule = gzip.GzipFile
+ elif file_extension.lower() == '.bz2':
+ fmodule = bz2.BZ2File
+ with fmodule(input_file_name, mode='rb') as f_compressed,\
+ NamedTemporaryFile(dir=dest_dir,
+ mode='wb',
+ delete=False) as f_uncompressed:
+ shutil.copyfileobj(f_compressed, f_uncompressed)
+ return f_uncompressed.name
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/operators/__init__.py
----------------------------------------------------------------------
diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py
index 63ff2a0..1fb0e5e 100644
--- a/tests/operators/__init__.py
+++ b/tests/operators/__init__.py
@@ -17,3 +17,4 @@ from .subdag_operator import *
from .operators import *
from .sensors import *
from .hive_operator import *
+from .s3_to_hive_operator import *
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/operators/s3_to_hive_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/s3_to_hive_operator.py
new file mode 100644
index 0000000..faab11e
--- /dev/null
+++ b/tests/operators/s3_to_hive_operator.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+try:
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+import logging
+from itertools import product
+from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer
+from collections import OrderedDict
+from airflow.exceptions import AirflowException
+from tempfile import NamedTemporaryFile, mkdtemp
+import gzip
+import bz2
+import shutil
+import filecmp
+import errno
+
+
+class S3ToHiveTransferTest(unittest.TestCase):
+
+ def setUp(self):
+ self.fn = {}
+ self.task_id = 'S3ToHiveTransferTest'
+ self.s3_key = 'S32hive_test_file'
+ self.field_dict = OrderedDict([('Sno', 'BIGINT'), ('Some,Text', 'STRING')])
+ self.hive_table = 'S32hive_test_table'
+ self.delimiter = '\t'
+ self.create = True
+ self.recreate = True
+ self.partition = {'ds': 'STRING'}
+ self.headers = True
+ self.check_headers = True
+ self.wildcard_match = False
+ self.input_compressed = False
+ self.kwargs = {'task_id': self.task_id,
+ 's3_key': self.s3_key,
+ 'field_dict': self.field_dict,
+ 'hive_table': self.hive_table,
+ 'delimiter': self.delimiter,
+ 'create': self.create,
+ 'recreate': self.recreate,
+ 'partition': self.partition,
+ 'headers': self.headers,
+ 'check_headers': self.check_headers,
+ 'wildcard_match': self.wildcard_match,
+ 'input_compressed': self.input_compressed
+ }
+ try:
+ header = "Sno\tSome,Text \n".encode()
+ line1 = "1\tAirflow Test\n".encode()
+ line2 = "2\tS32HiveTransfer\n".encode()
+ self.tmp_dir = mkdtemp(prefix='test_tmps32hive_')
+ # create sample txt, gz and bz2 with and without headers
+ with NamedTemporaryFile(mode='wb+',
+ dir=self.tmp_dir,
+ delete=False) as f_txt_h:
+ self._set_fn(f_txt_h.name, '.txt', True)
+ f_txt_h.writelines([header, line1, line2])
+ fn_gz = self._get_fn('.txt', True) + ".gz"
+ with gzip.GzipFile(filename=fn_gz,
+ mode="wb") as f_gz_h:
+ self._set_fn(fn_gz, '.gz', True)
+ f_gz_h.writelines([header, line1, line2])
+ fn_bz2 = self._get_fn('.txt', True) + '.bz2'
+ with bz2.BZ2File(filename=fn_bz2,
+ mode="wb") as f_bz2_h:
+ self._set_fn(fn_bz2, '.bz2', True)
+ f_bz2_h.writelines([header, line1, line2])
+ # create sample txt, bz and bz2 without header
+ with NamedTemporaryFile(mode='wb+',
+ dir=self.tmp_dir,
+ delete=False) as f_txt_nh:
+ self._set_fn(f_txt_nh.name, '.txt', False)
+ f_txt_nh.writelines([line1, line2])
+ fn_gz = self._get_fn('.txt', False) + ".gz"
+ with gzip.GzipFile(filename=fn_gz,
+ mode="wb") as f_gz_nh:
+ self._set_fn(fn_gz, '.gz', False)
+ f_gz_nh.writelines([line1, line2])
+ fn_bz2 = self._get_fn('.txt', False) + '.bz2'
+ with bz2.BZ2File(filename=fn_bz2,
+ mode="wb") as f_bz2_nh:
+ self._set_fn(fn_bz2, '.bz2', False)
+ f_bz2_nh.writelines([line1, line2])
+ # Base Exception so it catches Keyboard Interrupt
+ except BaseException as e:
+ logging.error(e)
+ self.tearDown()
+
+ def tearDown(self):
+ try:
+ shutil.rmtree(self.tmp_dir)
+ except OSError as e:
+ # ENOENT - no such file or directory
+ if e.errno != errno.ENOENT:
+ raise e
+
+ # Helper method to create a dictionary of file names and
+ # file types (file extension and header)
+ def _set_fn(self, fn, ext, header):
+ key = self._get_key(ext, header)
+ self.fn[key] = fn
+
+ # Helper method to fetch a file of a
+ # certain format (file extension and header)
+ def _get_fn(self, ext, header):
+ key = self._get_key(ext, header)
+ return self.fn[key]
+
+ def _get_key(self, ext, header):
+ key = ext + "_" + ('h' if header else 'nh')
+ return key
+
+ def _cp_file_contents(self, fn_src, fn_dest):
+ with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest:
+ shutil.copyfileobj(f_src, f_dest)
+
+ def _check_file_equality(self, fn_1, fn_2, ext):
+ # gz files contain mtime and filename in the header that
+ # causes filecmp to return False even if contents are identical
+ # Hence decompress to test for equality
+ if(ext == '.gz'):
+ with gzip.GzipFile(fn_1, 'rb') as f_1,\
+ NamedTemporaryFile(mode='wb') as f_txt_1,\
+ gzip.GzipFile(fn_2, 'rb') as f_2,\
+ NamedTemporaryFile(mode='wb') as f_txt_2:
+ shutil.copyfileobj(f_1, f_txt_1)
+ shutil.copyfileobj(f_2, f_txt_2)
+ f_txt_1.flush()
+ f_txt_2.flush()
+ return filecmp.cmp(f_txt_1.name, f_txt_2.name, shallow=False)
+ else:
+ return filecmp.cmp(fn_1, fn_2, shallow=False)
+
+ def test_bad_parameters(self):
+ self.kwargs['check_headers'] = True
+ self.kwargs['headers'] = False
+ self.assertRaisesRegexp(AirflowException,
+ "To check_headers.*",
+ S3ToHiveTransfer,
+ **self.kwargs)
+
+ def test__get_top_row_as_list(self):
+ self.kwargs['delimiter'] = '\t'
+ fn_txt = self._get_fn('.txt', True)
+ header_list = S3ToHiveTransfer(**self.kwargs).\
+ _get_top_row_as_list(fn_txt)
+ self.assertEqual(header_list, ['Sno', 'Some,Text'],
+ msg="Top row from file doesnt matched expected value")
+
+ self.kwargs['delimiter'] = ','
+ header_list = S3ToHiveTransfer(**self.kwargs).\
+ _get_top_row_as_list(fn_txt)
+ self.assertEqual(header_list, ['Sno\tSome', 'Text'],
+ msg="Top row from file doesnt matched expected value")
+
+ def test__match_headers(self):
+ self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'),
+ ('Some,Text', 'STRING')])
+ self.assertTrue(S3ToHiveTransfer(**self.kwargs).
+ _match_headers(['Sno', 'Some,Text']),
+ msg="Header row doesnt match expected value")
+ # Testing with different column order
+ self.assertFalse(S3ToHiveTransfer(**self.kwargs).
+ _match_headers(['Some,Text', 'Sno']),
+ msg="Header row doesnt match expected value")
+ # Testing with extra column in header
+ self.assertFalse(S3ToHiveTransfer(**self.kwargs).
+ _match_headers(['Sno', 'Some,Text', 'ExtraColumn']),
+ msg="Header row doesnt match expected value")
+
+ def test__delete_top_row_and_compress(self):
+ s32hive = S3ToHiveTransfer(**self.kwargs)
+ # Testing gz file type
+ fn_txt = self._get_fn('.txt', True)
+ gz_txt_nh = s32hive._delete_top_row_and_compress(fn_txt,
+ '.gz',
+ self.tmp_dir)
+ fn_gz = self._get_fn('.gz', False)
+ self.assertTrue(self._check_file_equality(gz_txt_nh, fn_gz, '.gz'),
+ msg="gz Compressed file not as expected")
+ # Testing bz2 file type
+ bz2_txt_nh = s32hive._delete_top_row_and_compress(fn_txt,
+ '.bz2',
+ self.tmp_dir)
+ fn_bz2 = self._get_fn('.bz2', False)
+ self.assertTrue(self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'),
+ msg="bz2 Compressed file not as expected")
+
+ @unittest.skipIf(mock is None, 'mock package not present')
+ @mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook')
+ @mock.patch('airflow.operators.s3_to_hive_operator.S3Hook')
+ def test_execute(self, mock_s3hook, mock_hiveclihook):
+ # Testing txt, zip, bz2 files with and without header row
+ for test in product(['.txt', '.gz', '.bz2'], [True, False]):
+ ext = test[0]
+ has_header = test[1]
+ self.kwargs['headers'] = has_header
+ self.kwargs['check_headers'] = has_header
+ logging.info("Testing {0} format {1} header".
+ format(ext,
+ ('with' if has_header else 'without'))
+ )
+ self.kwargs['input_compressed'] = (False if ext == '.txt' else True)
+ self.kwargs['s3_key'] = self.s3_key + ext
+ ip_fn = self._get_fn(ext, self.kwargs['headers'])
+ op_fn = self._get_fn(ext, False)
+ # Mock s3 object returned by S3Hook
+ mock_s3_object = mock.Mock(key=self.kwargs['s3_key'])
+ mock_s3_object.get_contents_to_file.side_effect = \
+ lambda dest_file: \
+ self._cp_file_contents(ip_fn, dest_file.name)
+ mock_s3hook().get_key.return_value = mock_s3_object
+ # file paramter to HiveCliHook.load_file is compared
+ # against expected file oputput
+ mock_hiveclihook().load_file.side_effect = \
+ lambda *args, **kwargs: \
+ self.assertTrue(
+ self._check_file_equality(args[0],
+ op_fn,
+ ext
+ ),
+ msg='{0} output file not as expected'.format(ext))
+ # Execute S3ToHiveTransfer
+ s32hive = S3ToHiveTransfer(**self.kwargs)
+ s32hive.execute(None)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/utils/compression.py
----------------------------------------------------------------------
diff --git a/tests/utils/compression.py b/tests/utils/compression.py
new file mode 100644
index 0000000..f8e0ebb
--- /dev/null
+++ b/tests/utils/compression.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.utils import compression
+import unittest
+from tempfile import NamedTemporaryFile, mkdtemp
+import bz2
+import gzip
+import shutil
+import logging
+import errno
+import filecmp
+
+
+class Compression(unittest.TestCase):
+
+ def setUp(self):
+ self.fn = {}
+ try:
+ header = "Sno\tSome,Text \n".encode()
+ line1 = "1\tAirflow Test\n".encode()
+ line2 = "2\tCompressionUtil\n".encode()
+ self.tmp_dir = mkdtemp(prefix='test_utils_compression_')
+ # create sample txt, gz and bz2 files
+ with NamedTemporaryFile(mode='wb+',
+ dir=self.tmp_dir,
+ delete=False) as f_txt:
+ self._set_fn(f_txt.name, '.txt')
+ f_txt.writelines([header, line1, line2])
+ fn_gz = self._get_fn('.txt') + ".gz"
+ with gzip.GzipFile(filename=fn_gz,
+ mode="wb") as f_gz:
+ self._set_fn(fn_gz, '.gz')
+ f_gz.writelines([header, line1, line2])
+ fn_bz2 = self._get_fn('.txt') + '.bz2'
+ with bz2.BZ2File(filename=fn_bz2,
+ mode="wb") as f_bz2:
+ self._set_fn(fn_bz2, '.bz2')
+ f_bz2.writelines([header, line1, line2])
+ # Base Exception so it catches Keyboard Interrupt
+ except BaseException as e:
+ logging.error(e)
+ self.tearDown()
+
+ def tearDown(self):
+ try:
+ shutil.rmtree(self.tmp_dir)
+ except OSError as e:
+ # ENOENT - no such file or directory
+ if e.errno != errno.ENOENT:
+ raise e
+
+ # Helper method to create a dictionary of file names and
+ # file extension
+ def _set_fn(self, fn, ext):
+ self.fn[ext] = fn
+
+ # Helper method to fetch a file of a
+ # certain extension
+ def _get_fn(self, ext):
+ return self.fn[ext]
+
+ def test_uncompress_file(self):
+ # Testing txt file type
+ self.assertRaisesRegexp(NotImplementedError,
+ "^Received .txt format. Only gz and bz2.*",
+ compression.uncompress_file,
+ **{'input_file_name': None,
+ 'file_extension': '.txt',
+ 'dest_dir': None
+ })
+ # Testing gz file type
+ fn_txt = self._get_fn('.txt')
+ fn_gz = self._get_fn('.gz')
+ txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir)
+ self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False),
+ msg="Uncompressed file doest match original")
+ # Testing bz2 file type
+ fn_bz2 = self._get_fn('.bz2')
+ txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir)
+ self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False),
+ msg="Uncompressed file doest match original")
+
+
+if __name__ == '__main__':
+ unittest.main()