You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by GitBox <gi...@apache.org> on 2019/01/15 23:36:40 UTC

[beam] Diff for: [GitHub] aaltay merged pull request #7503: [BEAM-5315] Python 3 port io.tfrecordio module

diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index e85842436b22..49956ea6f3a6 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -59,15 +59,15 @@
 # >>> import tensorflow as tf
 # >>> import base64
 # >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
-# >>> writer.write('foo')
+# >>> writer.write(b'foo')
 # >>> writer.close()
 # >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
 # ...   data =  base64.b64encode(f.read())
 # ...   print(data)
-FOO_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/g=='
+FOO_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/g=='
 
-# Same as above but containing two records ['foo', 'bar']
-FOO_BAR_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg='
+# Same as above but containing two records [b'foo', b'bar']
+FOO_BAR_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg='
 
 
 def _write_file(path, base64_records):
@@ -95,42 +95,46 @@ def _as_file_handle(self, contents):
 
   def _increment_value_at_index(self, value, index):
     l = list(value)
-    l[index] = bytes(ord(l[index]) + 1)
-    return "".join(l)
+    if sys.version_info[0] <= 2:
+      l[index] = bytes(ord(l[index]) + 1)
+      return b"".join(l)
+    else:
+      l[index] = l[index] + 1
+      return bytes(l)
 
   def _test_error(self, record, error_text):
     with self.assertRaisesRegexp(ValueError, re.escape(error_text)):
       _TFRecordUtil.read_record(self._as_file_handle(record))
 
   def test_masked_crc32c(self):
-    self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c('\x00' * 32))
-    self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c('\xff' * 32))
-    self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c('foo'))
+    self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32))
+    self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32))
+    self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo'))
     self.assertEqual(
         0xe4999b0,
-        _TFRecordUtil._masked_crc32c('\x03\x00\x00\x00\x00\x00\x00\x00'))
+        _TFRecordUtil._masked_crc32c(b'\x03\x00\x00\x00\x00\x00\x00\x00'))
 
   def test_masked_crc32c_crcmod(self):
     crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
     self.assertEqual(
         0xfd7fffa,
         _TFRecordUtil._masked_crc32c(
-            '\x00' * 32, crc32c_fn=crc32c_fn))
+            b'\x00' * 32, crc32c_fn=crc32c_fn))
     self.assertEqual(
         0xf909b029,
         _TFRecordUtil._masked_crc32c(
-            '\xff' * 32, crc32c_fn=crc32c_fn))
+            b'\xff' * 32, crc32c_fn=crc32c_fn))
     self.assertEqual(
         0xfebe8a61, _TFRecordUtil._masked_crc32c(
-            'foo', crc32c_fn=crc32c_fn))
+            b'foo', crc32c_fn=crc32c_fn))
     self.assertEqual(
         0xe4999b0,
         _TFRecordUtil._masked_crc32c(
-            '\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn))
+            b'\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn))
 
   def test_write_record(self):
     file_handle = io.BytesIO()
-    _TFRecordUtil.write_record(file_handle, 'foo')
+    _TFRecordUtil.write_record(file_handle, b'foo')
     self.assertEqual(self.record, file_handle.getvalue())
 
   def test_read_record(self):
@@ -138,7 +142,7 @@ def test_read_record(self):
     self.assertEqual(b'foo', actual)
 
   def test_read_record_invalid_record(self):
-    self._test_error('bar', 'Not a valid TFRecord. Fewer than 12 bytes')
+    self._test_error(b'bar', 'Not a valid TFRecord. Fewer than 12 bytes')
 
   def test_read_record_invalid_length_mask(self):
     record = self._increment_value_at_index(self.record, 9)
@@ -149,7 +153,7 @@ def test_read_record_invalid_data_mask(self):
     self._test_error(record, 'Mismatch of data mask')
 
   def test_compatibility_read_write(self):
-    for record in ['', 'blah', 'another blah']:
+    for record in [b'', b'blah', b'another blah']:
       file_handle = io.BytesIO()
       _TFRecordUtil.write_record(file_handle, record)
       file_handle.seek(0)
@@ -176,9 +180,9 @@ def test_write_record_single(self):
           num_shards=0,
           shard_name_template=None,
           compression_type=CompressionTypes.UNCOMPRESSED)
-      self._write_lines(sink, path, ['foo'])
+      self._write_lines(sink, path, [b'foo'])
 
-      with open(path, 'r') as f:
+      with open(path, 'rb') as f:
         self.assertEqual(f.read(), record)
 
   def test_write_record_multiple(self):
@@ -192,9 +196,9 @@ def test_write_record_multiple(self):
           num_shards=0,
           shard_name_template=None,
           compression_type=CompressionTypes.UNCOMPRESSED)
