You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/11/01 17:32:44 UTC

[beam] branch master updated: Add WriteParquetBatched (#23030)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ef7c0c9a476 Add WriteParquetBatched  (#23030)
ef7c0c9a476 is described below

commit ef7c0c9a476d5b9b09a2ac6c7ccae3c7b5135faa
Author: peridotml <10...@users.noreply.github.com>
AuthorDate: Tue Nov 1 13:32:38 2022 -0400

    Add WriteParquetBatched  (#23030)
    
    * initial draft
    
    * reuse parquet sink for writing rows and pa.Tables
    
    * lint
    
    * fix imports
    
    * cr - rename value to table
    
    * add doc string and example test
    
    * fix doc strings
    
    * specify doctest group to separate tests
    
    Co-authored-by: Evan Sadler <ea...@gmail.com>
---
 sdks/python/apache_beam/io/parquetio.py      | 234 +++++++++++++++++++++------
 sdks/python/apache_beam/io/parquetio_test.py |  36 ++++-
 2 files changed, 216 insertions(+), 54 deletions(-)

diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py
index acbf1e23f20..dfcc1abec29 100644
--- a/sdks/python/apache_beam/io/parquetio.py
+++ b/sdks/python/apache_beam/io/parquetio.py
@@ -43,6 +43,7 @@ from apache_beam.io.iobase import Write
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import ParDo
 from apache_beam.transforms import PTransform
+from apache_beam.transforms import window
 
 try:
   import pyarrow as pa
@@ -60,7 +61,8 @@ __all__ = [
     'ReadAllFromParquet',
     'ReadFromParquetBatched',
     'ReadAllFromParquetBatched',
-    'WriteToParquet'
+    'WriteToParquet',
+    'WriteToParquetBatched'
 ]
 
 
@@ -83,6 +85,67 @@ class _ArrowTableToRowDictionaries(DoFn):
         yield row
 
 
+class _RowDictionariesToArrowTable(DoFn):
+  """ A DoFn that consumes python dictionarys and yields a pyarrow table."""
+  def __init__(
+      self,
+      schema,
+      row_group_buffer_size=64 * 1024 * 1024,
+      record_batch_size=1000):
+    self._schema = schema
+    self._row_group_buffer_size = row_group_buffer_size
+    self._buffer = [[] for _ in range(len(schema.names))]
+    self._buffer_size = record_batch_size
+    self._record_batches = []
+    self._record_batches_byte_size = 0
+
+  def process(self, row):
+    if len(self._buffer[0]) >= self._buffer_size:
+      self._flush_buffer()
+
+    if self._record_batches_byte_size >= self._row_group_buffer_size:
+      table = self._create_table()
+      yield table
+
+    # reorder the data in columnar format.
+    for i, n in enumerate(self._schema.names):
+      self._buffer[i].append(row[n])
+
+  def finish_bundle(self):
+    if len(self._buffer[0]) > 0:
+      self._flush_buffer()
+    if self._record_batches_byte_size > 0:
+      table = self._create_table()
+      yield window.GlobalWindows.windowed_value_at_end_of_window(table)
+
+  def display_data(self):
+    res = super().display_data()
+    res['row_group_buffer_size'] = str(self._row_group_buffer_size)
+    res['buffer_size'] = str(self._buffer_size)
+
+    return res
+
+  def _create_table(self):
+    table = pa.Table.from_batches(self._record_batches, schema=self._schema)
+    self._record_batches = []
+    self._record_batches_byte_size = 0
+    return table
+
+  def _flush_buffer(self):
+    arrays = [[] for _ in range(len(self._schema.names))]
+    for x, y in enumerate(self._buffer):
+      arrays[x] = pa.array(y, type=self._schema.types[x])
+      self._buffer[x] = []
+    rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
+    self._record_batches.append(rb)
+    size = 0
+    for x in arrays:
+      for b in x.buffers():
+        if b is not None:
+          size = size + b.size
+    self._record_batches_byte_size = self._record_batches_byte_size + size
+
+
 class ReadFromParquetBatched(PTransform):
   """A :class:`~apache_beam.transforms.ptransform.PTransform` for reading
      Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is
@@ -453,13 +516,127 @@ class WriteToParquet(PTransform):
       A WriteToParquet transform usable for writing.
     """
     super().__init__()
+    self._schema = schema
+    self._row_group_buffer_size = row_group_buffer_size
+    self._record_batch_size = record_batch_size
+
+    self._sink = \
+      _create_parquet_sink(
+          file_path_prefix,
+          schema,
+          codec,
+          use_deprecated_int96_timestamps,
+          use_compliant_nested_type,
+          file_name_suffix,
+          num_shards,
+          shard_name_template,
+          mime_type
+      )
+
+  def expand(self, pcoll):
+    return pcoll | ParDo(
+        _RowDictionariesToArrowTable(
+            self._schema, self._row_group_buffer_size,
+            self._record_batch_size)) | Write(self._sink)
+
+  def display_data(self):
+    return {
+        'sink_dd': self._sink,
+        'row_group_buffer_size': str(self._row_group_buffer_size)
+    }
+
+
+class WriteToParquetBatched(PTransform):
+  """A ``PTransform`` for writing parquet files from a `PCollection` of
+    `pyarrow.Table`.
+
+    This ``PTransform`` is currently experimental. No backward-compatibility
+    guarantees.
+  """
+  def __init__(
+      self,
+      file_path_prefix,
+      schema=None,
+      codec='none',
+      use_deprecated_int96_timestamps=False,
+      use_compliant_nested_type=False,
+      file_name_suffix='',
+      num_shards=0,
+      shard_name_template=None,
+      mime_type='application/x-parquet',
+  ):
+    """Initialize a WriteToParquetBatched transform.
+
+    Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of
+    records. Each record is a pa.Table Schema must be specified like the
+    example below.
+
+    .. testsetup:: batched
+
+      from tempfile import NamedTemporaryFile
+      import glob
+      import os
+      import pyarrow
+
+      filename = NamedTemporaryFile(delete=False).name
+
+    .. testcode:: batched
+
+      table = pyarrow.Table.from_pylist([{'name': 'foo', 'age': 10},
+                                         {'name': 'bar', 'age': 20}])
+      with beam.Pipeline() as p:
+        records = p | 'Read' >> beam.Create([table])
+        _ = records | 'Write' >> beam.io.WriteToParquetBatched(filename,
+            pyarrow.schema(
+                [('name', pyarrow.string()), ('age', pyarrow.int64())]
+            )
+        )
+
+    .. testcleanup:: batched
+
+      for output in glob.glob('{}*'.format(filename)):
+        os.remove(output)
+
+    For more information on supported types and schema, please see the pyarrow
+    document.
+
+    Args:
+      file_path_prefix: The file path to write to. The files written will begin
+        with this prefix, followed by a shard identifier (see num_shards), and
+        end in a common extension, if given by file_name_suffix. In most cases,
+        only this argument is specified and num_shards, shard_name_template, and
+        file_name_suffix use default values.
+      schema: The schema to use, as type of ``pyarrow.Schema``.
+      codec: The codec to use for block-level compression. Any string supported
+        by the pyarrow specification is accepted.
+      use_deprecated_int96_timestamps: Write nanosecond resolution timestamps to
+        INT96 Parquet format. Defaults to False.
+      use_compliant_nested_type: Write compliant Parquet nested type (lists).
+      file_name_suffix: Suffix for the files written.
+      num_shards: The number of files (shards) used for output. If not set, the
+        service will decide on the optimal number of shards.
+        Constraining the number of shards is likely to reduce
+        the performance of a pipeline.  Setting this value is not recommended
+        unless you require a specific number of output files.
+      shard_name_template: A template string containing placeholders for
+        the shard number and shard count. When constructing a filename for a
+        particular shard number, the upper-case letters 'S' and 'N' are
+        replaced with the 0-padded shard number and shard count respectively.
+        This argument can be '' in which case it behaves as if num_shards was
+        set to 1 and only one file will be generated. The default pattern used
+        is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
+      mime_type: The MIME type to use for the produced files, if the filesystem
+        supports specifying MIME types.
+
+    Returns:
+      A WriteToParquetBatched transform usable for writing.
+    """
+    super().__init__()
     self._sink = \
       _create_parquet_sink(
           file_path_prefix,
           schema,
           codec,
-          row_group_buffer_size,
-          record_batch_size,
           use_deprecated_int96_timestamps,
           use_compliant_nested_type,
           file_name_suffix,
@@ -479,8 +656,6 @@ def _create_parquet_sink(
     file_path_prefix,
     schema,
     codec,
-    row_group_buffer_size,
-    record_batch_size,
     use_deprecated_int96_timestamps,
     use_compliant_nested_type,
     file_name_suffix,
@@ -492,8 +667,6 @@ def _create_parquet_sink(
         file_path_prefix,
         schema,
         codec,
-        row_group_buffer_size,
-        record_batch_size,
         use_deprecated_int96_timestamps,
         use_compliant_nested_type,
         file_name_suffix,
@@ -504,14 +677,12 @@ def _create_parquet_sink(
 
 
 class _ParquetSink(filebasedsink.FileBasedSink):
-  """A sink for parquet files."""
+  """A sink for parquet files from batches."""
   def __init__(
       self,
       file_path_prefix,
       schema,
       codec,
-      row_group_buffer_size,
-      record_batch_size,
       use_deprecated_int96_timestamps,
       use_compliant_nested_type,
       file_name_suffix,
@@ -535,7 +706,6 @@ class _ParquetSink(filebasedsink.FileBasedSink):
           "Due to ARROW-9424, writing with LZ4 compression is not supported in "
           "pyarrow 1.x, please use a different pyarrow version or a different "
           f"codec. Your pyarrow version: {pa.__version__}")
-    self._row_group_buffer_size = row_group_buffer_size
     self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
     if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4:
       raise ValueError(
@@ -543,10 +713,6 @@ class _ParquetSink(filebasedsink.FileBasedSink):
           "pyarrow version >= 4.x, please use a different pyarrow version. "
           f"Your pyarrow version: {pa.__version__}")
     self._use_compliant_nested_type = use_compliant_nested_type
-    self._buffer = [[] for _ in range(len(schema.names))]
-    self._buffer_size = record_batch_size
-    self._record_batches = []
-    self._record_batches_byte_size = 0
     self._file_handle = None
 
   def open(self, temp_path):
@@ -564,23 +730,10 @@ class _ParquetSink(filebasedsink.FileBasedSink):
         use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps,
         use_compliant_nested_type=self._use_compliant_nested_type)
 
-  def write_record(self, writer, value):
-    if len(self._buffer[0]) >= self._buffer_size:
-      self._flush_buffer()
-
-    if self._record_batches_byte_size >= self._row_group_buffer_size:
-      self._write_batches(writer)
-
-    # reorder the data in columnar format.
-    for i, n in enumerate(self._schema.names):
-      self._buffer[i].append(value[n])
+  def write_record(self, writer, table: pa.Table):
+    writer.write_table(table)
 
   def close(self, writer):
-    if len(self._buffer[0]) > 0:
-      self._flush_buffer()
-    if self._record_batches_byte_size > 0:
-      self._write_batches(writer)
-
     writer.close()
     if self._file_handle:
       self._file_handle.close()
@@ -590,25 +743,4 @@ class _ParquetSink(filebasedsink.FileBasedSink):
     res = super().display_data()
     res['codec'] = str(self._codec)
     res['schema'] = str(self._schema)
-    res['row_group_buffer_size'] = str(self._row_group_buffer_size)
     return res
-
-  def _write_batches(self, writer):
-    table = pa.Table.from_batches(self._record_batches, schema=self._schema)
-    self._record_batches = []
-    self._record_batches_byte_size = 0
-    writer.write_table(table)
-
-  def _flush_buffer(self):
-    arrays = [[] for _ in range(len(self._schema.names))]
-    for x, y in enumerate(self._buffer):
-      arrays[x] = pa.array(y, type=self._schema.types[x])
-      self._buffer[x] = []
-    rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
-    self._record_batches.append(rb)
-    size = 0
-    for x in arrays:
-      for b in x.buffers():
-        if b is not None:
-          size = size + b.size
-    self._record_batches_byte_size = self._record_batches_byte_size + size
diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py
index 454a45493c4..df018a3a776 100644
--- a/sdks/python/apache_beam/io/parquetio_test.py
+++ b/sdks/python/apache_beam/io/parquetio_test.py
@@ -40,6 +40,7 @@ from apache_beam.io.parquetio import ReadAllFromParquetBatched
 from apache_beam.io.parquetio import ReadFromParquet
 from apache_beam.io.parquetio import ReadFromParquetBatched
 from apache_beam.io.parquetio import WriteToParquet
+from apache_beam.io.parquetio import WriteToParquetBatched
 from apache_beam.io.parquetio import _create_parquet_sink
 from apache_beam.io.parquetio import _create_parquet_source
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -284,8 +285,6 @@ class TestParquet(unittest.TestCase):
         file_name,
         self.SCHEMA,
         'none',
-        1024 * 1024,
-        1000,
         False,
         False,
         '.end',
@@ -299,7 +298,6 @@ class TestParquet(unittest.TestCase):
             'file_pattern',
             'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
         DisplayDataItemMatcher('codec', 'none'),
-        DisplayDataItemMatcher('row_group_buffer_size', str(1024 * 1024)),
         DisplayDataItemMatcher('compression', 'uncompressed')
     ]
     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
@@ -308,6 +306,7 @@ class TestParquet(unittest.TestCase):
     file_name = 'some_parquet_sink'
     write = WriteToParquet(file_name, self.SCHEMA)
     dd = DisplayData.create_from(write)
+
     expected_items = [
         DisplayDataItemMatcher('codec', 'none'),
         DisplayDataItemMatcher('schema', str(self.SCHEMA)),
@@ -319,6 +318,21 @@ class TestParquet(unittest.TestCase):
     ]
     hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
 
+  def test_write_batched_display_data(self):
+    file_name = 'some_parquet_sink'
+    write = WriteToParquetBatched(file_name, self.SCHEMA)
+    dd = DisplayData.create_from(write)
+
+    expected_items = [
+        DisplayDataItemMatcher('codec', 'none'),
+        DisplayDataItemMatcher('schema', str(self.SCHEMA)),
+        DisplayDataItemMatcher(
+            'file_pattern',
+            'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'),
+        DisplayDataItemMatcher('compression', 'uncompressed')
+    ]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
   def test_sink_transform_int96(self):
     with tempfile.NamedTemporaryFile() as dst:
       path = dst.name
@@ -348,6 +362,22 @@ class TestParquet(unittest.TestCase):
             | Map(json.dumps)
         assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
 
+  def test_sink_transform_batched(self):
+    with TemporaryDirectory() as tmp_dirname:
+      path = os.path.join(tmp_dirname + "tmp_filename")
+      with TestPipeline() as p:
+        _ = p \
+        | Create([self._records_as_arrow()]) \
+        | WriteToParquetBatched(
+            path, self.SCHEMA, num_shards=1, shard_name_template='')
+      with TestPipeline() as p:
+        # json used for stable sortability
+        readback = \
+            p \
+            | ReadFromParquet(path) \
+            | Map(json.dumps)
+        assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
+
   def test_sink_transform_compliant_nested_type(self):
     if ARROW_MAJOR_VERSION < 4:
       return unittest.skip(