You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2016/09/13 00:07:29 UTC
[1/3] incubator-beam git commit: Closes #940
Repository: incubator-beam
Updated Branches:
refs/heads/python-sdk 2649372d6 -> bc32bc866
Closes #940
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/bc32bc86
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/bc32bc86
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/bc32bc86
Branch: refs/heads/python-sdk
Commit: bc32bc8661b31036a879cb2b569b1fe976e2ac4b
Parents: 2649372 6a483c1
Author: Dan Halperin <dh...@google.com>
Authored: Mon Sep 12 17:07:22 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Sep 12 17:07:22 2016 -0700
----------------------------------------------------------------------
sdks/python/apache_beam/io/fileio.py | 892 ++++++++++++-------
sdks/python/apache_beam/io/fileio_test.py | 535 +++++++++--
sdks/python/apache_beam/io/gcsio.py | 93 +-
sdks/python/apache_beam/io/gcsio_test.py | 147 +--
.../runners/inprocess/inprocess_runner_test.py | 9 +-
5 files changed, 1174 insertions(+), 502 deletions(-)
----------------------------------------------------------------------
[2/3] incubator-beam git commit: Making Dataflow Python Materialized
PCollection representation more efficient (3 of several).
Posted by dh...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6a483c18/sdks/python/apache_beam/io/gcsio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcsio_test.py b/sdks/python/apache_beam/io/gcsio_test.py
index 1e2c50e..919e9d2 100644
--- a/sdks/python/apache_beam/io/gcsio_test.py
+++ b/sdks/python/apache_beam/io/gcsio_test.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
"""Tests for Google Cloud Storage client."""
import logging
@@ -49,10 +48,11 @@ class FakeFile(object):
self.generation = generation
def get_metadata(self):
- return storage.Object(bucket=self.bucket,
- name=self.object,
- generation=self.generation,
- size=len(self.contents))
+ return storage.Object(
+ bucket=self.bucket,
+ name=self.object,
+ generation=self.generation,
+ size=len(self.contents))
class FakeGcsObjects(object):
@@ -81,7 +81,7 @@ class FakeGcsObjects(object):
f = self.get_file(get_request.bucket, get_request.object)
if f is None:
# Failing with a HTTP 404 if file does not exist.
- raise HttpError({'status':404}, None, None)
+ raise HttpError({'status': 404}, None, None)
if download is None:
return f.get_metadata()
else:
@@ -90,6 +90,7 @@ class FakeGcsObjects(object):
def get_range_callback(start, end):
assert start >= 0 and end >= start and end < len(f.contents)
stream.write(f.contents[start:end + 1])
+
download.GetRange = get_range_callback
def Insert(self, insert_request, upload=None): # pylint: disable=invalid-name
@@ -114,13 +115,14 @@ class FakeGcsObjects(object):
src_file = self.get_file(copy_request.sourceBucket,
copy_request.sourceObject)
if not src_file:
- raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found',
- 'https://fake/url')
+ raise HttpError(
+ httplib2.Response({'status': '404'}), '404 Not Found',
+ 'https://fake/url')
generation = self.get_last_generation(copy_request.destinationBucket,
copy_request.destinationObject) + 1
dest_file = FakeFile(copy_request.destinationBucket,
- copy_request.destinationObject,
- src_file.contents, generation)
+ copy_request.destinationObject, src_file.contents,
+ generation)
self.add_file(dest_file)
def Delete(self, delete_request): # pylint: disable=invalid-name
@@ -129,8 +131,9 @@ class FakeGcsObjects(object):
if self.get_file(delete_request.bucket, delete_request.object):
self.delete_file(delete_request.bucket, delete_request.object)
else:
- raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found',
- 'https://fake/url')
+ raise HttpError(
+ httplib2.Response({'status': '404'}), '404 Not Found',
+ 'https://fake/url')
def List(self, list_request): # pylint: disable=invalid-name
bucket = list_request.bucket
@@ -202,7 +205,7 @@ class TestGCSIO(unittest.TestCase):
def test_exists_failure(self, mock_get):
# Raising an error other than 404. Raising 404 is a valid failure for
# exists() call.
- mock_get.side_effect = HttpError({'status':400}, None, None)
+ mock_get.side_effect = HttpError({'status': 400}, None, None)
file_name = 'gs://gcsio-test/dummy_file'
file_size = 1234
self._insert_random_file(self.client, file_name, file_size)
@@ -240,34 +243,32 @@ class TestGCSIO(unittest.TestCase):
self.gcs.delete(file_name)
self._insert_random_file(self.client, file_name, file_size)
- self.assertTrue(gcsio.parse_gcs_path(file_name) in
- self.client.objects.files)
+ self.assertTrue(
+ gcsio.parse_gcs_path(file_name) in self.client.objects.files)
self.gcs.delete(file_name)
- self.assertFalse(gcsio.parse_gcs_path(file_name) in
- self.client.objects.files)
+ self.assertFalse(
+ gcsio.parse_gcs_path(file_name) in self.client.objects.files)
def test_copy(self):
src_file_name = 'gs://gcsio-test/source'
dest_file_name = 'gs://gcsio-test/dest'
file_size = 1024
- self._insert_random_file(self.client, src_file_name,
- file_size)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
+ self._insert_random_file(self.client, src_file_name, file_size)
+ self.assertTrue(
+ gcsio.parse_gcs_path(src_file_name) in self.client.objects.files)
+ self.assertFalse(
+ gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files)
self.gcs.copy(src_file_name, dest_file_name)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
+ self.assertTrue(
+ gcsio.parse_gcs_path(src_file_name) in self.client.objects.files)
+ self.assertTrue(
+ gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files)
- self.assertRaises(IOError, self.gcs.copy,
- 'gs://gcsio-test/non-existent',
+ self.assertRaises(IOError, self.gcs.copy, 'gs://gcsio-test/non-existent',
'gs://gcsio-test/non-existent-destination')
def test_copytree(self):
@@ -278,46 +279,45 @@ class TestGCSIO(unittest.TestCase):
for path in paths:
src_file_name = src_dir_name + path
dest_file_name = dest_dir_name + path
- self._insert_random_file(self.client, src_file_name,
- file_size)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
+ self._insert_random_file(self.client, src_file_name, file_size)
+ self.assertTrue(
+ gcsio.parse_gcs_path(src_file_name) in self.client.objects.files)
+ self.assertFalse(
+ gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files)
self.gcs.copytree(src_dir_name, dest_dir_name)
for path in paths:
src_file_name = src_dir_name + path
dest_file_name = dest_dir_name + path
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
+ self.assertTrue(
+ gcsio.parse_gcs_path(src_file_name) in self.client.objects.files)
+ self.assertTrue(
+ gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files)
def test_rename(self):
src_file_name = 'gs://gcsio-test/source'
dest_file_name = 'gs://gcsio-test/dest'
file_size = 1024
- self._insert_random_file(self.client, src_file_name,
- file_size)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
+ self._insert_random_file(self.client, src_file_name, file_size)
+ self.assertTrue(
+ gcsio.parse_gcs_path(src_file_name) in self.client.objects.files)
+ self.assertFalse(
+ gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files)
self.gcs.rename(src_file_name, dest_file_name)
- self.assertFalse(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
+ self.assertFalse(
+ gcsio.parse_gcs_path(src_file_name) in self.client.objects.files)
+ self.assertTrue(
+ gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files)
def test_full_file_read(self):
file_name = 'gs://gcsio-test/full_file'
file_size = 5 * 1024 * 1024 + 100
random_file = self._insert_random_file(self.client, file_name, file_size)
f = self.gcs.open(file_name)
+ self.assertEqual(f.mode, 'r')
f.seek(0, os.SEEK_END)
self.assertEqual(f.tell(), file_size)
self.assertEqual(f.read(), '')
@@ -337,8 +337,8 @@ class TestGCSIO(unittest.TestCase):
start, end = min(a, b), max(a, b)
f.seek(start)
self.assertEqual(f.tell(), start)
- self.assertEqual(f.read(end - start + 1),
- random_file.contents[start:end + 1])
+ self.assertEqual(
+ f.read(end - start + 1), random_file.contents[start:end + 1])
self.assertEqual(f.tell(), end + 1)
def test_file_read_line(self):
@@ -398,6 +398,7 @@ class TestGCSIO(unittest.TestCase):
file_size = 5 * 1024 * 1024 + 2000
contents = os.urandom(file_size)
f = self.gcs.open(file_name, 'w')
+ self.assertEqual(f.mode, 'w')
f.write(contents[0:1000])
f.write(contents[1000:1024 * 1024])
f.write(contents[1024 * 1024:])
@@ -406,6 +407,36 @@ class TestGCSIO(unittest.TestCase):
self.assertEqual(
self.client.objects.get_file(bucket, name).contents, contents)
+ def test_file_close(self):
+ file_name = 'gs://gcsio-test/close_file'
+ file_size = 5 * 1024 * 1024 + 2000
+ contents = os.urandom(file_size)
+ f = self.gcs.open(file_name, 'w')
+ self.assertEqual(f.mode, 'w')
+ f.write(contents)
+ f.close()
+ f.close() # This should not crash.
+ bucket, name = gcsio.parse_gcs_path(file_name)
+ self.assertEqual(
+ self.client.objects.get_file(bucket, name).contents, contents)
+
+ def test_file_flush(self):
+ file_name = 'gs://gcsio-test/flush_file'
+ file_size = 5 * 1024 * 1024 + 2000
+ contents = os.urandom(file_size)
+ bucket, name = gcsio.parse_gcs_path(file_name)
+ f = self.gcs.open(file_name, 'w')
+ self.assertEqual(f.mode, 'w')
+ f.write(contents[0:1000])
+ f.flush()
+ f.write(contents[1000:1024 * 1024])
+ f.flush()
+ f.flush() # Should be a NOOP.
+ f.write(contents[1024 * 1024:])
+ f.close() # This should already call the equivalent of flush() in its body.
+ self.assertEqual(
+ self.client.objects.get_file(bucket, name).contents, contents)
+
def test_context_manager(self):
# Test writing with a context manager.
file_name = 'gs://gcsio-test/context_manager_file'
@@ -496,10 +527,10 @@ class TestGCSIO(unittest.TestCase):
]),
]
for file_pattern, expected_object_names in test_cases:
- expected_file_names = ['gs://%s/%s' % (bucket_name, o) for o in
- expected_object_names]
- self.assertEqual(set(self.gcs.glob(file_pattern)),
- set(expected_file_names))
+ expected_file_names = ['gs://%s/%s' % (bucket_name, o)
+ for o in expected_object_names]
+ self.assertEqual(
+ set(self.gcs.glob(file_pattern)), set(expected_file_names))
class TestPipeStream(unittest.TestCase):
@@ -525,7 +556,7 @@ class TestPipeStream(unittest.TestCase):
self.assertEqual(''.join(data_list), expected)
def test_pipe_stream(self):
- block_sizes = list(4 ** i for i in range(0, 12))
+ block_sizes = list(4**i for i in range(0, 12))
data_blocks = list(os.urandom(size) for size in block_sizes)
expected = ''.join(data_blocks)
@@ -534,8 +565,8 @@ class TestPipeStream(unittest.TestCase):
for buffer_size in buffer_sizes:
parent_conn, child_conn = multiprocessing.Pipe()
stream = gcsio.GcsBufferedWriter.PipeStream(child_conn)
- child_thread = threading.Thread(target=self._read_and_verify,
- args=(stream, expected, buffer_size))
+ child_thread = threading.Thread(
+ target=self._read_and_verify, args=(stream, expected, buffer_size))
child_thread.start()
for data in data_blocks:
parent_conn.send_bytes(data)
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6a483c18/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py b/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py
index 4e84147..3ab8383 100644
--- a/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py
+++ b/sdks/python/apache_beam/runners/inprocess/inprocess_runner_test.py
@@ -87,9 +87,12 @@ class TestTextFileSource(
pass
-class NativeTestTextFileSink(
- TestWithInProcessPipelineRunner, fileio_test.NativeTestTextFileSink):
- pass
+class TestNativeTextFileSink(
+ TestWithInProcessPipelineRunner, fileio_test.TestNativeTextFileSink):
+
+ def setUp(self):
+ TestWithInProcessPipelineRunner.setUp(self)
+ fileio_test.TestNativeTextFileSink.setUp(self)
class TestTextFileSink(
[3/3] incubator-beam git commit: Making Dataflow Python Materialized
PCollection representation more efficient (3 of several).
Posted by dh...@apache.org.
Making Dataflow Python Materialized PCollection representation more
efficient (3 of several).
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/6a483c18
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/6a483c18
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/6a483c18
Branch: refs/heads/python-sdk
Commit: 6a483c187dff5d436736f1ff505187c727a95b38
Parents: 2649372
Author: Gus Katsiapis <ka...@katsiapis-linux.mtv.corp.google.com>
Authored: Fri Sep 9 15:34:39 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Sep 12 17:07:22 2016 -0700
----------------------------------------------------------------------
sdks/python/apache_beam/io/fileio.py | 892 ++++++++++++-------
sdks/python/apache_beam/io/fileio_test.py | 535 +++++++++--
sdks/python/apache_beam/io/gcsio.py | 93 +-
sdks/python/apache_beam/io/gcsio_test.py | 147 +--
.../runners/inprocess/inprocess_runner_test.py | 9 +-
5 files changed, 1174 insertions(+), 502 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6a483c18/sdks/python/apache_beam/io/fileio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py
index bfa246f..bc93138 100644
--- a/sdks/python/apache_beam/io/fileio.py
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
"""File-based sources and sinks."""
from __future__ import absolute_import
@@ -34,17 +33,308 @@ from apache_beam import coders
from apache_beam.io import iobase
from apache_beam.io import range_trackers
-
__all__ = ['TextFileSource', 'TextFileSink']
DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
+class _CompressionType(object):
+ """Object representing single compression type."""
+
+ def __init__(self, identifier):
+ self.identifier = identifier
+
+ def __eq__(self, other):
+ return (isinstance(other, _CompressionType) and
+ self.identifier == other.identifier)
+
+ def __hash__(self):
+ return hash(self.identifier)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __repr__(self):
+ return '_CompressionType(%s)' % self.identifier
+
+
+class CompressionTypes(object):
+ """Enum-like class representing known compression types."""
+ AUTO = _CompressionType(1) # Detect compression based on filename extension.
+ GZIP = _CompressionType(2) # gzip compression (deflate with gzip headers).
+ ZLIB = _CompressionType(3) # zlib compression (deflate with zlib headers).
+ UNCOMPRESSED = _CompressionType(4) # Uncompressed (i.e., may be split).
+
+ # TODO: Remove this backwards-compatibility soon.
+ NO_COMPRESSION = _CompressionType(4) # Deprecated. Use UNCOMPRESSED instead.
+ assert NO_COMPRESSION == UNCOMPRESSED
+
+ @classmethod
+ def is_valid_compression_type(cls, compression_type):
+ """Returns true for valid compression types, false otherwise."""
+ return isinstance(compression_type, _CompressionType)
+
+ @classmethod
+ def mime_type(cls, compression_type, default='application/octet-stream'):
+ mime_types_by_compression_type = {
+ cls.GZIP: 'application/x-gzip',
+ cls.ZLIB: 'application/octet-stream'
+ }
+ return mime_types_by_compression_type.get(compression_type, default)
+
+ @classmethod
+ def detect_compression_type(cls, file_path):
+ """Returns the compression type of a file (based on its suffix)"""
+ compression_types_by_suffix = {'.gz': cls.GZIP, '.z': cls.ZLIB}
+ lowercased_path = file_path.lower()
+ for suffix, compression_type in compression_types_by_suffix.iteritems():
+ if lowercased_path.endswith(suffix):
+ return compression_type
+ return cls.UNCOMPRESSED
+
+
+class NativeFileSource(iobase.NativeSource):
+ """A source implemented by Dataflow service from a GCS or local file or files.
+
+ This class is to be only inherited by sources natively implemented by Cloud
+ Dataflow service, hence should not be sub-classed by users.
+ """
+
+ def __init__(self,
+ file_path,
+ start_offset=None,
+ end_offset=None,
+ coder=coders.BytesCoder(),
+ compression_type=CompressionTypes.AUTO,
+ mime_type='application/octet-stream'):
+ """Initialize a NativeFileSource.
+
+ Args:
+ file_path: The file path to read from as a local file path or a GCS
+ gs:// path. The path can contain glob characters (*, ?, and [...]
+ sets).
+ start_offset: The byte offset in the source file that the reader
+ should start reading. By default is 0 (beginning of file).
+ end_offset: The byte offset in the file that the reader should stop
+ reading. By default it is the end of the file.
+ compression_type: Used to handle compressed input files. Typical value
+ is CompressionTypes.AUTO, in which case the file_path's extension will
+ be used to detect the compression.
+ coder: Coder used to decode each record.
+
+ Raises:
+ TypeError: if file_path is not a string.
+
+ If the file_path contains glob characters then the start_offset and
+ end_offset must not be specified.
+
+ The 'start_offset' and 'end_offset' pair provide a mechanism to divide the
+ file into multiple pieces for individual sources. Because the offset
+ is measured by bytes, some complication arises when the offset splits in
+ the middle of a record. To avoid the scenario where two adjacent sources
+ each get a fraction of a line we adopt the following rules:
+
+ If start_offset falls inside a record (any character except the first one)
+ then the source will skip the record and start with the next one.
+
+ If end_offset falls inside a record (any character except the first one)
+ then the source will contain that entire record.
+ """
+ if not isinstance(file_path, basestring):
+ raise TypeError('%s: file_path must be a string; got %r instead' %
+ (self.__class__.__name__, file_path))
+
+ self.file_path = file_path
+ self.start_offset = start_offset
+ self.end_offset = end_offset
+ self.compression_type = compression_type
+ self.coder = coder
+ self.mime_type = mime_type
+
+ def __eq__(self, other):
+ # TODO: Remove this backwards-compatibility soon.
+ def equiv_autos(lhs, rhs):
+ return ((lhs == 'AUTO' and rhs == CompressionTypes.AUTO) or
+ (lhs == CompressionTypes.AUTO and rhs == 'AUTO'))
+
+ return (self.file_path == other.file_path and
+ self.start_offset == other.start_offset and
+ self.end_offset == other.end_offset and
+ (self.compression_type == other.compression_type or
+ equiv_autos(self.compression_type, other.compression_type)) and
+ self.coder == other.coder and self.mime_type == other.mime_type)
+
+ @property
+ def path(self):
+ return self.file_path
+
+ def reader(self):
+ return NativeFileSourceReader(self)
+
+
+class NativeFileSourceReader(iobase.NativeSourceReader,
+ coders.observable.ObservableMixin):
+ """The source reader for a NativeFileSource.
+
+ This class is to be only inherited by source readers natively implemented by
+ Cloud Dataflow service, hence should not be sub-classed by users.
+ """
+
+ def __init__(self, source):
+ super(NativeFileSourceReader, self).__init__()
+ self.source = source
+ self.start_offset = self.source.start_offset or 0
+ self.end_offset = self.source.end_offset
+ self.current_offset = self.start_offset
+
+ def __enter__(self):
+ self.file = ChannelFactory.open(
+ self.source.file_path,
+ 'rb',
+ mime_type=self.source.mime_type,
+ compression_type=self.source.compression_type)
+
+ # Determine the real end_offset.
+ #
+ # If not specified or if the source is not splittable it will be the length
+ # of the file (or infinity for compressed files) as appropriate.
+ if ChannelFactory.is_compressed(self.file):
+ if not isinstance(self.source, TextFileSource):
+ raise ValueError('Unexpected compressed file for a non-TextFileSource.')
+ self.end_offset = range_trackers.OffsetRangeTracker.OFFSET_INFINITY
+
+ elif self.end_offset is None:
+ self.file.seek(0, os.SEEK_END)
+ self.end_offset = self.file.tell()
+ self.file.seek(self.start_offset)
+
+ # Initializing range tracker after self.end_offset is finalized.
+ self.range_tracker = range_trackers.OffsetRangeTracker(self.start_offset,
+ self.end_offset)
+
+ # Position to the appropriate start_offset.
+ if self.start_offset > 0:
+ if ChannelFactory.is_compressed(self.file):
+ # TODO: Turns this warning into an exception soon.
+ logging.warning(
+ 'Encountered initial split starting at (%s) for compressed source.',
+ self.start_offset)
+ self.seek_to_true_start_offset()
+
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self.file.close()
+
+ def __iter__(self):
+ if self.current_offset > 0 and ChannelFactory.is_compressed(self.file):
+ # When compression is enabled both initial and dynamic splitting should be
+ # prevented. Here we prevent initial splitting by ignoring all splits
+ # other than the split that starts at byte 0.
+ #
+ # TODO: Turns this warning into an exception soon.
+ logging.warning('Ignoring split starting at (%s) for compressed source.',
+ self.current_offset)
+ return
+
+ while True:
+ if not self.range_tracker.try_claim(record_start=self.current_offset):
+ # Reader has completed reading the set of records in its range. Note
+ # that the end offset of the range may be smaller than the original
+ # end offset defined when creating the reader due to reader accepting
+ # a dynamic split request from the service.
+ return
+
+ # Note that for compressed sources, delta_offsets are virtual and don't
+ # actually correspond to byte offsets in the underlying file. They
+ # nonetheless correspond to unique virtual position locations.
+ for eof, record, delta_offset in self.read_records():
+ if eof:
+ # Can't read from this source anymore and the record and delta_offset
+ # are non-sensical; hence we are done.
+ return
+ else:
+ self.notify_observers(record, is_encoded=False)
+ self.current_offset += delta_offset
+ yield record
+
+ def seek_to_true_start_offset(self):
+ """Seeks the underlying file to the appropriate start_offset that is
+ compatible with range-tracking and position models and updates
+ self.current_offset accordingly.
+ """
+ raise NotImplementedError
+
+ def read_records(self):
+ """
+ Yields information about (possibly multiple) records corresponding to
+ self.current_offset
+
+ If a read_records() invocation returns multiple results, the first record
+ must be at a split point and other records should not be at split points.
+ The first record is assumed to be at self.current_offset and the caller
+ should use the yielded delta_offsets to update self.current_offset
+ accordingly.
+
+ The yielded value is a tripplet of the form:
+ eof, record, delta_offset
+ eof: A boolean indicating whether eof has been reached, in which case
+ the contents of record and delta_offset cannot be trusted or used.
+ record: The (possibly decoded) record (ie payload) read from the
+ underlying source.
+ delta_offset: The delta_offfset (from self.current_offset) in bytes, that
+ has been consumed from the underlying source, to the starting position
+ of the next record (or EOF if no record exists).
+ """
+ raise NotImplementedError
+
+ def get_progress(self):
+ return iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=self.range_tracker.last_record_start))
+
+ def request_dynamic_split(self, dynamic_split_request):
+ if ChannelFactory.is_compressed(self.file):
+ # When compression is enabled both initial and dynamic splitting should be
+ # prevented. Here we prevent dynamic splitting by ignoring all dynamic
+ # split requests at the reader.
+ #
+ # TODO: Turns this warning into an exception soon.
+ logging.warning('FileBasedReader cannot be split since it is compressed. '
+ 'Requested: %r', dynamic_split_request)
+ return
+
+ assert dynamic_split_request is not None
+ progress = dynamic_split_request.progress
+ split_position = progress.position
+ if split_position is None:
+ percent_complete = progress.percent_complete
+ if percent_complete is not None:
+ if percent_complete <= 0 or percent_complete >= 1:
+ logging.warning(
+ 'FileBasedReader cannot be split since the provided percentage '
+ 'of work to be completed is out of the valid range (0, '
+ '1). Requested: %r', dynamic_split_request)
+ return
+ split_position = iobase.ReaderPosition()
+ split_position.byte_offset = (
+ self.range_tracker.position_at_fraction(percent_complete))
+ else:
+ logging.warning(
+ 'FileBasedReader requires either a position or a percentage of '
+ 'work to be complete to perform a dynamic split request. '
+ 'Requested: %r', dynamic_split_request)
+ return
+
+ if self.range_tracker.try_split(split_position.byte_offset):
+ return iobase.DynamicSplitResultWithPosition(split_position)
+ else:
+ return
+
# -----------------------------------------------------------------------------
# TextFileSource, TextFileSink.
-class TextFileSource(iobase.NativeSource):
+class TextFileSource(NativeFileSource):
"""A source for a GCS or local text file.
Parses a text file as newline-delimited elements, by default assuming
@@ -54,9 +344,14 @@ class TextFileSource(iobase.NativeSource):
ASCII. This has not been tested for other encodings such as UTF-16 or UTF-32.
"""
- def __init__(self, file_path, start_offset=None, end_offset=None,
- compression_type='AUTO', strip_trailing_newlines=True,
- coder=coders.StrUtf8Coder()):
+ def __init__(self,
+ file_path,
+ start_offset=None,
+ end_offset=None,
+ compression_type=CompressionTypes.AUTO,
+ strip_trailing_newlines=True,
+ coder=coders.StrUtf8Coder(),
+ mime_type='text/plain'):
"""Initialize a TextSource.
Args:
@@ -68,7 +363,8 @@ class TextFileSource(iobase.NativeSource):
end_offset: The byte offset in the file that the reader should stop
reading. By default it is the end of the file.
compression_type: Used to handle compressed input files. Typical value
- is 'AUTO'.
+ is CompressionTypes.AUTO, in which case the file_path's extension will
+ be used to detect the compression.
strip_trailing_newlines: Indicates whether this source should remove
the newline char in each line it reads before decoding that line.
This feature only works for ASCII and UTF-8 encoded input.
@@ -86,25 +382,19 @@ class TextFileSource(iobase.NativeSource):
the middle of a text line. To avoid the scenario where two adjacent sources
each get a fraction of a line we adopt the following rules:
- If start_offset falls inside a line (any character except the firt one)
+ If start_offset falls inside a line (any character except the first one)
then the source will skip the line and start with the next one.
If end_offset falls inside a line (any character except the first one) then
the source will contain that entire line.
"""
- if not isinstance(file_path, basestring):
- raise TypeError(
- '%s: file_path must be a string; got %r instead' %
- (self.__class__.__name__, file_path))
-
- self.file_path = file_path
- self.start_offset = start_offset
- self.end_offset = end_offset
- self.compression_type = compression_type
+ super(TextFileSource, self).__init__(
+ file_path,
+ start_offset=start_offset,
+ end_offset=end_offset,
+ coder=coder,
+ compression_type=compression_type)
self.strip_trailing_newlines = strip_trailing_newlines
- self.coder = coder
-
- self.is_gcs_source = file_path.startswith('gs://')
@property
def format(self):
@@ -112,15 +402,8 @@ class TextFileSource(iobase.NativeSource):
return 'text'
def __eq__(self, other):
- return (self.file_path == other.file_path and
- self.start_offset == other.start_offset and
- self.end_offset == other.end_offset and
- self.strip_trailing_newlines == other.strip_trailing_newlines and
- self.coder == other.coder)
-
- @property
- def path(self):
- return self.file_path
+ return (super(TextFileSource, self).__eq__(other) and
+ self.strip_trailing_newlines == other.strip_trailing_newlines)
def reader(self):
# If a multi-file pattern was specified as a source then make sure the
@@ -140,7 +423,7 @@ class TextFileSource(iobase.NativeSource):
class ChannelFactory(object):
- # TODO(robertwb): Generalize into extensible framework.
+ # TODO: Generalize into extensible framework.
@staticmethod
def mkdir(path):
@@ -153,13 +436,38 @@ class ChannelFactory(object):
raise IOError(err)
@staticmethod
- def open(path, mode, mime_type):
+ def open(path,
+ mode,
+ mime_type='application/octet-stream',
+ compression_type=CompressionTypes.AUTO):
+ if compression_type == CompressionTypes.AUTO:
+ compression_type = CompressionTypes.detect_compression_type(path)
+ elif compression_type == 'AUTO':
+ # TODO: Remove this backwards-compatibility soon.
+ compression_type = CompressionTypes.detect_compression_type(path)
+ else:
+ if not CompressionTypes.is_valid_compression_type(compression_type):
+ raise TypeError('compression_type must be CompressionType object but '
+ 'was %s' % type(compression_type))
+
if path.startswith('gs://'):
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.io import gcsio
- return gcsio.GcsIO().open(path, mode, mime_type=mime_type)
+ raw_file = gcsio.GcsIO().open(
+ path,
+ mode,
+ mime_type=CompressionTypes.mime_type(compression_type, mime_type))
+ else:
+ raw_file = open(path, mode)
+
+ if compression_type == CompressionTypes.UNCOMPRESSED:
+ return raw_file
else:
- return open(path, mode)
+ return _CompressedFile(raw_file, compression_type=compression_type)
+
+ @staticmethod
+ def is_compressed(fileobj):
+ return isinstance(fileobj, _CompressedFile)
@staticmethod
def rename(src, dst):
@@ -208,7 +516,7 @@ class ChannelFactory(object):
gcs = gcsio.GcsIO()
if not path.endswith('/'):
path += '/'
- # TODO(robertwb): Threadpool?
+ # TODO: Threadpool?
for entry in gcs.glob(path + '*'):
gcs.delete(entry)
else:
@@ -253,148 +561,121 @@ class ChannelFactory(object):
return os.path.getsize(path)
-class _CompressionType(object):
- """Object representing single compression type."""
-
- def __init__(self, identifier):
- self.identifier = identifier
-
- def __eq__(self, other):
- return (isinstance(other, _CompressionType) and
- self.identifier == other.identifier)
-
- def __hash__(self):
- return hash(self.identifier)
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __repr__(self):
- return '_CompressionType(%s)' % self.identifier
-
-
-class CompressionTypes(object):
- """Enum-like class representing known compression types."""
- NO_COMPRESSION = _CompressionType(1) # No compression.
- DEFLATE = _CompressionType(2) # 'Deflate' compression (without headers).
- GZIP = _CompressionType(3) # gzip compression (deflate with gzip headers).
- ZLIB = _CompressionType(4) # zlib compression (deflate with zlib headers).
-
- @staticmethod
- def is_valid_compression_type(compression_type):
- """Returns true for valid compression types, false otherwise."""
- return isinstance(compression_type, _CompressionType)
-
- @staticmethod
- def mime_type(compression_type, default='application/octet-stream'):
- if compression_type == CompressionTypes.GZIP:
- return 'application/x-gzip'
- elif compression_type == CompressionTypes.ZLIB:
- return 'application/octet-stream'
- elif compression_type == CompressionTypes.DEFLATE:
- return 'application/octet-stream'
- else:
- return default
-
-
class _CompressedFile(object):
"""Somewhat limited file wrapper for easier handling of compressed files."""
_type_mask = {
- CompressionTypes.ZLIB: zlib.MAX_WBITS,
CompressionTypes.GZIP: zlib.MAX_WBITS | 16,
- CompressionTypes.DEFLATE: -zlib.MAX_WBITS,
+ CompressionTypes.ZLIB: zlib.MAX_WBITS,
}
def __init__(self,
- fileobj=None,
- compression_type=CompressionTypes.ZLIB,
+ fileobj,
+ compression_type=CompressionTypes.GZIP,
read_size=16384):
- self._validate_compression_type(compression_type)
if not fileobj:
raise ValueError('fileobj must be opened file but was %s' % fileobj)
+ self._validate_compression_type(compression_type)
+
+ self._file = fileobj
+ self._data = ''
+ self._read_size = read_size
+ self._compression_type = compression_type
- self.fileobj = fileobj
- self.data = ''
- self.read_size = read_size
- self.compression_type = compression_type
if self._readable():
- self.decompressor = self._create_decompressor(self.compression_type)
+ self._decompressor = zlib.decompressobj(self._type_mask[compression_type])
else:
- self.decompressor = None
+ self._decompressor = None
+
if self._writeable():
- self.compressor = self._create_compressor(self.compression_type)
+ self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+ zlib.DEFLATED,
+ self._type_mask[compression_type])
else:
- self.compressor = None
+ self._compressor = None
def _validate_compression_type(self, compression_type):
if not CompressionTypes.is_valid_compression_type(compression_type):
raise TypeError('compression_type must be CompressionType object but '
'was %s' % type(compression_type))
- if compression_type == CompressionTypes.NO_COMPRESSION:
- raise ValueError('cannot create object with no compression')
-
- def _create_compressor(self, compression_type):
- self._validate_compression_type(compression_type)
- return zlib.compressobj(9, zlib.DEFLATED,
- self._type_mask[compression_type])
-
- def _create_decompressor(self, compression_type):
- self._validate_compression_type(compression_type)
- return zlib.decompressobj(self._type_mask[compression_type])
+ if (compression_type == CompressionTypes.AUTO or
+ compression_type == CompressionTypes.UNCOMPRESSED):
+ raise ValueError(
+ 'cannot create object with unspecified or no compression')
def _readable(self):
- mode = self.fileobj.mode
+ mode = self._file.mode
return 'r' in mode or 'a' in mode
def _writeable(self):
- mode = self.fileobj.mode
+ mode = self._file.mode
return 'w' in mode or 'a' in mode
def write(self, data):
"""Write data to file."""
- if not self.compressor:
+ if not self._compressor:
raise ValueError('compressor not initialized')
- compressed = self.compressor.compress(data)
+ compressed = self._compressor.compress(data)
if compressed:
- self.fileobj.write(compressed)
-
- def _read(self, num_bytes):
- """Read num_bytes into internal buffer."""
- while not num_bytes or len(self.data) < num_bytes:
- buf = self.fileobj.read(self.read_size)
- if not buf:
+ self._file.write(compressed)
+
+ def _fetch_to_internal_buffer(self, num_bytes):
+ """Fetch up to num_bytes into the internal buffer."""
+ while len(self._data) < num_bytes:
+ buf = self._file.read(self._read_size)
+ if buf:
+ self._data += self._decompressor.decompress(buf)
+ else:
# EOF reached, flush.
- self.data += self.decompressor.flush()
- break
+ self._data += self._decompressor.flush()
+ return
- self.data += self.decompressor.decompress(buf)
- result = self.data[:num_bytes]
- self.data = self.data[num_bytes:]
+ def _read_from_internal_buffer(self, num_bytes):
+ """Read up to num_bytes from the internal buffer."""
+ result = self._data[:num_bytes]
+ self._data = self._data[num_bytes:]
return result
def read(self, num_bytes):
- if not self.decompressor:
+ if not self._decompressor:
raise ValueError('decompressor not initialized')
- return self._read(num_bytes)
+ self._fetch_to_internal_buffer(num_bytes)
+ return self._read_from_internal_buffer(num_bytes)
+
+ def readline(self):
+ """Equivalent to standard file.readline(). Same return conventions apply."""
+ if not self._decompressor:
+ raise ValueError('decompressor not initialized')
+ result = ''
+ while True:
+ self._fetch_to_internal_buffer(self._read_size)
+ if not self._data:
+ break # EOF reached.
+ index = self._data.find('\n')
+ if index == -1:
+ result += self._read_from_internal_buffer(len(self._data))
+ else:
+ result += self._read_from_internal_buffer(index + 1)
+ break # Newline reached.
+ return result
@property
def closed(self):
- return not self.fileobj or self.fileobj.closed()
+ return not self._file or self._file.closed()
def close(self):
- if self.fileobj is None:
+ if self._file is None:
return
if self._writeable():
- self.fileobj.write(self.compressor.flush())
- self.fileobj.close()
+ self._file.write(self._compressor.flush())
+ self._file.close()
def flush(self):
if self._writeable():
- self.fileobj.write(self.compressor.flush())
- self.fileobj.flush()
+ self._file.write(self._compressor.flush())
+ self._file.flush()
- # TODO(slaven): Add support for seeking to a file position.
+ # TODO: Add support for seeking to a file position.
@property
def seekable(self):
return False
@@ -426,7 +707,24 @@ class FileSink(iobase.Sink):
num_shards=0,
shard_name_template=None,
mime_type='application/octet-stream',
- compression_type=CompressionTypes.NO_COMPRESSION):
+ compression_type=CompressionTypes.AUTO):
+ """
+ Raises:
+ TypeError: if file path parameters are not a string or if compression_type
+ is not member of CompressionTypes.
+ ValueError: if shard_name_template is not of expected format.
+ """
+ if not isinstance(file_path_prefix, basestring):
+ raise TypeError('file_path_prefix must be a string; got %r instead' %
+ file_path_prefix)
+ if not isinstance(file_name_suffix, basestring):
+ raise TypeError('file_name_suffix must be a string; got %r instead' %
+ file_name_suffix)
+
+ if not CompressionTypes.is_valid_compression_type(compression_type):
+ raise TypeError('compression_type must be CompressionType object but '
+ 'was %s' % type(compression_type))
+
if shard_name_template is None:
shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
elif shard_name_template is '':
@@ -436,11 +734,8 @@ class FileSink(iobase.Sink):
self.num_shards = num_shards
self.coder = coder
self.shard_name_format = self._template_to_format(shard_name_template)
- if not CompressionTypes.is_valid_compression_type(compression_type):
- raise TypeError('compression_type must be CompressionType object but '
- 'was %s' % type(compression_type))
self.compression_type = compression_type
- self.mime_type = CompressionTypes.mime_type(compression_type, mime_type)
+ self.mime_type = mime_type
def open(self, temp_path):
"""Opens ``temp_path``, returning an opaque file handle object.
@@ -448,12 +743,11 @@ class FileSink(iobase.Sink):
The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
- raw_file = ChannelFactory.open(temp_path, 'wb', self.mime_type)
- if self.compression_type == CompressionTypes.NO_COMPRESSION:
- return raw_file
- else:
- return _CompressedFile(fileobj=raw_file,
- compression_type=self.compression_type)
+ return ChannelFactory.open(
+ temp_path,
+ 'wb',
+ mime_type=self.mime_type,
+ compression_type=self.compression_type)
def write_record(self, file_handle, value):
"""Writes a single record go the file handle returned by ``open()``.
@@ -491,17 +785,16 @@ class FileSink(iobase.Sink):
writer_results = sorted(writer_results)
num_shards = len(writer_results)
channel_factory = ChannelFactory()
- num_threads = max(1, min(
- num_shards / FileSink._WRITE_RESULTS_PER_RENAME_THREAD,
- FileSink._MAX_RENAME_THREADS))
+ min_threads = min(num_shards / FileSink._WRITE_RESULTS_PER_RENAME_THREAD,
+ FileSink._MAX_RENAME_THREADS)
+ num_threads = max(1, min_threads)
rename_ops = []
for shard_num, shard in enumerate(writer_results):
final_name = ''.join([
- self.file_path_prefix,
- self.shard_name_format % dict(shard_num=shard_num,
- num_shards=num_shards),
- self.file_name_suffix])
+ self.file_path_prefix, self.shard_name_format % dict(
+ shard_num=shard_num, num_shards=num_shards), self.file_name_suffix
+ ])
rename_ops.append((shard, final_name))
logging.info(
@@ -537,7 +830,7 @@ class FileSink(iobase.Sink):
# ThreadPool crashes in old versions of Python (< 2.7.5) if created from a
# child thread. (http://bugs.python.org/issue10015)
- if not hasattr(threading.current_thread(), "_children"):
+ if not hasattr(threading.current_thread(), '_children'):
threading.current_thread()._children = weakref.WeakKeyDictionary()
rename_results = ThreadPool(num_threads).map(_rename_file, rename_ops)
@@ -548,8 +841,8 @@ class FileSink(iobase.Sink):
else:
yield final_name
- logging.info('Renamed %d shards in %.2f seconds.',
- num_shards, time.time() - start_time)
+ logging.info('Renamed %d shards in %.2f seconds.', num_shards,
+ time.time() - start_time)
try:
channel_factory.rmdir(init_result)
@@ -563,8 +856,8 @@ class FileSink(iobase.Sink):
return ''
m = re.search('S+', shard_name_template)
if m is None:
- raise ValueError("Shard number pattern S+ not found in template '%s'"
- % shard_name_template)
+ raise ValueError("Shard number pattern S+ not found in template '%s'" %
+ shard_name_template)
shard_name_format = shard_name_template.replace(
m.group(0), '%%(shard_num)0%dd' % len(m.group(0)))
m = re.search('N+', shard_name_format)
@@ -574,7 +867,7 @@ class FileSink(iobase.Sink):
return shard_name_format
def __eq__(self, other):
- # TODO(robertwb): Clean up workitem_test which uses this.
+ # TODO: Clean up workitem_test which uses this.
# pylint: disable=unidiomatic-typecheck
return type(self) == type(other) and self.__dict__ == other.__dict__
@@ -606,8 +899,7 @@ class TextFileSink(FileSink):
num_shards=0,
shard_name_template=None,
coder=coders.ToStringCoder(),
- compression_type=CompressionTypes.NO_COMPRESSION,
- ):
+ compression_type=CompressionTypes.AUTO):
"""Initialize a TextFileSink.
Args:
@@ -633,34 +925,23 @@ class TextFileSink(FileSink):
case it behaves as if num_shards was set to 1 and only one file will be
generated. The default pattern used is '-SSSSS-of-NNNNN'.
coder: Coder used to encode each line.
- compression_type: Type of compression to use for this sink.
-
- Raises:
- TypeError: if file path parameters are not a string or if compression_type
- is not member of CompressionTypes.
- ValueError: if shard_name_template is not of expected format.
+ compression_type: Used to handle compressed output files. Typical value
+ is CompressionTypes.AUTO, in which case the final file path's
+ extension (as determined by file_path_prefix, file_name_suffix,
+ num_shards and shard_name_template) will be used to detect the
+ compression.
Returns:
A TextFileSink object usable for writing.
"""
- if not isinstance(file_path_prefix, basestring):
- raise TypeError(
- 'TextFileSink: file_path_prefix must be a string; got %r instead' %
- file_path_prefix)
- if not isinstance(file_name_suffix, basestring):
- raise TypeError(
- 'TextFileSink: file_name_suffix must be a string; got %r instead' %
- file_name_suffix)
-
- super(TextFileSink, self).__init__(file_path_prefix,
- file_name_suffix=file_name_suffix,
- num_shards=num_shards,
- shard_name_template=shard_name_template,
- coder=coder,
- mime_type='text/plain',
- compression_type=compression_type)
-
- self.compression_type = compression_type
+ super(TextFileSink, self).__init__(
+ file_path_prefix,
+ file_name_suffix=file_name_suffix,
+ num_shards=num_shards,
+ shard_name_template=shard_name_template,
+ coder=coder,
+ mime_type='text/plain',
+ compression_type=compression_type)
self.append_trailing_newlines = append_trailing_newlines
def write_encoded_record(self, file_handle, encoded_value):
@@ -669,173 +950,168 @@ class TextFileSink(FileSink):
if self.append_trailing_newlines:
file_handle.write('\n')
- def close(self, file_handle):
- """Finalize and close the file handle returned from ``open()``.
-
- Args:
- file_handle: file handle to be closed.
- Raises:
- ValueError: if file_handle is already closed.
- """
- if file_handle is not None:
- file_handle.close()
+class NativeFileSink(iobase.NativeSink):
+ """A sink implemented by Dataflow service to a GCS or local file or files.
-class NativeTextFileSink(iobase.NativeSink):
- """A sink to a GCS or local text file or files."""
+ This class is to be only inherited by sinks natively implemented by Cloud
+ Dataflow service, hence should not be sub-classed by users.
+ """
- def __init__(self, file_path_prefix,
- append_trailing_newlines=True,
+ def __init__(self,
+ file_path_prefix,
file_name_suffix='',
num_shards=0,
shard_name_template=None,
validate=True,
- coder=coders.ToStringCoder()):
+ coder=coders.BytesCoder(),
+ mime_type='application/octet-stream',
+ compression_type=CompressionTypes.AUTO):
+ if not CompressionTypes.is_valid_compression_type(compression_type):
+ raise TypeError('compression_type must be CompressionType object but '
+ 'was %s' % type(compression_type))
+
# We initialize a file_path attribute containing just the prefix part for
# local runner environment. For now, sharding is not supported in the local
# runner and sharding options (template, num, suffix) are ignored.
# The attribute is also used in the worker environment when we just write
# to a specific file.
- # TODO(silviuc): Add support for file sharding in the local runner.
+ # TODO: Add support for file sharding in the local runner.
self.file_path = file_path_prefix
- self.append_trailing_newlines = append_trailing_newlines
self.coder = coder
-
- self.is_gcs_sink = self.file_path.startswith('gs://')
-
self.file_name_prefix = file_path_prefix
self.file_name_suffix = file_name_suffix
self.num_shards = num_shards
- # TODO(silviuc): Update this when the service supports more patterns.
+ # TODO: Update this when the service supports more patterns.
self.shard_name_template = ('-SSSSS-of-NNNNN' if shard_name_template is None
else shard_name_template)
- # TODO(silviuc): Implement sink validation.
+ # TODO: Implement sink validation.
self.validate = validate
-
- @property
- def format(self):
- """Sink format name required for remote execution."""
- return 'text'
+ self.mime_type = mime_type
+ self.compression_type = compression_type
@property
def path(self):
return self.file_path
def writer(self):
- return TextFileWriter(self)
+ return NativeFileSinkWriter(self)
def __eq__(self, other):
- return (self.file_path == other.file_path and
- self.append_trailing_newlines == other.append_trailing_newlines and
- self.coder == other.coder and
+ return (self.file_path == other.file_path and self.coder == other.coder and
self.file_name_prefix == other.file_name_prefix and
self.file_name_suffix == other.file_name_suffix and
self.num_shards == other.num_shards and
self.shard_name_template == other.shard_name_template and
- self.validate == other.validate)
+ self.validate == other.validate and
+ self.mime_type == other.mime_type and
+ self.compression_type == other.compression_type)
+
+
+class NativeFileSinkWriter(iobase.NativeSinkWriter):
+ """The sink writer for a NativeFileSink.
+ This class is to be only inherited by sink writers natively implemented by
+ Cloud Dataflow service, hence should not be sub-classed by users.
+ """
+
+ def __init__(self, sink):
+ self.sink = sink
+
+ def __enter__(self):
+ self.file = ChannelFactory.open(
+ self.sink.file_path,
+ 'wb',
+ mime_type=self.sink.mime_type,
+ compression_type=self.sink.compression_type)
+
+ if (ChannelFactory.is_compressed(self.file) and
+ not isinstance(self.sink, NativeTextFileSink)):
+ raise ValueError(
+ 'Unexpected compressed file for a non-NativeTextFileSink.')
+
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self.file.close()
+
+ def Write(self, value):
+ self.file.write(self.sink.coder.encode(value))
+
+
+class NativeTextFileSink(NativeFileSink):
+ """A sink to a GCS or local text file or files."""
+
+ def __init__(self,
+ file_path_prefix,
+ append_trailing_newlines=True,
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=None,
+ validate=True,
+ coder=coders.ToStringCoder(),
+ mime_type='text/plain',
+ compression_type=CompressionTypes.AUTO):
+ super(NativeTextFileSink, self).__init__(
+ file_path_prefix,
+ file_name_suffix=file_name_suffix,
+ num_shards=num_shards,
+ shard_name_template=shard_name_template,
+ validate=validate,
+ coder=coder,
+ mime_type=mime_type,
+ compression_type=compression_type)
+ self.append_trailing_newlines = append_trailing_newlines
+
+ @property
+ def format(self):
+ """Sink format name required for remote execution."""
+ return 'text'
+
+ def writer(self):
+ return TextFileWriter(self)
+
+ def __eq__(self, other):
+ return (super(NativeTextFileSink, self).__eq__(other) and
+ self.append_trailing_newlines == other.append_trailing_newlines)
# -----------------------------------------------------------------------------
# TextFileReader, TextMultiFileReader.
-class TextFileReader(iobase.NativeSourceReader,
- coders.observable.ObservableMixin):
+class TextFileReader(NativeFileSourceReader):
"""A reader for a text file source."""
- def __init__(self, source):
- super(TextFileReader, self).__init__()
- self.source = source
- self.start_offset = self.source.start_offset or 0
- self.end_offset = self.source.end_offset
- self.current_offset = self.start_offset
+ def seek_to_true_start_offset(self):
+ if ChannelFactory.is_compressed(self.file):
+ # When compression is enabled both initial and dynamic splitting should be
+ # prevented. Here we don't perform any seeking to a different offset, nor
+ # do we update the current_offset so that the rest of the framework can
+ # properly deal with compressed files.
+ return
- def __enter__(self):
- if self.source.is_gcs_source:
- # pylint: disable=wrong-import-order, wrong-import-position
- from apache_beam.io import gcsio
- self._file = gcsio.GcsIO().open(self.source.file_path, 'rb')
- else:
- self._file = open(self.source.file_path, 'rb')
- # Determine the real end_offset.
- # If not specified it will be the length of the file.
- if self.end_offset is None:
- self._file.seek(0, os.SEEK_END)
- self.end_offset = self._file.tell()
-
- if self.start_offset is None:
- self.start_offset = 0
- self.current_offset = self.start_offset
if self.start_offset > 0:
# Read one byte before. This operation will either consume a previous
# newline if start_offset was at the beginning of a line or consume the
- # line if we were in the middle of it. Either way we get the read position
- # exactly where we wanted: at the begining of the first full line.
- self._file.seek(self.start_offset - 1)
+ # line if we were in the middle of it. Either way we get the read
+ # position exactly where we wanted: at the beginning of the first full
+ # line.
+ self.file.seek(self.start_offset - 1)
self.current_offset -= 1
- line = self._file.readline()
+ line = self.file.readline()
self.notify_observers(line, is_encoded=True)
self.current_offset += len(line)
- else:
- self._file.seek(self.start_offset)
- # Initializing range tracker after start and end offsets are finalized.
- self.range_tracker = range_trackers.OffsetRangeTracker(self.start_offset,
- self.end_offset)
-
- return self
+ def read_records(self):
+ line = self.file.readline()
+ delta_offset = len(line)
- def __exit__(self, exception_type, exception_value, traceback):
- self._file.close()
-
- def __iter__(self):
- while True:
- if not self.range_tracker.try_claim(
- record_start=self.current_offset):
- # Reader has completed reading the set of records in its range. Note
- # that the end offset of the range may be smaller than the original
- # end offset defined when creating the reader due to reader accepting
- # a dynamic split request from the service.
- return
- line = self._file.readline()
- self.notify_observers(line, is_encoded=True)
- self.current_offset += len(line)
+ if delta_offset == 0:
+ yield True, None, None # Reached EOF.
+ else:
if self.source.strip_trailing_newlines:
line = line.rstrip('\n')
- yield self.source.coder.decode(line)
-
- def get_progress(self):
- return iobase.ReaderProgress(position=iobase.ReaderPosition(
- byte_offset=self.range_tracker.last_record_start))
-
- def request_dynamic_split(self, dynamic_split_request):
- assert dynamic_split_request is not None
- progress = dynamic_split_request.progress
- split_position = progress.position
- if split_position is None:
- percent_complete = progress.percent_complete
- if percent_complete is not None:
- if percent_complete <= 0 or percent_complete >= 1:
- logging.warning(
- 'FileBasedReader cannot be split since the provided percentage '
- 'of work to be completed is out of the valid range (0, '
- '1). Requested: %r',
- dynamic_split_request)
- return
- split_position = iobase.ReaderPosition()
- split_position.byte_offset = (
- self.range_tracker.position_at_fraction(percent_complete))
- else:
- logging.warning(
- 'TextReader requires either a position or a percentage of work to '
- 'be complete to perform a dynamic split request. Requested: %r',
- dynamic_split_request)
- return
-
- if self.range_tracker.try_split(split_position.byte_offset):
- return iobase.DynamicSplitResultWithPosition(split_position)
- else:
- return
+ yield False, self.source.coder.decode(line), delta_offset
class TextMultiFileReader(iobase.NativeSourceReader):
@@ -845,8 +1121,7 @@ class TextMultiFileReader(iobase.NativeSourceReader):
self.source = source
self.file_paths = ChannelFactory.glob(self.source.file_path)
if not self.file_paths:
- raise RuntimeError(
- 'No files found for path: %s' % self.source.file_path)
+ raise RuntimeError('No files found for path: %s' % self.source.file_path)
def __enter__(self):
return self
@@ -860,35 +1135,20 @@ class TextMultiFileReader(iobase.NativeSourceReader):
index += 1
logging.info('Reading from %s (%d/%d)', path, index, len(self.file_paths))
with TextFileSource(
- path, strip_trailing_newlines=self.source.strip_trailing_newlines,
+ path,
+ strip_trailing_newlines=self.source.strip_trailing_newlines,
coder=self.source.coder).reader() as reader:
for line in reader:
yield line
-
# -----------------------------------------------------------------------------
# TextFileWriter.
-class TextFileWriter(iobase.NativeSinkWriter):
+class TextFileWriter(NativeFileSinkWriter):
"""The sink writer for a TextFileSink."""
- def __init__(self, sink):
- self.sink = sink
-
- def __enter__(self):
- if self.sink.is_gcs_sink:
- # pylint: disable=wrong-import-order, wrong-import-position
- from apache_beam.io import gcsio
- self._file = gcsio.GcsIO().open(self.sink.file_path, 'wb')
- else:
- self._file = open(self.sink.file_path, 'wb')
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- self._file.close()
-
- def Write(self, line):
- self._file.write(self.sink.coder.encode(line))
+ def Write(self, value):
+ super(TextFileWriter, self).Write(value)
if self.sink.append_trailing_newlines:
- self._file.write('\n')
+ self.file.write('\n')
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6a483c18/sdks/python/apache_beam/io/fileio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py
index e71ba6d..6c6fe12 100644
--- a/sdks/python/apache_beam/io/fileio_test.py
+++ b/sdks/python/apache_beam/io/fileio_test.py
@@ -31,31 +31,44 @@ from apache_beam import coders
from apache_beam.io import fileio
from apache_beam.io import iobase
+# TODO: Add tests for file patterns (ie not just individual files) for both
+# uncompressed
+
+# TODO: Update code to not use NamedTemporaryFile (or to use it in a way that
+# doesn't violate its assumptions).
+
class TestTextFileSource(unittest.TestCase):
- def create_temp_file(self, text):
- temp = tempfile.NamedTemporaryFile(delete=False)
+ def create_temp_file(self, text, suffix=''):
+ temp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
with temp.file as tmp:
tmp.write(text)
return temp.name
- def read_with_offsets(self, input_lines, output_lines,
- start_offset=None, end_offset=None):
+ def read_with_offsets(self,
+ input_lines,
+ output_lines,
+ start_offset=None,
+ end_offset=None):
source = fileio.TextFileSource(
file_path=self.create_temp_file('\n'.join(input_lines)),
- start_offset=start_offset, end_offset=end_offset)
+ start_offset=start_offset,
+ end_offset=end_offset)
read_lines = []
with source.reader() as reader:
for line in reader:
read_lines.append(line)
self.assertEqual(read_lines, output_lines)
- def progress_with_offsets(self, input_lines,
- start_offset=None, end_offset=None):
+ def progress_with_offsets(self,
+ input_lines,
+ start_offset=None,
+ end_offset=None):
source = fileio.TextFileSource(
file_path=self.create_temp_file('\n'.join(input_lines)),
- start_offset=start_offset, end_offset=end_offset)
+ start_offset=start_offset,
+ end_offset=end_offset)
progress_record = []
with source.reader() as reader:
self.assertEqual(reader.get_progress().position.byte_offset, -1)
@@ -78,6 +91,170 @@ class TestTextFileSource(unittest.TestCase):
read_lines.append(line)
self.assertEqual(read_lines, lines)
+ def test_read_entire_file_empty(self):
+ source = fileio.TextFileSource(file_path=self.create_temp_file(''))
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, [])
+
+ def test_read_entire_file_gzip(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.GZIP)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_read_entire_file_gzip_auto(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(file_path=self.create_temp_file(
+ data, suffix='.gz'))
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_read_entire_file_gzip_empty(self):
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('') + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.GZIP)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, [])
+
+ def test_read_entire_file_gzip_large(self):
+ lines = ['Line %d' % d for d in range(10 * 1000)]
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.GZIP)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_read_entire_file_zlib(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.ZLIB)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_read_entire_file_zlib_auto(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(file_path=self.create_temp_file(
+ data, suffix='.Z'))
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_read_entire_file_zlib_empty(self):
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('') + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.ZLIB)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, [])
+
+ def test_read_entire_file_zlib_large(self):
+ lines = ['Line %d' % d for d in range(10 * 1000)]
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.ZLIB)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_skip_entire_file_gzip(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ start_offset=1, # Anything other than 0 should lead to a null-read.
+ compression_type=fileio.CompressionTypes.ZLIB)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, [])
+
+ def test_skip_entire_file_zlib(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ start_offset=1, # Anything other than 0 should lead to a null-read.
+ compression_type=fileio.CompressionTypes.GZIP)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, [])
+
+ def test_consume_entire_file_gzip(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ end_offset=1, # Any end_offset should effectively be ignored.
+ compression_type=fileio.CompressionTypes.GZIP)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_consume_entire_file_zlib(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ end_offset=1, # Any end_offset should effectively be ignored.
+ compression_type=fileio.CompressionTypes.ZLIB)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
def test_progress_entire_file(self):
lines = ['First', 'Second', 'Third']
source = fileio.TextFileSource(
@@ -93,6 +270,44 @@ class TestTextFileSource(unittest.TestCase):
self.assertEqual(len(progress_record), 3)
self.assertEqual(progress_record, [0, 6, 13])
+ def test_progress_entire_file_gzip(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.GZIP)
+ progress_record = []
+ with source.reader() as reader:
+ self.assertEqual(-1, reader.get_progress().position.byte_offset)
+ for line in reader:
+ self.assertIsNotNone(line)
+ progress_record.append(reader.get_progress().position.byte_offset)
+ self.assertEqual(18, # Reading the entire contents before we decide EOF.
+ reader.get_progress().position.byte_offset)
+
+ self.assertEqual(len(progress_record), 3)
+ self.assertEqual(progress_record, [0, 6, 13])
+
+ def test_progress_entire_file_zlib(self):
+ lines = ['First', 'Second', 'Third']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.ZLIB)
+ progress_record = []
+ with source.reader() as reader:
+ self.assertEqual(-1, reader.get_progress().position.byte_offset)
+ for line in reader:
+ self.assertIsNotNone(line)
+ progress_record.append(reader.get_progress().position.byte_offset)
+ self.assertEqual(18, # Reading the entire contents before we decide EOF.
+ reader.get_progress().position.byte_offset)
+
+ self.assertEqual(len(progress_record), 3)
+ self.assertEqual(progress_record, [0, 6, 13])
+
def try_splitting_reader_at(self, reader, split_request, expected_response):
actual_response = reader.request_dynamic_split(split_request)
@@ -108,6 +323,66 @@ class TestTextFileSource(unittest.TestCase):
return actual_response
+ def test_gzip_file_unsplittable(self):
+ lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS | 16)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.GZIP)
+
+ with source.reader() as reader:
+ percents_complete = [x / 100.0 for x in range(101)]
+
+ # Cursor at beginning of file.
+ for percent_complete in percents_complete:
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=percent_complete)),
+ None)
+
+ # Cursor passed beginning of file.
+ reader_iter = iter(reader)
+ next(reader_iter)
+ next(reader_iter)
+ for percent_complete in percents_complete:
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=percent_complete)),
+ None)
+
+ def test_zlib_file_unsplittable(self):
+ lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
+ compressor = zlib.compressobj(-1, zlib.DEFLATED, zlib.MAX_WBITS)
+ data = compressor.compress('\n'.join(lines)) + compressor.flush()
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file(data),
+ compression_type=fileio.CompressionTypes.ZLIB)
+
+ with source.reader() as reader:
+ percents_complete = [x / 100.0 for x in range(101)]
+
+ # Cursor at beginning of file.
+ for percent_complete in percents_complete:
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=percent_complete)),
+ None)
+
+ # Cursor passed beginning of file.
+ reader_iter = iter(reader)
+ next(reader_iter)
+ next(reader_iter)
+ for percent_complete in percents_complete:
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=percent_complete)),
+ None)
+
def test_update_stop_position_for_percent_complete(self):
lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
source = fileio.TextFileSource(
@@ -132,23 +407,23 @@ class TestTextFileSource(unittest.TestCase):
# Splitting at positions on or before start offset of the last record
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
- 0.2)),
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=0.2)),
None)
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
- 0.4)),
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=0.4)),
None)
# Splitting at a position after the start offset of the last record should
# be successful
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
- 0.6)),
- iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition(
- byte_offset=15)))
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(percent_complete=0.6)),
+ iobase.DynamicSplitResultWithPosition(
+ iobase.ReaderPosition(byte_offset=15)))
def test_update_stop_position_percent_complete_for_position(self):
lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
@@ -164,35 +439,40 @@ class TestTextFileSource(unittest.TestCase):
# Splitting at end of the range should be unsuccessful
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=0))),
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=0))),
None)
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=25))),
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=25))),
None)
# Splitting at positions on or before start offset of the last record
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=5))),
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=5))),
None)
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=10))),
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=10))),
None)
# Splitting at a position after the start offset of the last record should
# be successful
self.try_splitting_reader_at(
reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=15))),
- iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition(
- byte_offset=15)))
+ iobase.DynamicSplitRequest(
+ iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=15))),
+ iobase.DynamicSplitResultWithPosition(
+ iobase.ReaderPosition(byte_offset=15)))
def run_update_stop_position_exhaustive(self, lines, newline):
"""An exhaustive test for dynamic splitting.
@@ -240,13 +520,12 @@ class TestTextFileSource(unittest.TestCase):
['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\r\n')
def test_update_stop_position_exhaustive_multi_byte(self):
- self.run_update_stop_position_exhaustive(
- [u'\u0d85\u0d85\u0d85\u0d85'.encode('utf-8'), u'\u0db6\u0db6\u0db6\u0db6'.encode('utf-8'),
- u'\u0d9a\u0d9a\u0d9a\u0d9a'.encode('utf-8')], '\n')
+ self.run_update_stop_position_exhaustive([u'\u0d85\u0d85\u0d85\u0d85'.encode('utf-8'),
+ u'\u0db6\u0db6\u0db6\u0db6'.encode('utf-8'),
+ u'\u0d9a\u0d9a\u0d9a\u0d9a'.encode('utf-8')], '\n')
def run_update_stop_position(self, start_offset, end_offset, stop_offset,
- records_to_read,
- file_path):
+ records_to_read, file_path):
source = fileio.TextFileSource(file_path, start_offset, end_offset)
records_of_first_split = ''
@@ -296,9 +575,7 @@ class TestTextFileSource(unittest.TestCase):
records_of_original += line
new_source = fileio.TextFileSource(
- file_path,
- split_response.stop_position.byte_offset,
- end_offset)
+ file_path, split_response.stop_position.byte_offset, end_offset)
with new_source.reader() as reader:
for line in reader:
records_of_second_split += line
@@ -331,21 +608,81 @@ class TestTextFileSource(unittest.TestCase):
self.progress_with_offsets(lines, start_offset=20, end_offset=20)
-class NativeTestTextFileSink(unittest.TestCase):
+class TestNativeTextFileSink(unittest.TestCase):
- def create_temp_file(self):
- temp = tempfile.NamedTemporaryFile(delete=False)
- return temp.name
+ def setUp(self):
+ self.lines = ['Line %d' % d for d in range(100)]
+ self.path = tempfile.NamedTemporaryFile().name
- def test_write_entire_file(self):
- lines = ['First', 'Second', 'Third']
- file_path = self.create_temp_file()
- sink = fileio.NativeTextFileSink(file_path)
+ def _write_lines(self, sink, lines):
with sink.writer() as writer:
for line in lines:
writer.Write(line)
- with open(file_path, 'r') as f:
- self.assertEqual(f.read().splitlines(), lines)
+
+ def test_write_text_file(self):
+ sink = fileio.NativeTextFileSink(self.path)
+ self._write_lines(sink, self.lines)
+
+ with open(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), self.lines)
+
+ def test_write_text_file_empty(self):
+ sink = fileio.NativeTextFileSink(self.path)
+ self._write_lines(sink, [])
+
+ with open(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), [])
+
+ def test_write_text_gzip_file(self):
+ sink = fileio.NativeTextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.GZIP)
+ self._write_lines(sink, self.lines)
+
+ with gzip.GzipFile(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), self.lines)
+
+ def test_write_text_gzip_file_auto(self):
+ self.path = tempfile.NamedTemporaryFile(suffix='.gz').name
+ sink = fileio.NativeTextFileSink(self.path)
+ self._write_lines(sink, self.lines)
+
+ with gzip.GzipFile(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), self.lines)
+
+ def test_write_text_gzip_file_empty(self):
+ sink = fileio.NativeTextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.GZIP)
+ self._write_lines(sink, [])
+
+ with gzip.GzipFile(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), [])
+
+ def test_write_text_zlib_file(self):
+ sink = fileio.NativeTextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.ZLIB)
+ self._write_lines(sink, self.lines)
+
+ with open(self.path, 'r') as f:
+ self.assertEqual(
+ zlib.decompress(f.read(), zlib.MAX_WBITS).splitlines(), self.lines)
+
+ def test_write_text_zlib_file_auto(self):
+ self.path = tempfile.NamedTemporaryFile(suffix='.Z').name
+ sink = fileio.NativeTextFileSink(self.path)
+ self._write_lines(sink, self.lines)
+
+ with open(self.path, 'r') as f:
+ self.assertEqual(
+ zlib.decompress(f.read(), zlib.MAX_WBITS).splitlines(), self.lines)
+
+ def test_write_text_zlib_file_empty(self):
+ sink = fileio.NativeTextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.ZLIB)
+ self._write_lines(sink, [])
+
+ with open(self.path, 'r') as f:
+ self.assertEqual(
+ zlib.decompress(f.read(), zlib.MAX_WBITS).splitlines(), [])
class TestTextFileSink(unittest.TestCase):
@@ -367,42 +704,72 @@ class TestTextFileSink(unittest.TestCase):
with open(self.path, 'r') as f:
self.assertEqual(f.read().splitlines(), self.lines)
- def test_write_deflate_file(self):
- sink = fileio.TextFileSink(self.path,
- compression_type=fileio.CompressionTypes.DEFLATE)
- self._write_lines(sink, self.lines)
+ def test_write_text_file_empty(self):
+ sink = fileio.TextFileSink(self.path)
+ self._write_lines(sink, [])
with open(self.path, 'r') as f:
- content = f.read()
- self.assertEqual(
- zlib.decompress(content, -zlib.MAX_WBITS).splitlines(), self.lines)
+ self.assertEqual(f.read().splitlines(), [])
def test_write_gzip_file(self):
- sink = fileio.TextFileSink(self.path,
- compression_type=fileio.CompressionTypes.GZIP)
+ sink = fileio.TextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.GZIP)
self._write_lines(sink, self.lines)
with gzip.GzipFile(self.path, 'r') as f:
self.assertEqual(f.read().splitlines(), self.lines)
+ def test_write_gzip_file_auto(self):
+ self.path = tempfile.NamedTemporaryFile(suffix='.gz').name
+ sink = fileio.TextFileSink(self.path)
+ self._write_lines(sink, self.lines)
+
+ with gzip.GzipFile(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), self.lines)
+
+ def test_write_gzip_file_empty(self):
+ sink = fileio.TextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.GZIP)
+ self._write_lines(sink, [])
+
+ with gzip.GzipFile(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), [])
+
def test_write_zlib_file(self):
- sink = fileio.TextFileSink(self.path,
- compression_type=fileio.CompressionTypes.ZLIB)
+ sink = fileio.TextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.ZLIB)
self._write_lines(sink, self.lines)
with open(self.path, 'r') as f:
content = f.read()
- # Below decompress option should work for both zlib/gzip header
- # auto detection.
self.assertEqual(
- zlib.decompress(content, zlib.MAX_WBITS | 32).splitlines(),
- self.lines)
+ zlib.decompress(content, zlib.MAX_WBITS).splitlines(), self.lines)
+
+ def test_write_zlib_file_auto(self):
+ self.path = tempfile.NamedTemporaryFile(suffix='.Z').name
+ sink = fileio.TextFileSink(self.path)
+ self._write_lines(sink, self.lines)
+
+ with open(self.path, 'r') as f:
+ content = f.read()
+ self.assertEqual(
+ zlib.decompress(content, zlib.MAX_WBITS).splitlines(), self.lines)
+
+ def test_write_zlib_file_empty(self):
+ sink = fileio.TextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.ZLIB)
+ self._write_lines(sink, [])
+
+ with open(self.path, 'r') as f:
+ content = f.read()
+ self.assertEqual(
+ zlib.decompress(content, zlib.MAX_WBITS).splitlines(), [])
class MyFileSink(fileio.FileSink):
def open(self, temp_path):
- # TODO(robertwb): Fix main session pickling.
+ # TODO: Fix main session pickling.
# file_handle = super(MyFileSink, self).open(temp_path)
file_handle = fileio.FileSink.open(self, temp_path)
file_handle.write('[start]')
@@ -415,7 +782,7 @@ class MyFileSink(fileio.FileSink):
def close(self, file_handle):
file_handle.write('[end]')
- # TODO(robertwb): Fix main session pickling.
+ # TODO: Fix main session pickling.
# file_handle = super(MyFileSink, self).close(file_handle)
file_handle = fileio.FileSink.close(self, file_handle)
@@ -424,9 +791,8 @@ class TestFileSink(unittest.TestCase):
def test_file_sink_writing(self):
temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
+ sink = MyFileSink(
+ temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
# Manually invoke the generic Sink API.
init_token = sink.initialize_write()
@@ -442,7 +808,7 @@ class TestFileSink(unittest.TestCase):
writer2.write('z')
res2 = writer2.close()
- res = list(sink.finalize_write(init_token, [res1, res2]))
+ _ = list(sink.finalize_write(init_token, [res1, res2]))
# Retry the finalize operation (as if the first attempt was lost).
res = list(sink.finalize_write(init_token, [res1, res2]))
@@ -458,37 +824,37 @@ class TestFileSink(unittest.TestCase):
def test_empty_write(self):
temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
+ sink = MyFileSink(
+ temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
p = beam.Pipeline('DirectPipelineRunner')
p | beam.Create([]) | beam.io.Write(sink) # pylint: disable=expression-not-assigned
p.run()
- self.assertEqual(open(temp_path + '-00000-of-00001.foo').read(),
- '[start][end]')
+ self.assertEqual(
+ open(temp_path + '-00000-of-00001.foo').read(), '[start][end]')
def test_fixed_shard_write(self):
temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- num_shards=3,
- shard_name_template='_NN_SSS_',
- coder=coders.ToStringCoder())
+ sink = MyFileSink(
+ temp_path,
+ file_name_suffix='.foo',
+ num_shards=3,
+ shard_name_template='_NN_SSS_',
+ coder=coders.ToStringCoder())
p = beam.Pipeline('DirectPipelineRunner')
p | beam.Create(['a', 'b']) | beam.io.Write(sink) # pylint: disable=expression-not-assigned
p.run()
- concat = ''.join(open(temp_path + '_03_%03d_.foo' % shard_num).read()
- for shard_num in range(3))
+ concat = ''.join(
+ open(temp_path + '_03_%03d_.foo' % shard_num).read()
+ for shard_num in range(3))
self.assertTrue('][a][' in concat, concat)
self.assertTrue('][b][' in concat, concat)
def test_file_sink_multi_shards(self):
temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
+ sink = MyFileSink(
+ temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
# Manually invoke the generic Sink API.
init_token = sink.initialize_write()
@@ -522,9 +888,8 @@ class TestFileSink(unittest.TestCase):
def test_file_sink_io_error(self):
temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
+ sink = MyFileSink(
+ temp_path, file_name_suffix='.foo', coder=coders.ToStringCoder())
# Manually invoke the generic Sink API.
init_token = sink.initialize_write()
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6a483c18/sdks/python/apache_beam/io/gcsio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcsio.py b/sdks/python/apache_beam/io/gcsio.py
index 339fd41..5a83004 100644
--- a/sdks/python/apache_beam/io/gcsio.py
+++ b/sdks/python/apache_beam/io/gcsio.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
"""Google Cloud Storage client.
This library evolved from the Google App Engine GCS client available at
@@ -37,7 +36,6 @@ import apitools.base.py.transfer as transfer
from apache_beam.internal import auth
from apache_beam.utils import retry
-
# Issue a friendlier error message if the storage library is not available.
# TODO(silviuc): Remove this guard when storage is available everywhere.
try:
@@ -48,7 +46,6 @@ except ImportError:
'Google Cloud Storage I/O not supported for this execution environment '
'(could not import storage API client).')
-
DEFAULT_READ_BUFFER_SIZE = 1024 * 1024
WRITE_CHUNK_SIZE = 8 * 1024 * 1024
@@ -93,7 +90,9 @@ class GcsIO(object):
if storage_client is not None:
self.client = storage_client
- def open(self, filename, mode='r',
+ def open(self,
+ filename,
+ mode='r',
read_buffer_size=DEFAULT_READ_BUFFER_SIZE,
mime_type='application/octet-stream'):
"""Open a GCS file path for reading or writing.
@@ -158,8 +157,8 @@ class GcsIO(object):
path: GCS file path pattern in the form gs://<bucket>/<name>.
"""
bucket, object_path = parse_gcs_path(path)
- request = storage.StorageObjectsDeleteRequest(bucket=bucket,
- object=object_path)
+ request = storage.StorageObjectsDeleteRequest(
+ bucket=bucket, object=object_path)
try:
self.client.objects.Delete(request)
except HttpError as http_error:
@@ -179,10 +178,11 @@ class GcsIO(object):
"""
src_bucket, src_path = parse_gcs_path(src)
dest_bucket, dest_path = parse_gcs_path(dest)
- request = storage.StorageObjectsCopyRequest(sourceBucket=src_bucket,
- sourceObject=src_path,
- destinationBucket=dest_bucket,
- destinationObject=dest_path)
+ request = storage.StorageObjectsCopyRequest(
+ sourceBucket=src_bucket,
+ sourceObject=src_path,
+ destinationBucket=dest_bucket,
+ destinationObject=dest_path)
try:
self.client.objects.Copy(request)
except HttpError as http_error:
@@ -232,8 +232,8 @@ class GcsIO(object):
"""
bucket, object_path = parse_gcs_path(path)
try:
- request = storage.StorageObjectsGetRequest(bucket=bucket,
- object=object_path)
+ request = storage.StorageObjectsGetRequest(
+ bucket=bucket, object=object_path)
self.client.objects.Get(request) # metadata
return True
except HttpError as http_error:
@@ -255,35 +255,37 @@ class GcsIO(object):
Returns: size of the GCS object in bytes.
"""
bucket, object_path = parse_gcs_path(path)
- request = storage.StorageObjectsGetRequest(bucket=bucket,
- object=object_path)
+ request = storage.StorageObjectsGetRequest(
+ bucket=bucket, object=object_path)
return self.client.objects.Get(request).size
class GcsBufferedReader(object):
"""A class for reading Google Cloud Storage files."""
- def __init__(self, client, path, mode='r',
+ def __init__(self,
+ client,
+ path,
+ mode='r',
buffer_size=DEFAULT_READ_BUFFER_SIZE):
self.client = client
self.path = path
self.bucket, self.name = parse_gcs_path(path)
+ self.mode = mode
self.buffer_size = buffer_size
self.mode = mode
# Get object state.
- get_request = (
- storage.StorageObjectsGetRequest(
- bucket=self.bucket,
- object=self.name))
+ get_request = (storage.StorageObjectsGetRequest(
+ bucket=self.bucket, object=self.name))
try:
metadata = self._get_object_metadata(get_request)
except HttpError as http_error:
if http_error.status_code == 404:
raise IOError(errno.ENOENT, 'Not found: %s' % self.path)
else:
- logging.error(
- 'HTTP error while requesting file %s: %s', self.path, http_error)
+ logging.error('HTTP error while requesting file %s: %s', self.path,
+ http_error)
raise
self.size = metadata.size
@@ -373,17 +375,16 @@ class GcsBufferedReader(object):
# If readline is set, we only want to read up to and including the next
# newline character.
if readline:
- next_newline_position = self.buffer.find(
- '\n', buffer_bytes_read, len(self.buffer))
+ next_newline_position = self.buffer.find('\n', buffer_bytes_read,
+ len(self.buffer))
if next_newline_position != -1:
- bytes_to_read_from_buffer = (1 + next_newline_position -
- buffer_bytes_read)
+ bytes_to_read_from_buffer = (
+ 1 + next_newline_position - buffer_bytes_read)
break_after = True
# Read bytes.
- data_list.append(
- self.buffer[buffer_bytes_read:buffer_bytes_read +
- bytes_to_read_from_buffer])
+ data_list.append(self.buffer[buffer_bytes_read:buffer_bytes_read +
+ bytes_to_read_from_buffer])
self.position += bytes_to_read_from_buffer
to_read -= bytes_to_read_from_buffer
@@ -393,8 +394,8 @@ class GcsBufferedReader(object):
return ''.join(data_list)
def _fetch_next_if_buffer_exhausted(self):
- if not self.buffer or (self.buffer_start_position + len(self.buffer)
- <= self.position):
+ if not self.buffer or (
+ self.buffer_start_position + len(self.buffer) <= self.position):
bytes_to_request = min(self._remaining(), self.buffer_size)
self.buffer_start_position = self.position
self.buffer = self._get_segment(self.position, bytes_to_request)
@@ -548,10 +549,14 @@ class GcsBufferedWriter(object):
if self.closed:
raise IOError('Stream is closed.')
- def __init__(self, client, path, mode='w',
+ def __init__(self,
+ client,
+ path,
+ mode='w',
mime_type='application/octet-stream'):
self.client = client
self.path = path
+ self.mode = mode
self.bucket, self.name = parse_gcs_path(path)
self.mode = mode
@@ -568,12 +573,12 @@ class GcsBufferedWriter(object):
self.conn = parent_conn
# Set up uploader.
- self.insert_request = (
- storage.StorageObjectsInsertRequest(
- bucket=self.bucket,
- name=self.name))
- self.upload = transfer.Upload(GcsBufferedWriter.PipeStream(child_conn),
- mime_type, chunksize=WRITE_CHUNK_SIZE)
+ self.insert_request = (storage.StorageObjectsInsertRequest(
+ bucket=self.bucket, name=self.name))
+ self.upload = transfer.Upload(
+ GcsBufferedWriter.PipeStream(child_conn),
+ mime_type,
+ chunksize=WRITE_CHUNK_SIZE)
self.upload.strategy = transfer.RESUMABLE_UPLOAD
# Start uploading thread.
@@ -596,9 +601,8 @@ class GcsBufferedWriter(object):
try:
self.client.objects.Insert(self.insert_request, upload=self.upload)
except Exception as e: # pylint: disable=broad-except
- logging.error(
- 'Error in _start_upload while inserting file %s: %s', self.path,
- traceback.format_exc())
+ logging.error('Error in _start_upload while inserting file %s: %s',
+ self.path, traceback.format_exc())
self.upload_thread.last_error = e
finally:
self.child_conn.close()
@@ -620,12 +624,21 @@ class GcsBufferedWriter(object):
self._flush_write_buffer()
self.position += len(data)
+ def flush(self):
+ """Flushes any internal buffer to the underlying GCS file."""
+ self._check_open()
+ self._flush_write_buffer()
+
def tell(self):
"""Return the total number of bytes passed to write() so far."""
return self.position
def close(self):
"""Close the current GCS file."""
+ if self.closed:
+ logging.warn('Channel for %s is not open.', self.path)
+ return
+
self._flush_write_buffer()
self.closed = True
self.conn.close()