-      self._write_lines(sink, path, ['foo', 'bar'])
+      self._write_lines(sink, path, [b'foo', b'bar'])
 
-      with open(path, 'r') as f:
+      with open(path, 'rb') as f:
         self.assertEqual(f.read(), record)
 
 
@@ -247,7 +251,7 @@ def test_process_single(self):
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO,
                       validate=True))
-        assert_that(result, equal_to(['foo']))
+        assert_that(result, equal_to([b'foo']))
 
   def test_process_multiple(self):
     with TempDir() as temp_dir:
@@ -260,7 +264,7 @@ def test_process_multiple(self):
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO,
                       validate=True))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
   def test_process_gzip(self):
     with TempDir() as temp_dir:
@@ -273,11 +277,8 @@ def test_process_gzip(self):
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.GZIP,
                       validate=True))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
-  @unittest.skipIf(sys.version_info[0] == 3,
-                   'This test halts test suite execution on Python 3. '
-                   'TODO: BEAM-5623')
   def test_process_auto(self):
     with TempDir() as temp_dir:
       path = temp_dir.create_temp_file('result.gz')
@@ -289,11 +290,8 @@ def test_process_auto(self):
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO,
                       validate=True))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
-  @unittest.skipIf(sys.version_info[0] == 3,
-                   'This test halts test suite execution on Python 3. '
-                   'TODO: BEAM-5623')
   def test_process_gzip(self):
     with TempDir() as temp_dir:
       path = temp_dir.create_temp_file('result')
@@ -302,11 +300,8 @@ def test_process_gzip(self):
         result = (p
                   | ReadFromTFRecord(
                       path, compression_type=CompressionTypes.GZIP))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
-  @unittest.skipIf(sys.version_info[0] == 3,
-                   'This test halts test suite execution on Python 3. '
-                   'TODO: BEAM-5623')
   def test_process_gzip_auto(self):
     with TempDir() as temp_dir:
       path = temp_dir.create_temp_file('result.gz')
@@ -315,7 +310,7 @@ def test_process_gzip_auto(self):
         result = (p
                   | ReadFromTFRecord(
                       path, compression_type=CompressionTypes.AUTO))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
 
 class TestReadAllFromTFRecord(unittest.TestCase):
@@ -335,7 +330,7 @@ def test_process_single(self):
                   | ReadAllFromTFRecord(
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO))
-        assert_that(result, equal_to(['foo']))
+        assert_that(result, equal_to([b'foo']))
 
   def test_process_multiple(self):
     with TempDir() as temp_dir:
@@ -347,7 +342,7 @@ def test_process_multiple(self):
                   | ReadAllFromTFRecord(
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
   def test_process_glob(self):
     with TempDir() as temp_dir:
@@ -359,7 +354,7 @@ def test_process_glob(self):
                   | ReadAllFromTFRecord(
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO))
-        assert_that(result, equal_to(['foo', 'bar'] * 3))
+        assert_that(result, equal_to([b'foo', b'bar'] * 3))
 
   def test_process_multiple_globs(self):
     with TempDir() as temp_dir:
@@ -375,11 +370,8 @@ def test_process_multiple_globs(self):
                   | ReadAllFromTFRecord(
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO))
-        assert_that(result, equal_to(['foo', 'bar'] * 9))
+        assert_that(result, equal_to([b'foo', b'bar'] * 9))
 
-  @unittest.skipIf(sys.version_info[0] == 3,
-                   'This test halts test suite execution on Python 3. '
-                   'TODO: BEAM-5623')
   def test_process_gzip(self):
     with TempDir() as temp_dir:
       path = temp_dir.create_temp_file('result')
