You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/11/01 17:32:44 UTC
[beam] branch master updated: Add WriteParquetBatched (#23030)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ef7c0c9a476 Add WriteParquetBatched (#23030)
ef7c0c9a476 is described below
commit ef7c0c9a476d5b9b09a2ac6c7ccae3c7b5135faa
Author: peridotml <10...@users.noreply.github.com>
AuthorDate: Tue Nov 1 13:32:38 2022 -0400
Add WriteParquetBatched (#23030)
* initial draft
* reuse parquet sink for writing rows and pa.Tables
* lint
* fix imports
* cr - rename value to table
* add doc string and example test
* fix doc strings
* specify doctest group to separate tests
Co-authored-by: Evan Sadler <ea...@gmail.com>
---
sdks/python/apache_beam/io/parquetio.py | 234 +++++++++++++++++++++------
sdks/python/apache_beam/io/parquetio_test.py | 36 ++++-
2 files changed, 216 insertions(+), 54 deletions(-)
diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py
index acbf1e23f20..dfcc1abec29 100644
--- a/sdks/python/apache_beam/io/parquetio.py
+++ b/sdks/python/apache_beam/io/parquetio.py
@@ -43,6 +43,7 @@ from apache_beam.io.iobase import Write
from apache_beam.transforms import DoFn
from apache_beam.transforms import ParDo
from apache_beam.transforms import PTransform
+from apache_beam.transforms import window
try:
import pyarrow as pa
@@ -60,7 +61,8 @@ __all__ = [
'ReadAllFromParquet',
'ReadFromParquetBatched',
'ReadAllFromParquetBatched',
- 'WriteToParquet'
+ 'WriteToParquet',
+ 'WriteToParquetBatched'
]
@@ -83,6 +85,67 @@ class _ArrowTableToRowDictionaries(DoFn):
yield row
+class _RowDictionariesToArrowTable(DoFn):
+ """ A DoFn that consumes python dictionarys and yields a pyarrow table."""
+ def __init__(
+ self,
+ schema,
+ row_group_buffer_size=64 * 1024 * 1024,
+ record_batch_size=1000):
+ self._schema = schema
+ self._row_group_buffer_size = row_group_buffer_size
+ self._buffer = [[] for _ in range(len(schema.names))]
+ self._buffer_size = record_batch_size
+ self._record_batches = []
+ self._record_batches_byte_size = 0
+
+ def process(self, row):
+ if len(self._buffer[0]) >= self._buffer_size:
+ self._flush_buffer()
+
+ if self._record_batches_byte_size >= self._row_group_buffer_size:
+ table = self._create_table()
+ yield table
+
+ # reorder the data in columnar format.
+ for i, n in enumerate(self._schema.names):
+ self._buffer[i].append(row[n])
+
+ def finish_bundle(self):
+ if len(self._buffer[0]) > 0:
+ self._flush_buffer()
+ if self._record_batches_byte_size > 0:
+ table = self._create_table()
+ yield window.GlobalWindows.windowed_value_at_end_of_window(table)
+
+ def display_data(self):
+ res = super().display_data()
+ res['row_group_buffer_size'] = str(self._row_group_buffer_size)
+ res['buffer_size'] = str(self._buffer_size)
+
+ return res
+
+ def _create_table(self):
+ table = pa.Table.from_batches(self._record_batches, schema=self._schema)
+ self._record_batches = []
+ self._record_batches_byte_size = 0
+ return table
+
+ def _flush_buffer(self):
+ arrays = [[] for _ in range(len(self._schema.names))]
+ for x, y in enumerate(self._buffer):
+ arrays[x] = pa.array(y, type=self._schema.types[x])
+ self._buffer[x] = []
+ rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
+ self._record_batches.append(rb)
+ size = 0
+ for x in arrays:
+ for b in x.buffers():
+ if b is not None:
+ size = size + b.size
+ self._record_batches_byte_size = self._record_batches_byte_size + size
+
+
class ReadFromParquetBatched(PTransform):
"""A :class:`~apache_beam.transforms.ptransform.PTransform` for reading
Parquet files as a `PCollection` of `pyarrow.Table`. This `PTransform` is
@@ -453,13 +516,127 @@ class WriteToParquet(PTransform):
A WriteToParquet transform usable for writing.
"""
super().__init__()
+ self._schema = schema
+ self._row_group_buffer_size = row_group_buffer_size
+ self._record_batch_size = record_batch_size
+
+ self._sink = \
+ _create_parquet_sink(
+ file_path_prefix,
+ schema,
+ codec,
+ use_deprecated_int96_timestamps,
+ use_compliant_nested_type,
+ file_name_suffix,
+ num_shards,
+ shard_name_template,
+ mime_type
+ )
+
+ def expand(self, pcoll):
+ return pcoll | ParDo(
+ _RowDictionariesToArrowTable(
+ self._schema, self._row_group_buffer_size,
+ self._record_batch_size)) | Write(self._sink)
+
+ def display_data(self):
+ return {
+ 'sink_dd': self._sink,
+ 'row_group_buffer_size': str(self._row_group_buffer_size)
+ }
+
+
+class WriteToParquetBatched(PTransform):
+ """A ``PTransform`` for writing parquet files from a `PCollection` of
+ `pyarrow.Table`.
+
+ This ``PTransform`` is currently experimental. No backward-compatibility
+ guarantees.
+ """
+ def __init__(
+ self,
+ file_path_prefix,
+ schema=None,
+ codec='none',
+ use_deprecated_int96_timestamps=False,
+ use_compliant_nested_type=False,
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=None,
+ mime_type='application/x-parquet',
+ ):
+ """Initialize a WriteToParquetBatched transform.
+
+ Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of
+ records. Each record is a pa.Table Schema must be specified like the
+ example below.
+
+ .. testsetup:: batched
+
+ from tempfile import NamedTemporaryFile
+ import glob
+ import os
+ import pyarrow
+
+ filename = NamedTemporaryFile(delete=False).name
+
+ .. testcode:: batched
+
+ table = pyarrow.Table.from_pylist([{'name': 'foo', 'age': 10},
+ {'name': 'bar', 'age': 20}])
+ with beam.Pipeline() as p:
+ records = p | 'Read' >> beam.Create([table])
+ _ = records | 'Write' >> beam.io.WriteToParquetBatched(filename,
+ pyarrow.schema(
+ [('name', pyarrow.string()), ('age', pyarrow.int64())]
+ )
+ )
+
+ .. testcleanup:: batched
+
+ for output in glob.glob('{}*'.format(filename)):
+ os.remove(output)
+
+ For more information on supported types and schema, please see the pyarrow
+ document.
+
+ Args:
+ file_path_prefix: The file path to write to. The files written will begin
+ with this prefix, followed by a shard identifier (see num_shards), and
+ end in a common extension, if given by file_name_suffix. In most cases,
+ only this argument is specified and num_shards, shard_name_template, and
+ file_name_suffix use default values.
+ schema: The schema to use, as type of ``pyarrow.Schema``.
+ codec: The codec to use for block-level compression. Any string supported
+ by the pyarrow specification is accepted.
+ use_deprecated_int96_timestamps: Write nanosecond resolution timestamps to
+ INT96 Parquet format. Defaults to False.
+ use_compliant_nested_type: Write compliant Parquet nested type (lists).
+ file_name_suffix: Suffix for the files written.
+ num_shards: The number of files (shards) used for output. If not set, the
+ service will decide on the optimal number of shards.
+ Constraining the number of shards is likely to reduce
+ the performance of a pipeline. Setting this value is not recommended
+ unless you require a specific number of output files.
+ shard_name_template: A template string containing placeholders for
+ the shard number and shard count. When constructing a filename for a
+ particular shard number, the upper-case letters 'S' and 'N' are
+ replaced with the 0-padded shard number and shard count respectively.
+ This argument can be '' in which case it behaves as if num_shards was
+ set to 1 and only one file will be generated. The default pattern used
+ is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
+ mime_type: The MIME type to use for the produced files, if the filesystem
+ supports specifying MIME types.
+
+ Returns:
+ A WriteToParquetBatched transform usable for writing.
+ """
+ super().__init__()
self._sink = \
_create_parquet_sink(
file_path_prefix,
schema,
codec,
- row_group_buffer_size,
- record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
@@ -479,8 +656,6 @@ def _create_parquet_sink(
file_path_prefix,
schema,
codec,
- row_group_buffer_size,
- record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
@@ -492,8 +667,6 @@ def _create_parquet_sink(
file_path_prefix,
schema,
codec,
- row_group_buffer_size,
- record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
@@ -504,14 +677,12 @@ def _create_parquet_sink(
class _ParquetSink(filebasedsink.FileBasedSink):
- """A sink for parquet files."""
+ """A sink for parquet files from batches."""
def __init__(
self,
file_path_prefix,
schema,
codec,
- row_group_buffer_size,
- record_batch_size,
use_deprecated_int96_timestamps,
use_compliant_nested_type,
file_name_suffix,
@@ -535,7 +706,6 @@ class _ParquetSink(filebasedsink.FileBasedSink):
"Due to ARROW-9424, writing with LZ4 compression is not supported in "
"pyarrow 1.x, please use a different pyarrow version or a different "
f"codec. Your pyarrow version: {pa.__version__}")
- self._row_group_buffer_size = row_group_buffer_size
self._use_deprecated_int96_timestamps = use_deprecated_int96_timestamps
if use_compliant_nested_type and ARROW_MAJOR_VERSION < 4:
raise ValueError(
@@ -543,10 +713,6 @@ class _ParquetSink(filebasedsink.FileBasedSink):
"pyarrow version >= 4.x, please use a different pyarrow version. "
f"Your pyarrow version: {pa.__version__}")
self._use_compliant_nested_type = use_compliant_nested_type
- self._buffer = [[] for _ in range(len(schema.names))]
- self._buffer_size = record_batch_size
- self._record_batches = []
- self._record_batches_byte_size = 0
self._file_handle = None
def open(self, temp_path):
@@ -564,23 +730,10 @@ class _ParquetSink(filebasedsink.FileBasedSink):
use_deprecated_int96_timestamps=self._use_deprecated_int96_timestamps,
use_compliant_nested_type=self._use_compliant_nested_type)
- def write_record(self, writer, value):
- if len(self._buffer[0]) >= self._buffer_size:
- self._flush_buffer()
-
- if self._record_batches_byte_size >= self._row_group_buffer_size:
- self._write_batches(writer)
-
- # reorder the data in columnar format.
- for i, n in enumerate(self._schema.names):
- self._buffer[i].append(value[n])
+ def write_record(self, writer, table: pa.Table):
+ writer.write_table(table)
def close(self, writer):
- if len(self._buffer[0]) > 0:
- self._flush_buffer()
- if self._record_batches_byte_size > 0:
- self._write_batches(writer)
-
writer.close()
if self._file_handle:
self._file_handle.close()
@@ -590,25 +743,4 @@ class _ParquetSink(filebasedsink.FileBasedSink):
res = super().display_data()
res['codec'] = str(self._codec)
res['schema'] = str(self._schema)
- res['row_group_buffer_size'] = str(self._row_group_buffer_size)
return res
-
- def _write_batches(self, writer):
- table = pa.Table.from_batches(self._record_batches, schema=self._schema)
- self._record_batches = []
- self._record_batches_byte_size = 0
- writer.write_table(table)
-
- def _flush_buffer(self):
- arrays = [[] for _ in range(len(self._schema.names))]
- for x, y in enumerate(self._buffer):
- arrays[x] = pa.array(y, type=self._schema.types[x])
- self._buffer[x] = []
- rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
- self._record_batches.append(rb)
- size = 0
- for x in arrays:
- for b in x.buffers():
- if b is not None:
- size = size + b.size
- self._record_batches_byte_size = self._record_batches_byte_size + size
diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py
index 454a45493c4..df018a3a776 100644
--- a/sdks/python/apache_beam/io/parquetio_test.py
+++ b/sdks/python/apache_beam/io/parquetio_test.py
@@ -40,6 +40,7 @@ from apache_beam.io.parquetio import ReadAllFromParquetBatched
from apache_beam.io.parquetio import ReadFromParquet
from apache_beam.io.parquetio import ReadFromParquetBatched
from apache_beam.io.parquetio import WriteToParquet
+from apache_beam.io.parquetio import WriteToParquetBatched
from apache_beam.io.parquetio import _create_parquet_sink
from apache_beam.io.parquetio import _create_parquet_source
from apache_beam.testing.test_pipeline import TestPipeline
@@ -284,8 +285,6 @@ class TestParquet(unittest.TestCase):
file_name,
self.SCHEMA,
'none',
- 1024 * 1024,
- 1000,
False,
False,
'.end',
@@ -299,7 +298,6 @@ class TestParquet(unittest.TestCase):
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
DisplayDataItemMatcher('codec', 'none'),
- DisplayDataItemMatcher('row_group_buffer_size', str(1024 * 1024)),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
@@ -308,6 +306,7 @@ class TestParquet(unittest.TestCase):
file_name = 'some_parquet_sink'
write = WriteToParquet(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)
+
expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
@@ -319,6 +318,21 @@ class TestParquet(unittest.TestCase):
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+ def test_write_batched_display_data(self):
+ file_name = 'some_parquet_sink'
+ write = WriteToParquetBatched(file_name, self.SCHEMA)
+ dd = DisplayData.create_from(write)
+
+ expected_items = [
+ DisplayDataItemMatcher('codec', 'none'),
+ DisplayDataItemMatcher('schema', str(self.SCHEMA)),
+ DisplayDataItemMatcher(
+ 'file_pattern',
+ 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'),
+ DisplayDataItemMatcher('compression', 'uncompressed')
+ ]
+ hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
def test_sink_transform_int96(self):
with tempfile.NamedTemporaryFile() as dst:
path = dst.name
@@ -348,6 +362,22 @@ class TestParquet(unittest.TestCase):
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
+ def test_sink_transform_batched(self):
+ with TemporaryDirectory() as tmp_dirname:
+ path = os.path.join(tmp_dirname + "tmp_filename")
+ with TestPipeline() as p:
+ _ = p \
+ | Create([self._records_as_arrow()]) \
+ | WriteToParquetBatched(
+ path, self.SCHEMA, num_shards=1, shard_name_template='')
+ with TestPipeline() as p:
+ # json used for stable sortability
+ readback = \
+ p \
+ | ReadFromParquet(path) \
+ | Map(json.dumps)
+ assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
+
def test_sink_transform_compliant_nested_type(self):
if ARROW_MAJOR_VERSION < 4:
return unittest.skip(