You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2022/07/11 17:15:03 UTC

[beam] branch master updated: Allow one to bound the size of output shards when writing to files. (#22130)

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new abc8099d71b Allow one to bound the size of output shards when writing to files. (#22130)
abc8099d71b is described below

commit abc8099d71b5edad9493c669b5f467e46013b204
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Mon Jul 11 10:14:56 2022 -0700

    Allow one to bound the size of output shards when writing to files. (#22130)
    
    This fixes #22129.
---
 sdks/python/apache_beam/io/filebasedsink.py | 39 ++++++++++++++++++++++++++-
 sdks/python/apache_beam/io/iobase.py        | 12 ++++++++-
 sdks/python/apache_beam/io/textio.py        | 28 ++++++++++++++++++-
 sdks/python/apache_beam/io/textio_test.py   | 42 +++++++++++++++++++++++++++++
 4 files changed, 118 insertions(+), 3 deletions(-)

diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py
index a75e2c77443..6d8c6f8846f 100644
--- a/sdks/python/apache_beam/io/filebasedsink.py
+++ b/sdks/python/apache_beam/io/filebasedsink.py
@@ -68,6 +68,9 @@ class FileBasedSink(iobase.Sink):
       shard_name_template=None,
       mime_type='application/octet-stream',
       compression_type=CompressionTypes.AUTO,
+      *,
+      max_records_per_shard=None,
+      max_bytes_per_shard=None,
       skip_if_empty=False):
     """
      Raises:
@@ -108,6 +111,8 @@ class FileBasedSink(iobase.Sink):
         shard_name_template)
     self.compression_type = compression_type
     self.mime_type = mime_type
+    self.max_records_per_shard = max_records_per_shard
+    self.max_bytes_per_shard = max_bytes_per_shard
     self.skip_if_empty = skip_if_empty
 
   def display_data(self):
@@ -130,7 +135,13 @@ class FileBasedSink(iobase.Sink):
     The returned file handle is passed to ``write_[encoded_]record`` and
     ``close``.
     """
-    return FileSystems.create(temp_path, self.mime_type, self.compression_type)
+    writer = FileSystems.create(
+        temp_path, self.mime_type, self.compression_type)
+    if self.max_bytes_per_shard:
+      self.byte_counter = _ByteCountingWriter(writer)
+      return self.byte_counter
+    else:
+      return writer
 
   def write_record(self, file_handle, value):
     """Writes a single record go the file handle returned by ``open()``.
@@ -406,10 +417,36 @@ class FileBasedSinkWriter(iobase.Writer):
     self.sink = sink
     self.temp_shard_path = temp_shard_path
     self.temp_handle = self.sink.open(temp_shard_path)
+    self.num_records_written = 0
 
   def write(self, value):
+    self.num_records_written += 1
     self.sink.write_record(self.temp_handle, value)
 
+  def at_capacity(self):
+    return (
+        self.sink.max_records_per_shard and
+        self.num_records_written >= self.sink.max_records_per_shard
+    ) or (
+        self.sink.max_bytes_per_shard and
+        self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard)
+
   def close(self):
     self.sink.close(self.temp_handle)
     return self.temp_shard_path
+
+
+class _ByteCountingWriter:
+  def __init__(self, writer):
+    self.writer = writer
+    self.bytes_written = 0
+
+  def write(self, bs):
+    self.bytes_written += len(bs)
+    self.writer.write(bs)
+
+  def flush(self):
+    self.writer.flush()
+
+  def close(self):
+    self.writer.close()
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index fe46671aaa8..6d75d520af5 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -849,7 +849,8 @@ class Writer(object):
   writing to a sink.
   """
   def write(self, value):
-    """Writes a value to the sink using the current writer."""
+    """Writes a value to the sink using the current writer.
+    """
     raise NotImplementedError
 
   def close(self):
@@ -863,6 +864,12 @@ class Writer(object):
     """
     raise NotImplementedError
 
+  def at_capacity(self) -> bool:
+    """Returns whether this writer should be considered at capacity
+    and a new one should be created.
+    """
+    return False
+
 
 class Read(ptransform.PTransform):
   """A transform that reads a PCollection."""
@@ -1185,6 +1192,9 @@ class _WriteBundleDoFn(core.DoFn):
       # We ignore UUID collisions here since they are extremely rare.
       self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
     self.writer.write(element)
+    if self.writer.at_capacity():
+      yield self.writer.close()
+      self.writer = None
 
   def finish_bundle(self):
     if self.writer is not None:
diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py
index 81d75bbe66f..289c91e23b0 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -435,6 +435,9 @@ class _TextSink(filebasedsink.FileBasedSink):
                compression_type=CompressionTypes.AUTO,
                header=None,
                footer=None,
+               *,
+               max_records_per_shard=None,
+               max_bytes_per_shard=None,
                skip_if_empty=False):
     """Initialize a _TextSink.
 
@@ -469,6 +472,14 @@ class _TextSink(filebasedsink.FileBasedSink):
         append_trailing_newlines is set, '\n' will be added.
       footer: String to write at the end of file as a footer. If not None and
         append_trailing_newlines is set, '\n' will be added.
+      max_records_per_shard: Maximum number of records to write to any
+        individual shard.
+      max_bytes_per_shard: Target maximum number of bytes to write to any
+        individual shard. This may be exceeded slightly, as a new shard is
+        created once this limit is hit, but the remainder of a given record, a
+        subsequent newline, and a footer may cause the actual shard size
+        to exceed this value.  This also tracks the uncompressed,
+        not compressed, size of the shard.
       skip_if_empty: Don't write any shards if the PCollection is empty.
 
     Returns:
@@ -482,6 +493,8 @@ class _TextSink(filebasedsink.FileBasedSink):
         coder=coder,
         mime_type='text/plain',
         compression_type=compression_type,
+        max_records_per_shard=max_records_per_shard,
+        max_bytes_per_shard=max_bytes_per_shard,
         skip_if_empty=skip_if_empty)
     self._append_trailing_newlines = append_trailing_newlines
     self._header = header
@@ -791,6 +804,9 @@ class WriteToText(PTransform):
       compression_type=CompressionTypes.AUTO,
       header=None,
       footer=None,
+      *,
+      max_records_per_shard=None,
+      max_bytes_per_shard=None,
       skip_if_empty=False):
     r"""Initialize a :class:`WriteToText` transform.
 
@@ -830,6 +846,14 @@ class WriteToText(PTransform):
       footer (str): String to write at the end of file as a footer.
         If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
         be added.
+      max_records_per_shard: Maximum number of records to write to any
+        individual shard.
+      max_bytes_per_shard: Target maximum number of bytes to write to any
+        individual shard. This may be exceeded slightly, as a new shard is
+        created once this limit is hit, but the remainder of a given record, a
+        subsequent newline, and a footer may cause the actual shard size
+        to exceed this value.  This also tracks the uncompressed,
+        not compressed, size of the shard.
       skip_if_empty: Don't write any shards if the PCollection is empty.
     """
 