@@ -390,11 +382,8 @@ def test_process_gzip(self):
                   | ReadAllFromTFRecord(
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.GZIP))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
-  @unittest.skipIf(sys.version_info[0] == 3,
-                   'This test halts test suite execution on Python 3. '
-                   'TODO: BEAM-5623')
   def test_process_auto(self):
     with TempDir() as temp_dir:
       path = temp_dir.create_temp_file('result.gz')
@@ -405,12 +394,9 @@ def test_process_auto(self):
                   | ReadAllFromTFRecord(
                       coder=coders.BytesCoder(),
                       compression_type=CompressionTypes.AUTO))
-        assert_that(result, equal_to(['foo', 'bar']))
+        assert_that(result, equal_to([b'foo', b'bar']))
 
 
-@unittest.skipIf(sys.version_info[0] == 3,
-                 'This test still needs to be fixed on Python 3'
-                 'TODO: BEAM-5623 - several IO tests hang indefinitely')
 class TestEnd2EndWriteAndRead(unittest.TestCase):
 
   def create_inputs(self):
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index 77ad5e537c77..40bad48e6a72 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -58,7 +58,7 @@ setenv =
   BEAM_EXPERIMENTAL_PY3=1
   RUN_SKIPPED_PY3_TESTS=0
 modules =
-  apache_beam.typehints,apache_beam.coders,apache_beam.options,apache_beam.tools,apache_beam.utils,apache_beam.internal,apache_beam.metrics,apache_beam.portability,apache_beam.pipeline_test,apache_beam.pvalue_test,apache_beam.runners,apache_beam.io.hadoopfilesystem_test,apache_beam.io.hdfs_integration_test,apache_beam.io.gcp.tests.utils_test,apache_beam.io.gcp.big_query_query_to_table_it_test,apache_beam.io.gcp.bigquery_io_read_it_test,apache_beam.io.gcp.bigquery_test,apache_beam.io.gcp.gcsfilesystem_test,apache_beam.io.gcp.gcsio_test,apache_beam.io.gcp.pubsub_integration_test,apache_beam.io.hdfs_integration_test,apache_beam.io.gcp.internal,apache_beam.io.filesystem_test,apache_beam.io.filesystems_test,apache_beam.io.range_trackers_test,apache_beam.io.sources_test,apache_beam.transforms,apache_beam.testing,apache_beam.io.filesystemio_test,apache_beam.io.localfilesystem_test,apache_beam.io.range_trackers_test,apache_beam.io.restriction_trackers_test,apache_beam.io.source_test_utils_test,apache_beam.io.concat_source_test,apache_beam.io.filebasedsink_test,apache_beam.io.filebasedsource_test,apache_beam.io.textio_test
+  apache_beam.typehints,apache_beam.coders,apache_beam.options,apache_beam.tools,apache_beam.utils,apache_beam.internal,apache_beam.metrics,apache_beam.portability,apache_beam.pipeline_test,apache_beam.pvalue_test,apache_beam.runners,apache_beam.io.hadoopfilesystem_test,apache_beam.io.hdfs_integration_test,apache_beam.io.gcp.tests.utils_test,apache_beam.io.gcp.big_query_query_to_table_it_test,apache_beam.io.gcp.bigquery_io_read_it_test,apache_beam.io.gcp.bigquery_test,apache_beam.io.gcp.gcsfilesystem_test,apache_beam.io.gcp.gcsio_test,apache_beam.io.gcp.pubsub_integration_test,apache_beam.io.hdfs_integration_test,apache_beam.io.gcp.internal,apache_beam.io.filesystem_test,apache_beam.io.filesystems_test,apache_beam.io.range_trackers_test,apache_beam.io.sources_test,apache_beam.transforms,apache_beam.testing,apache_beam.io.filesystemio_test,apache_beam.io.localfilesystem_test,apache_beam.io.range_trackers_test,apache_beam.io.restriction_trackers_test,apache_beam.io.source_test_utils_test,apache_beam.io.concat_source_test,apache_beam.io.filebasedsink_test,apache_beam.io.filebasedsource_test,apache_beam.io.textio_test,apache_beam.io.tfrecordio_test
 commands =
   python --version
   pip --version


With regards,
Apache Git Services