You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/01/09 21:14:26 UTC

[2/3] beam git commit: Provided temporary directory management for test cases.

Provided temporary directory management for test cases.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/93e8d19e
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/93e8d19e
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/93e8d19e

Branch: refs/heads/python-sdk
Commit: 93e8d19e32807fb5279ed711f0f06c3123adfb2e
Parents: 88833ba
Author: Younghee Kwon <yo...@gmail.com>
Authored: Mon Jan 9 11:50:57 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:46 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/tfrecordio_test.py | 58 +++++++++++++++-------
 1 file changed, 41 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/93e8d19e/sdks/python/apache_beam/io/tfrecordio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index ee287b3..ecd58f5 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -20,8 +20,10 @@ import cStringIO
 import glob
 import gzip
 import logging
+import os
 import pickle
 import random
+import shutil
 import tempfile
 import unittest
 
@@ -134,7 +136,29 @@ class TestTFRecordUtil(unittest.TestCase):
       self.assertEqual(record, actual)
 
 
-class TestTFRecordSink(unittest.TestCase):
+class _TestCaseWithTempDirCleanUp(unittest.TestCase):
+  """Base class for TestCases that deals with TempDir clean-up.
+
+  Inherited test cases will call self._new_tempdir() to start a temporary dir
+  which will be deleted at the end of the tests (when tearDown() is called).
+  """
+
+  def setUp(self):
+    self._tempdirs = []
+
+  def tearDown(self):
+    for path in self._tempdirs:
+      if os.path.exists(path):
+        shutil.rmtree(path)
+    self._tempdirs = []
+
+  def _new_tempdir(self):
+    result = tempfile.mkdtemp()
+    self._tempdirs.append(result)
+    return result
+
+
+class TestTFRecordSink(_TestCaseWithTempDirCleanUp):
 
   def _write_lines(self, sink, path, lines):
     f = sink.open(path)
@@ -143,7 +167,7 @@ class TestTFRecordSink(unittest.TestCase):
     sink.close(f)
 
   def test_write_record_single(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     record = binascii.a2b_base64(FOO_RECORD_BASE64)
     sink = _TFRecordSink(
         path,
@@ -158,7 +182,7 @@ class TestTFRecordSink(unittest.TestCase):
       self.assertEqual(f.read(), record)
 
   def test_write_record_multiple(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
     sink = _TFRecordSink(
         path,
@@ -177,8 +201,8 @@ class TestTFRecordSink(unittest.TestCase):
 class TestWriteToTFRecord(TestTFRecordSink):
 
   def test_write_record_gzip(self):
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
     with beam.Pipeline(DirectRunner()) as p:
-      file_path_prefix = tempfile.NamedTemporaryFile().name
       input_data = ['foo', 'bar']
       _ = p | beam.Create(input_data) | WriteToTFRecord(
           file_path_prefix, compression_type=fileio.CompressionTypes.GZIP)
@@ -192,8 +216,8 @@ class TestWriteToTFRecord(TestTFRecordSink):
     self.assertEqual(actual, input_data)
 
   def test_write_record_auto(self):
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
     with beam.Pipeline(DirectRunner()) as p:
-      file_path_prefix = tempfile.NamedTemporaryFile().name
       input_data = ['foo', 'bar']
       _ = p | beam.Create(input_data) | WriteToTFRecord(
           file_path_prefix, file_name_suffix='.gz')
@@ -207,7 +231,7 @@ class TestWriteToTFRecord(TestTFRecordSink):
     self.assertEqual(actual, input_data)
 
 
-class TestTFRecordSource(unittest.TestCase):
+class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
 
   def _write_file(self, path, base64_records):
     record = binascii.a2b_base64(base64_records)
@@ -220,7 +244,7 @@ class TestTFRecordSource(unittest.TestCase):
       f.write(record)
 
   def test_process_single(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file(path, FOO_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -232,7 +256,7 @@ class TestTFRecordSource(unittest.TestCase):
       beam.assert_that(result, beam.equal_to(['foo']))
 
   def test_process_multiple(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -244,7 +268,7 @@ class TestTFRecordSource(unittest.TestCase):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_gzip(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -256,7 +280,7 @@ class TestTFRecordSource(unittest.TestCase):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_auto(self):
-    path = tempfile.mkstemp(suffix='.gz')[1]
+    path = os.path.join(self._new_tempdir(), 'result.gz')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -271,7 +295,7 @@ class TestTFRecordSource(unittest.TestCase):
 class TestReadFromTFRecordSource(TestTFRecordSource):
 
   def test_process_gzip(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -280,7 +304,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_gzip_auto(self):
-    path = tempfile.mkstemp(suffix='.gz')[1]
+    path = os.path.join(self._new_tempdir(), 'result.gz')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -289,7 +313,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
 
-class TestEnd2EndWriteAndRead(unittest.TestCase):
+class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp):
 
   def create_inputs(self):
     input_array = [[random.random() - 0.5 for _ in xrange(15)]
@@ -299,7 +323,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
     return memfile.getvalue()
 
   def test_end2end(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     # Generate a TFRecord file.
     with beam.Pipeline(DirectRunner()) as p:
@@ -312,7 +336,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
       beam.assert_that(actual_data, beam.equal_to(expected_data))
 
   def test_end2end_auto_compression(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     # Generate a TFRecord file.
     with beam.Pipeline(DirectRunner()) as p:
@@ -326,7 +350,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
       beam.assert_that(actual_data, beam.equal_to(expected_data))
 
   def test_end2end_auto_compression_unsharded(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     # Generate a TFRecord file.
     with beam.Pipeline(DirectRunner()) as p:
@@ -341,7 +365,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
 
   @unittest.skipIf(tf is None, 'tensorflow not installed.')
   def test_end2end_example_proto(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     example = tf.train.Example()
     example.features.feature['int'].int64_list.value.extend(range(3))