You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/01/09 21:14:25 UTC
[1/3] beam git commit: Create TFRecordIO,
which provides source/sink for TFRecords,
the dedicated record format for Tensorflow.
Repository: beam
Updated Branches:
refs/heads/python-sdk 69d8f2bf1 -> a25515171
Create TFRecordIO, which provides source/sink for TFRecords, the dedicated record format for Tensorflow.
For more about TFRecords, refer to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/python_io.md
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/88833ba5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/88833ba5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/88833ba5
Branch: refs/heads/python-sdk
Commit: 88833ba52bf0a3ac6668bcaa73ca383771d5e1b7
Parents: 69d8f2b
Author: Younghee Kwon <yo...@gmail.com>
Authored: Fri Jan 6 18:05:56 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:45 2017 -0800
----------------------------------------------------------------------
sdks/python/apache_beam/io/__init__.py | 1 +
sdks/python/apache_beam/io/tfrecordio.py | 271 +++++++++++++++
sdks/python/apache_beam/io/tfrecordio_test.py | 365 +++++++++++++++++++++
sdks/python/setup.py | 1 +
4 files changed, 638 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/__init__.py b/sdks/python/apache_beam/io/__init__.py
index 4ce9872..13ce36f 100644
--- a/sdks/python/apache_beam/io/__init__.py
+++ b/sdks/python/apache_beam/io/__init__.py
@@ -27,4 +27,5 @@ from apache_beam.io.iobase import Write
from apache_beam.io.iobase import Writer
from apache_beam.io.pubsub import *
from apache_beam.io.textio import *
+from apache_beam.io.tfrecordio import *
from apache_beam.io.range_trackers import *
http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/tfrecordio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py
new file mode 100644
index 0000000..be9f839
--- /dev/null
+++ b/sdks/python/apache_beam/io/tfrecordio.py
@@ -0,0 +1,271 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""TFRecord sources and sinks."""
+
+from __future__ import absolute_import
+
+import logging
+import struct
+
+from apache_beam import coders
+from apache_beam.io import filebasedsource
+from apache_beam.io import fileio
+from apache_beam.io.iobase import Read
+from apache_beam.io.iobase import Write
+from apache_beam.transforms import PTransform
+import crcmod
+
+__all__ = ['ReadFromTFRecord', 'WriteToTFRecord']
+
+
+def _default_crc32c_fn(value):
+ """Calculates crc32c by either snappy or crcmod based on installation."""
+
+ if not _default_crc32c_fn.fn:
+ try:
+ import snappy # pylint: disable=import-error
+ _default_crc32c_fn.fn = snappy._crc32c # pylint: disable=protected-access
+ except ImportError:
+ logging.warning('Couldn\'t find python-snappy so the implementation of '
+ '_TFRecordUtil._masked_crc32c is not as fast as it could '
+ 'be.')
+ _default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
+ return _default_crc32c_fn.fn(value)
+_default_crc32c_fn.fn = None
+
+
+class _TFRecordUtil(object):
+ """Provides basic TFRecord encoding/decoding with consistency checks.
+
+ For detailed TFRecord format description see:
+ https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details
+
+ Note that masks and length are represented in LittleEndian order.
+ """
+
+ @classmethod
+ def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn):
+ """Compute a masked crc32c checksum for a value.
+
+ Args:
+ value: A string for which we compute the crc.
+ crc32c_fn: A function that can compute a crc32c.
+ This is a performance hook that also helps with testing. Callers are
+ not expected to make use of it directly.
+ Returns:
+ Masked crc32c checksum.
+ """
+
+ crc = crc32c_fn(value)
+ return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff
+
+ @staticmethod
+ def encoded_num_bytes(record):
+ """Return the number of bytes consumed by a record in its encoded form."""
+ # 16 = 8 (Length) + 4 (crc of length) + 4 (crc of data)
+ return len(record) + 16
+
+ @classmethod
+ def write_record(cls, file_handle, value):
+ """Encode a value as a TFRecord.
+
+ Args:
+ file_handle: The file to write to.
+ value: A string content of the record.
+ """
+ encoded_length = struct.pack('<Q', len(value))
+ file_handle.write('{}{}{}{}'.format(
+ encoded_length,
+ struct.pack('<I', cls._masked_crc32c(encoded_length)), #
+ value,
+ struct.pack('<I', cls._masked_crc32c(value))))
+
+ @classmethod
+ def read_record(cls, file_handle):
+ """Read a record from a TFRecords file.
+
+ Args:
+ file_handle: The file to read from.
+ Returns:
+ None if EOF is reached; the paylod of the record otherwise.
+ Raises:
+ ValueError: If file appears to not be a valid TFRecords file.
+ """
+ buf_length_expected = 12
+ buf = file_handle.read(buf_length_expected)
+ if not buf:
+ return None # EOF Reached.
+
+ # Validate all length related payloads.
+ if len(buf) != buf_length_expected:
+ raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' %
+ (buf_length_expected, buf.encode('hex')))
+ length, length_mask_expected = struct.unpack('<QI', buf)
+ length_mask_actual = cls._masked_crc32c(buf[:8])
+ if length_mask_actual != length_mask_expected:
+ raise ValueError('Not a valid TFRecord. Mismatch of length mask: %s' %
+ buf.encode('hex'))
+
+ # Validate all data related payloads.
+ buf_length_expected = length + 4
+ buf = file_handle.read(buf_length_expected)
+ if len(buf) != buf_length_expected:
+ raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' %
+ (buf_length_expected, buf.encode('hex')))
+ data, data_mask_expected = struct.unpack('<%dsI' % length, buf)
+ data_mask_actual = cls._masked_crc32c(data)
+ if data_mask_actual != data_mask_expected:
+ raise ValueError('Not a valid TFRecord. Mismatch of data mask: %s' %
+ buf.encode('hex'))
+
+ # All validation checks passed.
+ return data
+
+
+class _TFRecordSource(filebasedsource.FileBasedSource):
+ """A File source for reading files of TFRecords.
+
+ For detailed TFRecords format description see:
+ https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details
+ """
+
+ def __init__(self,
+ file_pattern,
+ coder,
+ compression_type):
+ """Initialize a TFRecordSource. See ReadFromTFRecord for details."""
+ super(_TFRecordSource, self).__init__(
+ file_pattern=file_pattern,
+ compression_type=compression_type,
+ splittable=False)
+ self._coder = coder
+
+ def read_records(self, file_name, offset_range_tracker):
+ if offset_range_tracker.start_position():
+ raise ValueError('Start position not 0:%s' %
+ offset_range_tracker.start_position())
+
+ current_offset = offset_range_tracker.start_position()
+ with self.open_file(file_name) as file_handle:
+ while True:
+ if not offset_range_tracker.try_claim(current_offset):
+ raise RuntimeError('Unable to claim position: %s' % current_offset)
+ record = _TFRecordUtil.read_record(file_handle)
+ if record is None:
+ return # Reached EOF
+ else:
+ current_offset += _TFRecordUtil.encoded_num_bytes(record)
+ yield self._coder.decode(record)
+
+
+class ReadFromTFRecord(PTransform):
+ """Transform for reading TFRecord sources."""
+
+ def __init__(self,
+ file_pattern,
+ coder=coders.BytesCoder(),
+ compression_type=fileio.CompressionTypes.AUTO,
+ **kwargs):
+ """Initialize a ReadFromTFRecord transform.
+
+ Args:
+ file_pattern: A file glob pattern to read TFRecords from.
+ coder: Coder used to decode each record.
+ compression_type: Used to handle compressed input files. Default value
+ is CompressionTypes.AUTO, in which case the file_path's extension will
+ be used to detect the compression.
+ **kwargs: optional args dictionary. These are passed through to parent
+ constructor.
+
+ Returns:
+ A ReadFromTFRecord transform object.
+ """
+ super(ReadFromTFRecord, self).__init__(**kwargs)
+ self._args = (file_pattern, coder, compression_type)
+
+ def expand(self, pvalue):
+ return pvalue.pipeline | Read(_TFRecordSource(*self._args))
+
+
+class _TFRecordSink(fileio.FileSink):
+ """Sink for writing TFRecords files.
+
+ For detailed TFRecord format description see:
+ https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details
+ """
+
+ def __init__(self, file_path_prefix, coder, file_name_suffix, num_shards,
+ shard_name_template, compression_type):
+ """Initialize a TFRecordSink. See WriteToTFRecord for details."""
+
+ super(_TFRecordSink, self).__init__(
+ file_path_prefix=file_path_prefix,
+ coder=coder,
+ file_name_suffix=file_name_suffix,
+ num_shards=num_shards,
+ shard_name_template=shard_name_template,
+ mime_type='application/octet-stream',
+ compression_type=compression_type)
+
+ def write_encoded_record(self, file_handle, value):
+ _TFRecordUtil.write_record(file_handle, value)
+
+
+class WriteToTFRecord(PTransform):
+ """Transform for writing to TFRecord sinks."""
+
+ def __init__(self,
+ file_path_prefix,
+ coder=coders.BytesCoder(),
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=fileio.DEFAULT_SHARD_NAME_TEMPLATE,
+ compression_type=fileio.CompressionTypes.AUTO,
+ **kwargs):
+ """Initialize WriteToTFRecord transform.
+
+ Args:
+ file_path_prefix: The file path to write to. The files written will begin
+ with this prefix, followed by a shard identifier (see num_shards), and
+ end in a common extension, if given by file_name_suffix.
+ coder: Coder used to encode each record.
+ file_name_suffix: Suffix for the files written.
+ num_shards: The number of files (shards) used for output. If not set, the
+ default value will be used.
+ shard_name_template: A template string containing placeholders for
+ the shard number and shard count. Currently only '' and
+ '-SSSSS-of-NNNNN' are patterns allowed.
+ When constructing a filename for a particular shard number, the
+ upper-case letters 'S' and 'N' are replaced with the 0-padded shard
+ number and shard count respectively. This argument can be '' in which
+ case it behaves as if num_shards was set to 1 and only one file will be
+ generated. The default pattern is '-SSSSS-of-NNNNN'.
+ compression_type: Used to handle compressed output files. Typical value
+ is CompressionTypes.AUTO, in which case the file_path's extension will
+ be used to detect the compression.
+ **kwargs: Optional args dictionary. These are passed through to parent
+ constructor.
+
+ Returns:
+ A WriteToTFRecord transform object.
+ """
+ super(WriteToTFRecord, self).__init__(**kwargs)
+ self._args = (file_path_prefix, coder, file_name_suffix, num_shards,
+ shard_name_template, compression_type)
+
+ def expand(self, pcoll):
+ return pcoll | Write(_TFRecordSink(*self._args))
http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/tfrecordio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
new file mode 100644
index 0000000..ee287b3
--- /dev/null
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -0,0 +1,365 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import binascii
+import cStringIO
+import glob
+import gzip
+import logging
+import pickle
+import random
+import tempfile
+import unittest
+
+import apache_beam as beam
+from apache_beam import coders
+from apache_beam.io import fileio
+from apache_beam.io.tfrecordio import _TFRecordSink
+from apache_beam.io.tfrecordio import _TFRecordSource
+from apache_beam.io.tfrecordio import _TFRecordUtil
+from apache_beam.io.tfrecordio import ReadFromTFRecord
+from apache_beam.io.tfrecordio import WriteToTFRecord
+from apache_beam.runners import DirectRunner
+import crcmod
+
+
+try:
+ import tensorflow as tf # pylint: disable=import-error
+except ImportError:
+ tf = None # pylint: disable=invalid-name
+ logging.warning('Tensorflow is not installed, so skipping some tests.')
+
+# Created by running following code in python:
+# >>> import tensorflow as tf
+# >>> import base64
+# >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+# >>> writer.write('foo')
+# >>> writer.close()
+# >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+# ... data = base64.b64encode(f.read())
+# ... print data
+FOO_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/g=='
+
+# Same as above but containing two records ['foo', 'bar']
+FOO_BAR_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg='
+
+
+class TestTFRecordUtil(unittest.TestCase):
+
+ def setUp(self):
+ self.record = binascii.a2b_base64(FOO_RECORD_BASE64)
+
+ def _as_file_handle(self, contents):
+ result = cStringIO.StringIO()
+ result.write(contents)
+ result.reset()
+ return result
+
+ def _increment_value_at_index(self, value, index):
+ l = list(value)
+ l[index] = chr(ord(l[index]) + 1)
+ return ''.join(l)
+
+ def _test_error(self, record, error_text):
+ with self.assertRaises(ValueError) as context:
+ _TFRecordUtil.read_record(self._as_file_handle(record))
+ self.assertIn(error_text, context.exception.message)
+
+ def test_masked_crc32c(self):
+ self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c('\x00' * 32))
+ self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c('\xff' * 32))
+ self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c('foo'))
+ self.assertEqual(
+ 0xe4999b0,
+ _TFRecordUtil._masked_crc32c('\x03\x00\x00\x00\x00\x00\x00\x00'))
+
+ def test_masked_crc32c_crcmod(self):
+ crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
+ self.assertEqual(
+ 0xfd7fffa,
+ _TFRecordUtil._masked_crc32c(
+ '\x00' * 32, crc32c_fn=crc32c_fn))
+ self.assertEqual(
+ 0xf909b029,
+ _TFRecordUtil._masked_crc32c(
+ '\xff' * 32, crc32c_fn=crc32c_fn))
+ self.assertEqual(
+ 0xfebe8a61, _TFRecordUtil._masked_crc32c(
+ 'foo', crc32c_fn=crc32c_fn))
+ self.assertEqual(
+ 0xe4999b0,
+ _TFRecordUtil._masked_crc32c(
+ '\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn))
+
+ def test_write_record(self):
+ file_handle = cStringIO.StringIO()
+ _TFRecordUtil.write_record(file_handle, 'foo')
+ self.assertEqual(self.record, file_handle.getvalue())
+
+ def test_read_record(self):
+ actual = _TFRecordUtil.read_record(self._as_file_handle(self.record))
+ self.assertEqual('foo', actual)
+
+ def test_read_record_invalid_record(self):
+ self._test_error('bar', 'Not a valid TFRecord. Fewer than 12 bytes')
+
+ def test_read_record_invalid_length_mask(self):
+ record = self._increment_value_at_index(self.record, 9)
+ self._test_error(record, 'Mismatch of length mask')
+
+ def test_read_record_invalid_data_mask(self):
+ record = self._increment_value_at_index(self.record, 16)
+ self._test_error(record, 'Mismatch of data mask')
+
+ def test_compatibility_read_write(self):
+ for record in ['', 'blah', 'another blah']:
+ file_handle = cStringIO.StringIO()
+ _TFRecordUtil.write_record(file_handle, record)
+ file_handle.reset()
+ actual = _TFRecordUtil.read_record(file_handle)
+ self.assertEqual(record, actual)
+
+
+class TestTFRecordSink(unittest.TestCase):
+
+ def _write_lines(self, sink, path, lines):
+ f = sink.open(path)
+ for l in lines:
+ sink.write_record(f, l)
+ sink.close(f)
+
+ def test_write_record_single(self):
+ path = tempfile.NamedTemporaryFile().name
+ record = binascii.a2b_base64(FOO_RECORD_BASE64)
+ sink = _TFRecordSink(
+ path,
+ coder=coders.BytesCoder(),
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=None,
+ compression_type=fileio.CompressionTypes.UNCOMPRESSED)
+ self._write_lines(sink, path, ['foo'])
+
+ with open(path, 'r') as f:
+ self.assertEqual(f.read(), record)
+
+ def test_write_record_multiple(self):
+ path = tempfile.NamedTemporaryFile().name
+ record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
+ sink = _TFRecordSink(
+ path,
+ coder=coders.BytesCoder(),
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=None,
+ compression_type=fileio.CompressionTypes.UNCOMPRESSED)
+ self._write_lines(sink, path, ['foo', 'bar'])
+
+ with open(path, 'r') as f:
+ self.assertEqual(f.read(), record)
+
+
+@unittest.skipIf(tf is None, 'tensorflow not installed.')
+class TestWriteToTFRecord(TestTFRecordSink):
+
+ def test_write_record_gzip(self):
+ with beam.Pipeline(DirectRunner()) as p:
+ file_path_prefix = tempfile.NamedTemporaryFile().name
+ input_data = ['foo', 'bar']
+ _ = p | beam.Create(input_data) | WriteToTFRecord(
+ file_path_prefix, compression_type=fileio.CompressionTypes.GZIP)
+
+ actual = []
+ file_name = glob.glob(file_path_prefix + '-*')[0]
+ for r in tf.python_io.tf_record_iterator(
+ file_name, options=tf.python_io.TFRecordOptions(
+ tf.python_io.TFRecordCompressionType.GZIP)):
+ actual.append(r)
+ self.assertEqual(actual, input_data)
+
+ def test_write_record_auto(self):
+ with beam.Pipeline(DirectRunner()) as p:
+ file_path_prefix = tempfile.NamedTemporaryFile().name
+ input_data = ['foo', 'bar']
+ _ = p | beam.Create(input_data) | WriteToTFRecord(
+ file_path_prefix, file_name_suffix='.gz')
+
+ actual = []
+ file_name = glob.glob(file_path_prefix + '-*.gz')[0]
+ for r in tf.python_io.tf_record_iterator(
+ file_name, options=tf.python_io.TFRecordOptions(
+ tf.python_io.TFRecordCompressionType.GZIP)):
+ actual.append(r)
+ self.assertEqual(actual, input_data)
+
+
+class TestTFRecordSource(unittest.TestCase):
+
+ def _write_file(self, path, base64_records):
+ record = binascii.a2b_base64(base64_records)
+ with open(path, 'wb') as f:
+ f.write(record)
+
+ def _write_file_gzip(self, path, base64_records):
+ record = binascii.a2b_base64(base64_records)
+ with gzip.GzipFile(path, 'wb') as f:
+ f.write(record)
+
+ def test_process_single(self):
+ path = tempfile.NamedTemporaryFile().name
+ self._write_file(path, FOO_RECORD_BASE64)
+ with beam.Pipeline(DirectRunner()) as p:
+ result = (p
+ | beam.Read(
+ _TFRecordSource(
+ path,
+ coder=coders.BytesCoder(),
+ compression_type=fileio.CompressionTypes.AUTO)))
+ beam.assert_that(result, beam.equal_to(['foo']))
+
+ def test_process_multiple(self):
+ path = tempfile.NamedTemporaryFile().name
+ self._write_file(path, FOO_BAR_RECORD_BASE64)
+ with beam.Pipeline(DirectRunner()) as p:
+ result = (p
+ | beam.Read(
+ _TFRecordSource(
+ path,
+ coder=coders.BytesCoder(),
+ compression_type=fileio.CompressionTypes.AUTO)))
+ beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+ def test_process_gzip(self):
+ path = tempfile.NamedTemporaryFile().name
+ self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+ with beam.Pipeline(DirectRunner()) as p:
+ result = (p
+ | beam.Read(
+ _TFRecordSource(
+ path,
+ coder=coders.BytesCoder(),
+ compression_type=fileio.CompressionTypes.GZIP)))
+ beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+ def test_process_auto(self):
+ path = tempfile.mkstemp(suffix='.gz')[1]
+ self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+ with beam.Pipeline(DirectRunner()) as p:
+ result = (p
+ | beam.Read(
+ _TFRecordSource(
+ path,
+ coder=coders.BytesCoder(),
+ compression_type=fileio.CompressionTypes.AUTO)))
+ beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+
+class TestReadFromTFRecordSource(TestTFRecordSource):
+
+ def test_process_gzip(self):
+ path = tempfile.NamedTemporaryFile().name
+ self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+ with beam.Pipeline(DirectRunner()) as p:
+ result = (p
+ | ReadFromTFRecord(
+ path, compression_type=fileio.CompressionTypes.GZIP))
+ beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+ def test_process_gzip_auto(self):
+ path = tempfile.mkstemp(suffix='.gz')[1]
+ self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+ with beam.Pipeline(DirectRunner()) as p:
+ result = (p
+ | ReadFromTFRecord(
+ path, compression_type=fileio.CompressionTypes.AUTO))
+ beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+
+class TestEnd2EndWriteAndRead(unittest.TestCase):
+
+ def create_inputs(self):
+ input_array = [[random.random() - 0.5 for _ in xrange(15)]
+ for _ in xrange(12)]
+ memfile = cStringIO.StringIO()
+ pickle.dump(input_array, memfile)
+ return memfile.getvalue()
+
+ def test_end2end(self):
+ file_path_prefix = tempfile.NamedTemporaryFile().name
+
+ # Generate a TFRecord file.
+ with beam.Pipeline(DirectRunner()) as p:
+ expected_data = [self.create_inputs() for _ in range(0, 10)]
+ _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix)
+
+ # Read the file back and compare.
+ with beam.Pipeline(DirectRunner()) as p:
+ actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
+ beam.assert_that(actual_data, beam.equal_to(expected_data))
+
+ def test_end2end_auto_compression(self):
+ file_path_prefix = tempfile.NamedTemporaryFile().name
+
+ # Generate a TFRecord file.
+ with beam.Pipeline(DirectRunner()) as p:
+ expected_data = [self.create_inputs() for _ in range(0, 10)]
+ _ = p | beam.Create(expected_data) | WriteToTFRecord(
+ file_path_prefix, file_name_suffix='.gz')
+
+ # Read the file back and compare.
+ with beam.Pipeline(DirectRunner()) as p:
+ actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
+ beam.assert_that(actual_data, beam.equal_to(expected_data))
+
+ def test_end2end_auto_compression_unsharded(self):
+ file_path_prefix = tempfile.NamedTemporaryFile().name
+
+ # Generate a TFRecord file.
+ with beam.Pipeline(DirectRunner()) as p:
+ expected_data = [self.create_inputs() for _ in range(0, 10)]
+ _ = p | beam.Create(expected_data) | WriteToTFRecord(
+ file_path_prefix + '.gz', shard_name_template='')
+
+ # Read the file back and compare.
+ with beam.Pipeline(DirectRunner()) as p:
+ actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz')
+ beam.assert_that(actual_data, beam.equal_to(expected_data))
+
+ @unittest.skipIf(tf is None, 'tensorflow not installed.')
+ def test_end2end_example_proto(self):
+ file_path_prefix = tempfile.NamedTemporaryFile().name
+
+ example = tf.train.Example()
+ example.features.feature['int'].int64_list.value.extend(range(3))
+ example.features.feature['bytes'].bytes_list.value.extend(
+ [b'foo', b'bar'])
+
+ with beam.Pipeline(DirectRunner()) as p:
+ _ = p | beam.Create([example]) | WriteToTFRecord(
+ file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__))
+
+ # Read the file back and compare.
+ with beam.Pipeline(DirectRunner()) as p:
+ actual_data = (p | ReadFromTFRecord(
+ file_path_prefix + '-*',
+ coder=beam.coders.ProtoCoder(example.__class__)))
+ beam.assert_that(actual_data, beam.equal_to([example]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/setup.py
----------------------------------------------------------------------
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index f6357b6..1fd622f 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -85,6 +85,7 @@ else:
REQUIRED_PACKAGES = [
'avro>=1.7.7,<2.0.0',
+ 'crcmod>=1.7,<2.0',
'dill>=0.2.5,<0.3',
'google-apitools>=0.5.6,<1.0.0',
'googledatastore==6.4.1',
[3/3] beam git commit: Closes #1749
Posted by ro...@apache.org.
Closes #1749
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a2551517
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a2551517
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a2551517
Branch: refs/heads/python-sdk
Commit: a2551517184e55ecf01d07dffd921b105f7dc448
Parents: 69d8f2b 93e8d19
Author: Robert Bradshaw <ro...@google.com>
Authored: Mon Jan 9 13:13:55 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:55 2017 -0800
----------------------------------------------------------------------
sdks/python/apache_beam/io/__init__.py | 1 +
sdks/python/apache_beam/io/tfrecordio.py | 271 ++++++++++++++
sdks/python/apache_beam/io/tfrecordio_test.py | 389 +++++++++++++++++++++
sdks/python/setup.py | 1 +
4 files changed, 662 insertions(+)
----------------------------------------------------------------------
[2/3] beam git commit: Provided temporary directory management for
test cases.
Posted by ro...@apache.org.
Provided temporary directory management for test cases.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/93e8d19e
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/93e8d19e
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/93e8d19e
Branch: refs/heads/python-sdk
Commit: 93e8d19e32807fb5279ed711f0f06c3123adfb2e
Parents: 88833ba
Author: Younghee Kwon <yo...@gmail.com>
Authored: Mon Jan 9 11:50:57 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:46 2017 -0800
----------------------------------------------------------------------
sdks/python/apache_beam/io/tfrecordio_test.py | 58 +++++++++++++++-------
1 file changed, 41 insertions(+), 17 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/93e8d19e/sdks/python/apache_beam/io/tfrecordio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index ee287b3..ecd58f5 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -20,8 +20,10 @@ import cStringIO
import glob
import gzip
import logging
+import os
import pickle
import random
+import shutil
import tempfile
import unittest
@@ -134,7 +136,29 @@ class TestTFRecordUtil(unittest.TestCase):
self.assertEqual(record, actual)
-class TestTFRecordSink(unittest.TestCase):
+class _TestCaseWithTempDirCleanUp(unittest.TestCase):
+ """Base class for TestCases that deals with TempDir clean-up.
+
+ Inherited test cases will call self._new_tempdir() to start a temporary dir
+ which will be deleted at the end of the tests (when tearDown() is called).
+ """
+
+ def setUp(self):
+ self._tempdirs = []
+
+ def tearDown(self):
+ for path in self._tempdirs:
+ if os.path.exists(path):
+ shutil.rmtree(path)
+ self._tempdirs = []
+
+ def _new_tempdir(self):
+ result = tempfile.mkdtemp()
+ self._tempdirs.append(result)
+ return result
+
+
+class TestTFRecordSink(_TestCaseWithTempDirCleanUp):
def _write_lines(self, sink, path, lines):
f = sink.open(path)
@@ -143,7 +167,7 @@ class TestTFRecordSink(unittest.TestCase):
sink.close(f)
def test_write_record_single(self):
- path = tempfile.NamedTemporaryFile().name
+ path = os.path.join(self._new_tempdir(), 'result')
record = binascii.a2b_base64(FOO_RECORD_BASE64)
sink = _TFRecordSink(
path,
@@ -158,7 +182,7 @@ class TestTFRecordSink(unittest.TestCase):
self.assertEqual(f.read(), record)
def test_write_record_multiple(self):
- path = tempfile.NamedTemporaryFile().name
+ path = os.path.join(self._new_tempdir(), 'result')
record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
sink = _TFRecordSink(
path,
@@ -177,8 +201,8 @@ class TestTFRecordSink(unittest.TestCase):
class TestWriteToTFRecord(TestTFRecordSink):
def test_write_record_gzip(self):
+ file_path_prefix = os.path.join(self._new_tempdir(), 'result')
with beam.Pipeline(DirectRunner()) as p:
- file_path_prefix = tempfile.NamedTemporaryFile().name
input_data = ['foo', 'bar']
_ = p | beam.Create(input_data) | WriteToTFRecord(
file_path_prefix, compression_type=fileio.CompressionTypes.GZIP)
@@ -192,8 +216,8 @@ class TestWriteToTFRecord(TestTFRecordSink):
self.assertEqual(actual, input_data)
def test_write_record_auto(self):
+ file_path_prefix = os.path.join(self._new_tempdir(), 'result')
with beam.Pipeline(DirectRunner()) as p:
- file_path_prefix = tempfile.NamedTemporaryFile().name
input_data = ['foo', 'bar']
_ = p | beam.Create(input_data) | WriteToTFRecord(
file_path_prefix, file_name_suffix='.gz')
@@ -207,7 +231,7 @@ class TestWriteToTFRecord(TestTFRecordSink):
self.assertEqual(actual, input_data)
-class TestTFRecordSource(unittest.TestCase):
+class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
def _write_file(self, path, base64_records):
record = binascii.a2b_base64(base64_records)
@@ -220,7 +244,7 @@ class TestTFRecordSource(unittest.TestCase):
f.write(record)
def test_process_single(self):
- path = tempfile.NamedTemporaryFile().name
+ path = os.path.join(self._new_tempdir(), 'result')
self._write_file(path, FOO_RECORD_BASE64)
with beam.Pipeline(DirectRunner()) as p:
result = (p
@@ -232,7 +256,7 @@ class TestTFRecordSource(unittest.TestCase):
beam.assert_that(result, beam.equal_to(['foo']))
def test_process_multiple(self):
- path = tempfile.NamedTemporaryFile().name
+ path = os.path.join(self._new_tempdir(), 'result')
self._write_file(path, FOO_BAR_RECORD_BASE64)
with beam.Pipeline(DirectRunner()) as p:
result = (p
@@ -244,7 +268,7 @@ class TestTFRecordSource(unittest.TestCase):
beam.assert_that(result, beam.equal_to(['foo', 'bar']))
def test_process_gzip(self):
- path = tempfile.NamedTemporaryFile().name
+ path = os.path.join(self._new_tempdir(), 'result')
self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with beam.Pipeline(DirectRunner()) as p:
result = (p
@@ -256,7 +280,7 @@ class TestTFRecordSource(unittest.TestCase):
beam.assert_that(result, beam.equal_to(['foo', 'bar']))
def test_process_auto(self):
- path = tempfile.mkstemp(suffix='.gz')[1]
+ path = os.path.join(self._new_tempdir(), 'result.gz')
self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with beam.Pipeline(DirectRunner()) as p:
result = (p
@@ -271,7 +295,7 @@ class TestTFRecordSource(unittest.TestCase):
class TestReadFromTFRecordSource(TestTFRecordSource):
def test_process_gzip(self):
- path = tempfile.NamedTemporaryFile().name
+ path = os.path.join(self._new_tempdir(), 'result')
self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with beam.Pipeline(DirectRunner()) as p:
result = (p
@@ -280,7 +304,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource):
beam.assert_that(result, beam.equal_to(['foo', 'bar']))
def test_process_gzip_auto(self):
- path = tempfile.mkstemp(suffix='.gz')[1]
+ path = os.path.join(self._new_tempdir(), 'result.gz')
self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
with beam.Pipeline(DirectRunner()) as p:
result = (p
@@ -289,7 +313,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource):
beam.assert_that(result, beam.equal_to(['foo', 'bar']))
-class TestEnd2EndWriteAndRead(unittest.TestCase):
+class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp):
def create_inputs(self):
input_array = [[random.random() - 0.5 for _ in xrange(15)]
@@ -299,7 +323,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
return memfile.getvalue()
def test_end2end(self):
- file_path_prefix = tempfile.NamedTemporaryFile().name
+ file_path_prefix = os.path.join(self._new_tempdir(), 'result')
# Generate a TFRecord file.
with beam.Pipeline(DirectRunner()) as p:
@@ -312,7 +336,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
beam.assert_that(actual_data, beam.equal_to(expected_data))
def test_end2end_auto_compression(self):
- file_path_prefix = tempfile.NamedTemporaryFile().name
+ file_path_prefix = os.path.join(self._new_tempdir(), 'result')
# Generate a TFRecord file.
with beam.Pipeline(DirectRunner()) as p:
@@ -326,7 +350,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
beam.assert_that(actual_data, beam.equal_to(expected_data))
def test_end2end_auto_compression_unsharded(self):
- file_path_prefix = tempfile.NamedTemporaryFile().name
+ file_path_prefix = os.path.join(self._new_tempdir(), 'result')
# Generate a TFRecord file.
with beam.Pipeline(DirectRunner()) as p:
@@ -341,7 +365,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
@unittest.skipIf(tf is None, 'tensorflow not installed.')
def test_end2end_example_proto(self):
- file_path_prefix = tempfile.NamedTemporaryFile().name
+ file_path_prefix = os.path.join(self._new_tempdir(), 'result')
example = tf.train.Example()
example.features.feature['int'].int64_list.value.extend(range(3))