You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tv...@apache.org on 2020/10/02 01:31:24 UTC
[beam] branch master updated: [BEAM-10862] Handle empty tfrecord
files within a glob (#12790)
This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ea3b987 [BEAM-10862] Handle empty tfrecord files within a glob (#12790)
ea3b987 is described below
commit ea3b98712de3c2cb390d42581d25454906094aad
Author: Curtis "Fjord" Hawthorne <cg...@gmail.com>
AuthorDate: Thu Oct 1 18:30:52 2020 -0700
[BEAM-10862] Handle empty tfrecord files within a glob (#12790)
---
sdks/python/apache_beam/io/filebasedsource.py | 8 ++++++--
sdks/python/apache_beam/io/tfrecordio_test.py | 18 +++++++++++++++++-
2 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py
index 105050f..a957b2e 100644
--- a/sdks/python/apache_beam/io/filebasedsource.py
+++ b/sdks/python/apache_beam/io/filebasedsource.py
@@ -375,9 +375,13 @@ class _ReadRange(DoFn):
source = self._source_from_file(metadata.path)
# Following split() operation has to be performed to create a proper
# _SingleFileSource. Otherwise what we have is a ConcatSource that contains
- # a single _SingleFileSource. ConcatSource.read() expects a RangeTraker for
+ # a single _SingleFileSource. ConcatSource.read() expects a RangeTracker for
# sub-source range and reads full sub-sources (not byte ranges).
- source = list(source.split(float('inf')))[0].source
+ source_list = list(source.split(float('inf')))
+ # Handle the case of an empty source.
+ if not source_list:
+ return
+ source = source_list[0].source
for record in source.read(range.new_tracker()):
yield record
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index f97df56..dca5a5e 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -338,10 +338,13 @@ class TestReadFromTFRecord(unittest.TestCase):
class TestReadAllFromTFRecord(unittest.TestCase):
- def _write_glob(self, temp_dir, suffix):
+ def _write_glob(self, temp_dir, suffix, include_empty=False):
for _ in range(3):
path = temp_dir.create_temp_file(suffix)
_write_file(path, FOO_BAR_RECORD_BASE64)
+ if include_empty:
+ path = temp_dir.create_temp_file(suffix)
+ _write_file(path, '')
def test_process_single(self):
with TempDir() as temp_dir:
@@ -382,6 +385,19 @@ class TestReadAllFromTFRecord(unittest.TestCase):
compression_type=CompressionTypes.AUTO))
assert_that(result, equal_to([b'foo', b'bar'] * 3))
+ def test_process_glob_with_empty_file(self):
+ with TempDir() as temp_dir:
+ self._write_glob(temp_dir, 'result', include_empty=True)
+ glob = temp_dir.get_path() + os.path.sep + '*result'
+ with TestPipeline() as p:
+ result = (
+ p
+ | Create([glob])
+ | ReadAllFromTFRecord(
+ coder=coders.BytesCoder(),
+ compression_type=CompressionTypes.AUTO))
+ assert_that(result, equal_to([b'foo', b'bar'] * 3))
+
def test_process_multiple_globs(self):
with TempDir() as temp_dir:
globs = []