@@ -843,7 +867,9 @@ class WriteToText(PTransform):
         compression_type,
         header,
         footer,
-        skip_if_empty)
+        max_records_per_shard=max_records_per_shard,
+        max_bytes_per_shard=max_bytes_per_shard,
+        skip_if_empty=skip_if_empty)
 
   def expand(self, pcoll):
     return pcoll | Write(self._sink)
diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py
index 6b4d6d2bb7e..6fb8d6ccb36 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -1668,6 +1668,48 @@ class TextSinkTest(unittest.TestCase):
     outputs = list(glob.glob(self.path + '*'))
     self.assertEqual(outputs, [])
 
+  def test_write_max_records_per_shard(self):
+    records_per_shard = 13
+    lines = [str(i).encode('utf-8') for i in range(100)]
+    with TestPipeline() as p:
+      # pylint: disable=expression-not-assigned
+      p | beam.core.Create(lines) | WriteToText(
+          self.path, max_records_per_shard=records_per_shard)
+
+    read_result = []
+    for file_name in glob.glob(self.path + '*'):
+      with open(file_name, 'rb') as f:
+        shard_lines = list(f.read().splitlines())
+        self.assertLessEqual(len(shard_lines), records_per_shard)
+        read_result.extend(shard_lines)
+    self.assertEqual(sorted(read_result), sorted(lines))
+
+  def test_write_max_bytes_per_shard(self):
+    bytes_per_shard = 300
+    max_len = 100
+    lines = [b'x' * i for i in range(max_len)]
+    header = b'a' * 20
+    footer = b'b' * 30
+    with TestPipeline() as p:
+      # pylint: disable=expression-not-assigned
+      p | beam.core.Create(lines) | WriteToText(
+          self.path,
+          header=header,
+          footer=footer,
+          max_bytes_per_shard=bytes_per_shard)
+
+    read_result = []
+    for file_name in glob.glob(self.path + '*'):
+      with open(file_name, 'rb') as f:
+        contents = f.read()
+        self.assertLessEqual(
+            len(contents), bytes_per_shard + max_len + len(footer) + 2)
+        shard_lines = list(contents.splitlines())
+        self.assertEqual(shard_lines[0], header)
+        self.assertEqual(shard_lines[-1], footer)
+        read_result.extend(shard_lines[1:-1])
+    self.assertEqual(sorted(read_result), sorted(lines))
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)