You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2020/10/02 01:31:24 UTC

[beam] branch master updated: [BEAM-10862] Handle empty tfrecord files within a glob (#12790)

This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ea3b987  [BEAM-10862] Handle empty tfrecord files within a glob (#12790)
ea3b987 is described below

commit ea3b98712de3c2cb390d42581d25454906094aad
Author: Curtis "Fjord" Hawthorne <cg...@gmail.com>
AuthorDate: Thu Oct 1 18:30:52 2020 -0700

    [BEAM-10862] Handle empty tfrecord files within a glob (#12790)
---
 sdks/python/apache_beam/io/filebasedsource.py |  8 ++++++--
 sdks/python/apache_beam/io/tfrecordio_test.py | 18 +++++++++++++++++-
 2 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py
index 105050f..a957b2e 100644
--- a/sdks/python/apache_beam/io/filebasedsource.py
+++ b/sdks/python/apache_beam/io/filebasedsource.py
@@ -375,9 +375,13 @@ class _ReadRange(DoFn):
     source = self._source_from_file(metadata.path)
     # Following split() operation has to be performed to create a proper
     # _SingleFileSource. Otherwise what we have is a ConcatSource that contains
-    # a single _SingleFileSource. ConcatSource.read() expects a RangeTraker for
+    # a single _SingleFileSource. ConcatSource.read() expects a RangeTracker for
     # sub-source range and reads full sub-sources (not byte ranges).
-    source = list(source.split(float('inf')))[0].source
+    source_list = list(source.split(float('inf')))
+    # Handle the case of an empty source.
+    if not source_list:
+      return
+    source = source_list[0].source
     for record in source.read(range.new_tracker()):
       yield record
 
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index f97df56..dca5a5e 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -338,10 +338,13 @@ class TestReadFromTFRecord(unittest.TestCase):
 
 
 class TestReadAllFromTFRecord(unittest.TestCase):
-  def _write_glob(self, temp_dir, suffix):
+  def _write_glob(self, temp_dir, suffix, include_empty=False):
     for _ in range(3):
       path = temp_dir.create_temp_file(suffix)
       _write_file(path, FOO_BAR_RECORD_BASE64)
+    if include_empty:
+      path = temp_dir.create_temp_file(suffix)
+      _write_file(path, '')
 
   def test_process_single(self):
     with TempDir() as temp_dir:
@@ -382,6 +385,19 @@ class TestReadAllFromTFRecord(unittest.TestCase):
                 compression_type=CompressionTypes.AUTO))
         assert_that(result, equal_to([b'foo', b'bar'] * 3))
 
+  def test_process_glob_with_empty_file(self):
+    with TempDir() as temp_dir:
+      self._write_glob(temp_dir, 'result', include_empty=True)
+      glob = temp_dir.get_path() + os.path.sep + '*result'
+      with TestPipeline() as p:
+        result = (
+            p
+            | Create([glob])
+            | ReadAllFromTFRecord(
+                coder=coders.BytesCoder(),
+                compression_type=CompressionTypes.AUTO))
+        assert_that(result, equal_to([b'foo', b'bar'] * 3))
+
   def test_process_multiple_globs(self):
     with TempDir() as temp_dir:
       globs = []