You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/01/09 21:14:25 UTC

[1/3] beam git commit: Create TFRecordIO, which provides source/sink for TFRecords, the dedicated record format for Tensorflow.

Repository: beam
Updated Branches:
  refs/heads/python-sdk 69d8f2bf1 -> a25515171


Create TFRecordIO, which provides source/sink for TFRecords, the dedicated record format for Tensorflow.

For more about TFRecords, refer to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/python_io.md


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/88833ba5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/88833ba5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/88833ba5

Branch: refs/heads/python-sdk
Commit: 88833ba52bf0a3ac6668bcaa73ca383771d5e1b7
Parents: 69d8f2b
Author: Younghee Kwon <yo...@gmail.com>
Authored: Fri Jan 6 18:05:56 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:45 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/__init__.py        |   1 +
 sdks/python/apache_beam/io/tfrecordio.py      | 271 +++++++++++++++
 sdks/python/apache_beam/io/tfrecordio_test.py | 365 +++++++++++++++++++++
 sdks/python/setup.py                          |   1 +
 4 files changed, 638 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/__init__.py b/sdks/python/apache_beam/io/__init__.py
index 4ce9872..13ce36f 100644
--- a/sdks/python/apache_beam/io/__init__.py
+++ b/sdks/python/apache_beam/io/__init__.py
@@ -27,4 +27,5 @@ from apache_beam.io.iobase import Write
 from apache_beam.io.iobase import Writer
 from apache_beam.io.pubsub import *
 from apache_beam.io.textio import *
+from apache_beam.io.tfrecordio import *
 from apache_beam.io.range_trackers import *

http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/tfrecordio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py
new file mode 100644
index 0000000..be9f839
--- /dev/null
+++ b/sdks/python/apache_beam/io/tfrecordio.py
@@ -0,0 +1,271 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""TFRecord sources and sinks."""
+
+from __future__ import absolute_import
+
+import logging
+import struct
+
+from apache_beam import coders
+from apache_beam.io import filebasedsource
+from apache_beam.io import fileio
+from apache_beam.io.iobase import Read
+from apache_beam.io.iobase import Write
+from apache_beam.transforms import PTransform
+import crcmod
+
+__all__ = ['ReadFromTFRecord', 'WriteToTFRecord']
+
+
+def _default_crc32c_fn(value):
+  """Calculates crc32c by either snappy or crcmod based on installation."""
+
+  if not _default_crc32c_fn.fn:
+    try:
+      import snappy  # pylint: disable=import-error
+      _default_crc32c_fn.fn = snappy._crc32c  # pylint: disable=protected-access
+    except ImportError:
+      logging.warning('Couldn\'t find python-snappy so the implementation of '
+                      '_TFRecordUtil._masked_crc32c is not as fast as it could '
+                      'be.')
+      _default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
+  return _default_crc32c_fn.fn(value)
+_default_crc32c_fn.fn = None
+
+
+class _TFRecordUtil(object):
+  """Provides basic TFRecord encoding/decoding with consistency checks.
+
+  For detailed TFRecord format description see:
+    https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details
+
+  Note that masks and length are represented in LittleEndian order.
+  """
+
+  @classmethod
+  def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn):
+    """Compute a masked crc32c checksum for a value.
+
+    Args:
+      value: A string for which we compute the crc.
+      crc32c_fn: A function that can compute a crc32c.
+        This is a performance hook that also helps with testing. Callers are
+        not expected to make use of it directly.
+    Returns:
+      Masked crc32c checksum.
+    """
+
+    crc = crc32c_fn(value)
+    return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff
+
+  @staticmethod
+  def encoded_num_bytes(record):
+    """Return the number of bytes consumed by a record in its encoded form."""
+    # 16 = 8 (Length) + 4 (crc of length) + 4 (crc of data)
+    return len(record) + 16
+
+  @classmethod
+  def write_record(cls, file_handle, value):
+    """Encode a value as a TFRecord.
+
+    Args:
+      file_handle: The file to write to.
+      value: A string content of the record.
+    """
+    encoded_length = struct.pack('<Q', len(value))
+    file_handle.write('{}{}{}{}'.format(
+        encoded_length,
+        struct.pack('<I', cls._masked_crc32c(encoded_length)),  #
+        value,
+        struct.pack('<I', cls._masked_crc32c(value))))
+
+  @classmethod
+  def read_record(cls, file_handle):
+    """Read a record from a TFRecords file.
+
+    Args:
+      file_handle: The file to read from.
+    Returns:
+      None if EOF is reached; the paylod of the record otherwise.
+    Raises:
+      ValueError: If file appears to not be a valid TFRecords file.
+    """
+    buf_length_expected = 12
+    buf = file_handle.read(buf_length_expected)
+    if not buf:
+      return None  # EOF Reached.
+
+    # Validate all length related payloads.
+    if len(buf) != buf_length_expected:
+      raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' %
+                       (buf_length_expected, buf.encode('hex')))
+    length, length_mask_expected = struct.unpack('<QI', buf)
+    length_mask_actual = cls._masked_crc32c(buf[:8])
+    if length_mask_actual != length_mask_expected:
+      raise ValueError('Not a valid TFRecord. Mismatch of length mask: %s' %
+                       buf.encode('hex'))
+
+    # Validate all data related payloads.
+    buf_length_expected = length + 4
+    buf = file_handle.read(buf_length_expected)
+    if len(buf) != buf_length_expected:
+      raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' %
+                       (buf_length_expected, buf.encode('hex')))
+    data, data_mask_expected = struct.unpack('<%dsI' % length, buf)
+    data_mask_actual = cls._masked_crc32c(data)
+    if data_mask_actual != data_mask_expected:
+      raise ValueError('Not a valid TFRecord. Mismatch of data mask: %s' %
+                       buf.encode('hex'))
+
+    # All validation checks passed.
+    return data
+
+
+class _TFRecordSource(filebasedsource.FileBasedSource):
+  """A File source for reading files of TFRecords.
+
+  For detailed TFRecords format description see:
+    https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details
+  """
+
+  def __init__(self,
+               file_pattern,
+               coder,
+               compression_type):
+    """Initialize a TFRecordSource.  See ReadFromTFRecord for details."""
+    super(_TFRecordSource, self).__init__(
+        file_pattern=file_pattern,
+        compression_type=compression_type,
+        splittable=False)
+    self._coder = coder
+
+  def read_records(self, file_name, offset_range_tracker):
+    if offset_range_tracker.start_position():
+      raise ValueError('Start position not 0:%s' %
+                       offset_range_tracker.start_position())
+
+    current_offset = offset_range_tracker.start_position()
+    with self.open_file(file_name) as file_handle:
+      while True:
+        if not offset_range_tracker.try_claim(current_offset):
+          raise RuntimeError('Unable to claim position: %s' % current_offset)
+        record = _TFRecordUtil.read_record(file_handle)
+        if record is None:
+          return  # Reached EOF
+        else:
+          current_offset += _TFRecordUtil.encoded_num_bytes(record)
+          yield self._coder.decode(record)
+
+
+class ReadFromTFRecord(PTransform):
+  """Transform for reading TFRecord sources."""
+
+  def __init__(self,
+               file_pattern,
+               coder=coders.BytesCoder(),
+               compression_type=fileio.CompressionTypes.AUTO,
+               **kwargs):
+    """Initialize a ReadFromTFRecord transform.
+
+    Args:
+      file_pattern: A file glob pattern to read TFRecords from.
+      coder: Coder used to decode each record.
+      compression_type: Used to handle compressed input files. Default value
+          is CompressionTypes.AUTO, in which case the file_path's extension will
+          be used to detect the compression.
+      **kwargs: optional args dictionary. These are passed through to parent
+        constructor.
+
+    Returns:
+      A ReadFromTFRecord transform object.
+    """
+    super(ReadFromTFRecord, self).__init__(**kwargs)
+    self._args = (file_pattern, coder, compression_type)
+
+  def expand(self, pvalue):
+    return pvalue.pipeline | Read(_TFRecordSource(*self._args))
+
+
+class _TFRecordSink(fileio.FileSink):
+  """Sink for writing TFRecords files.
+
+  For detailed TFRecord format description see:
+    https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details
+  """
+
+  def __init__(self, file_path_prefix, coder, file_name_suffix, num_shards,
+               shard_name_template, compression_type):
+    """Initialize a TFRecordSink. See WriteToTFRecord for details."""
+
+    super(_TFRecordSink, self).__init__(
+        file_path_prefix=file_path_prefix,
+        coder=coder,
+        file_name_suffix=file_name_suffix,
+        num_shards=num_shards,
+        shard_name_template=shard_name_template,
+        mime_type='application/octet-stream',
+        compression_type=compression_type)
+
+  def write_encoded_record(self, file_handle, value):
+    _TFRecordUtil.write_record(file_handle, value)
+
+
+class WriteToTFRecord(PTransform):
+  """Transform for writing to TFRecord sinks."""
+
+  def __init__(self,
+               file_path_prefix,
+               coder=coders.BytesCoder(),
+               file_name_suffix='',
+               num_shards=0,
+               shard_name_template=fileio.DEFAULT_SHARD_NAME_TEMPLATE,
+               compression_type=fileio.CompressionTypes.AUTO,
+               **kwargs):
+    """Initialize WriteToTFRecord transform.
+
+    Args:
+      file_path_prefix: The file path to write to. The files written will begin
+        with this prefix, followed by a shard identifier (see num_shards), and
+        end in a common extension, if given by file_name_suffix.
+      coder: Coder used to encode each record.
+      file_name_suffix: Suffix for the files written.
+      num_shards: The number of files (shards) used for output. If not set, the
+        default value will be used.
+      shard_name_template: A template string containing placeholders for
+        the shard number and shard count. Currently only '' and
+        '-SSSSS-of-NNNNN' are patterns allowed.
+        When constructing a filename for a particular shard number, the
+        upper-case letters 'S' and 'N' are replaced with the 0-padded shard
+        number and shard count respectively.  This argument can be '' in which
+        case it behaves as if num_shards was set to 1 and only one file will be
+        generated. The default pattern is '-SSSSS-of-NNNNN'.
+      compression_type: Used to handle compressed output files. Typical value
+          is CompressionTypes.AUTO, in which case the file_path's extension will
+          be used to detect the compression.
+      **kwargs: Optional args dictionary. These are passed through to parent
+        constructor.
+
+    Returns:
+      A WriteToTFRecord transform object.
+    """
+    super(WriteToTFRecord, self).__init__(**kwargs)
+    self._args = (file_path_prefix, coder, file_name_suffix, num_shards,
+                  shard_name_template, compression_type)
+
+  def expand(self, pcoll):
+    return pcoll | Write(_TFRecordSink(*self._args))

http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/tfrecordio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
new file mode 100644
index 0000000..ee287b3
--- /dev/null
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -0,0 +1,365 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import binascii
+import cStringIO
+import glob
+import gzip
+import logging
+import pickle
+import random
+import tempfile
+import unittest
+
+import apache_beam as beam
+from apache_beam import coders
+from apache_beam.io import fileio
+from apache_beam.io.tfrecordio import _TFRecordSink
+from apache_beam.io.tfrecordio import _TFRecordSource
+from apache_beam.io.tfrecordio import _TFRecordUtil
+from apache_beam.io.tfrecordio import ReadFromTFRecord
+from apache_beam.io.tfrecordio import WriteToTFRecord
+from apache_beam.runners import DirectRunner
+import crcmod
+
+
+try:
+  import tensorflow as tf  # pylint: disable=import-error
+except ImportError:
+  tf = None  # pylint: disable=invalid-name
+  logging.warning('Tensorflow is not installed, so skipping some tests.')
+
+# Created by running following code in python:
+# >>> import tensorflow as tf
+# >>> import base64
+# >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+# >>> writer.write('foo')
+# >>> writer.close()
+# >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+# ...   data =  base64.b64encode(f.read())
+# ...   print data
+FOO_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/g=='
+
+# Same as above but containing two records ['foo', 'bar']
+FOO_BAR_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg='
+
+
+class TestTFRecordUtil(unittest.TestCase):
+
+  def setUp(self):
+    self.record = binascii.a2b_base64(FOO_RECORD_BASE64)
+
+  def _as_file_handle(self, contents):
+    result = cStringIO.StringIO()
+    result.write(contents)
+    result.reset()
+    return result
+
+  def _increment_value_at_index(self, value, index):
+    l = list(value)
+    l[index] = chr(ord(l[index]) + 1)
+    return ''.join(l)
+
+  def _test_error(self, record, error_text):
+    with self.assertRaises(ValueError) as context:
+      _TFRecordUtil.read_record(self._as_file_handle(record))
+    self.assertIn(error_text, context.exception.message)
+
+  def test_masked_crc32c(self):
+    self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c('\x00' * 32))
+    self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c('\xff' * 32))
+    self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c('foo'))
+    self.assertEqual(
+        0xe4999b0,
+        _TFRecordUtil._masked_crc32c('\x03\x00\x00\x00\x00\x00\x00\x00'))
+
+  def test_masked_crc32c_crcmod(self):
+    crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
+    self.assertEqual(
+        0xfd7fffa,
+        _TFRecordUtil._masked_crc32c(
+            '\x00' * 32, crc32c_fn=crc32c_fn))
+    self.assertEqual(
+        0xf909b029,
+        _TFRecordUtil._masked_crc32c(
+            '\xff' * 32, crc32c_fn=crc32c_fn))
+    self.assertEqual(
+        0xfebe8a61, _TFRecordUtil._masked_crc32c(
+            'foo', crc32c_fn=crc32c_fn))
+    self.assertEqual(
+        0xe4999b0,
+        _TFRecordUtil._masked_crc32c(
+            '\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn))
+
+  def test_write_record(self):
+    file_handle = cStringIO.StringIO()
+    _TFRecordUtil.write_record(file_handle, 'foo')
+    self.assertEqual(self.record, file_handle.getvalue())
+
+  def test_read_record(self):
+    actual = _TFRecordUtil.read_record(self._as_file_handle(self.record))
+    self.assertEqual('foo', actual)
+
+  def test_read_record_invalid_record(self):
+    self._test_error('bar', 'Not a valid TFRecord. Fewer than 12 bytes')
+
+  def test_read_record_invalid_length_mask(self):
+    record = self._increment_value_at_index(self.record, 9)
+    self._test_error(record, 'Mismatch of length mask')
+
+  def test_read_record_invalid_data_mask(self):
+    record = self._increment_value_at_index(self.record, 16)
+    self._test_error(record, 'Mismatch of data mask')
+
+  def test_compatibility_read_write(self):
+    for record in ['', 'blah', 'another blah']:
+      file_handle = cStringIO.StringIO()
+      _TFRecordUtil.write_record(file_handle, record)
+      file_handle.reset()
+      actual = _TFRecordUtil.read_record(file_handle)
+      self.assertEqual(record, actual)
+
+
+class TestTFRecordSink(unittest.TestCase):
+
+  def _write_lines(self, sink, path, lines):
+    f = sink.open(path)
+    for l in lines:
+      sink.write_record(f, l)
+    sink.close(f)
+
+  def test_write_record_single(self):
+    path = tempfile.NamedTemporaryFile().name
+    record = binascii.a2b_base64(FOO_RECORD_BASE64)
+    sink = _TFRecordSink(
+        path,
+        coder=coders.BytesCoder(),
+        file_name_suffix='',
+        num_shards=0,
+        shard_name_template=None,
+        compression_type=fileio.CompressionTypes.UNCOMPRESSED)
+    self._write_lines(sink, path, ['foo'])
+
+    with open(path, 'r') as f:
+      self.assertEqual(f.read(), record)
+
+  def test_write_record_multiple(self):
+    path = tempfile.NamedTemporaryFile().name
+    record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
+    sink = _TFRecordSink(
+        path,
+        coder=coders.BytesCoder(),
+        file_name_suffix='',
+        num_shards=0,
+        shard_name_template=None,
+        compression_type=fileio.CompressionTypes.UNCOMPRESSED)
+    self._write_lines(sink, path, ['foo', 'bar'])
+
+    with open(path, 'r') as f:
+      self.assertEqual(f.read(), record)
+
+
+@unittest.skipIf(tf is None, 'tensorflow not installed.')
+class TestWriteToTFRecord(TestTFRecordSink):
+
+  def test_write_record_gzip(self):
+    with beam.Pipeline(DirectRunner()) as p:
+      file_path_prefix = tempfile.NamedTemporaryFile().name
+      input_data = ['foo', 'bar']
+      _ = p | beam.Create(input_data) | WriteToTFRecord(
+          file_path_prefix, compression_type=fileio.CompressionTypes.GZIP)
+
+    actual = []
+    file_name = glob.glob(file_path_prefix + '-*')[0]
+    for r in tf.python_io.tf_record_iterator(
+        file_name, options=tf.python_io.TFRecordOptions(
+            tf.python_io.TFRecordCompressionType.GZIP)):
+      actual.append(r)
+    self.assertEqual(actual, input_data)
+
+  def test_write_record_auto(self):
+    with beam.Pipeline(DirectRunner()) as p:
+      file_path_prefix = tempfile.NamedTemporaryFile().name
+      input_data = ['foo', 'bar']
+      _ = p | beam.Create(input_data) | WriteToTFRecord(
+          file_path_prefix, file_name_suffix='.gz')
+
+    actual = []
+    file_name = glob.glob(file_path_prefix + '-*.gz')[0]
+    for r in tf.python_io.tf_record_iterator(
+        file_name, options=tf.python_io.TFRecordOptions(
+            tf.python_io.TFRecordCompressionType.GZIP)):
+      actual.append(r)
+    self.assertEqual(actual, input_data)
+
+
+class TestTFRecordSource(unittest.TestCase):
+
+  def _write_file(self, path, base64_records):
+    record = binascii.a2b_base64(base64_records)
+    with open(path, 'wb') as f:
+      f.write(record)
+
+  def _write_file_gzip(self, path, base64_records):
+    record = binascii.a2b_base64(base64_records)
+    with gzip.GzipFile(path, 'wb') as f:
+      f.write(record)
+
+  def test_process_single(self):
+    path = tempfile.NamedTemporaryFile().name
+    self._write_file(path, FOO_RECORD_BASE64)
+    with beam.Pipeline(DirectRunner()) as p:
+      result = (p
+                | beam.Read(
+                    _TFRecordSource(
+                        path,
+                        coder=coders.BytesCoder(),
+                        compression_type=fileio.CompressionTypes.AUTO)))
+      beam.assert_that(result, beam.equal_to(['foo']))
+
+  def test_process_multiple(self):
+    path = tempfile.NamedTemporaryFile().name
+    self._write_file(path, FOO_BAR_RECORD_BASE64)
+    with beam.Pipeline(DirectRunner()) as p:
+      result = (p
+                | beam.Read(
+                    _TFRecordSource(
+                        path,
+                        coder=coders.BytesCoder(),
+                        compression_type=fileio.CompressionTypes.AUTO)))
+      beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+  def test_process_gzip(self):
+    path = tempfile.NamedTemporaryFile().name
+    self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+    with beam.Pipeline(DirectRunner()) as p:
+      result = (p
+                | beam.Read(
+                    _TFRecordSource(
+                        path,
+                        coder=coders.BytesCoder(),
+                        compression_type=fileio.CompressionTypes.GZIP)))
+      beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+  def test_process_auto(self):
+    path = tempfile.mkstemp(suffix='.gz')[1]
+    self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+    with beam.Pipeline(DirectRunner()) as p:
+      result = (p
+                | beam.Read(
+                    _TFRecordSource(
+                        path,
+                        coder=coders.BytesCoder(),
+                        compression_type=fileio.CompressionTypes.AUTO)))
+      beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+
+class TestReadFromTFRecordSource(TestTFRecordSource):
+
+  def test_process_gzip(self):
+    path = tempfile.NamedTemporaryFile().name
+    self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+    with beam.Pipeline(DirectRunner()) as p:
+      result = (p
+                | ReadFromTFRecord(
+                    path, compression_type=fileio.CompressionTypes.GZIP))
+      beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+  def test_process_gzip_auto(self):
+    path = tempfile.mkstemp(suffix='.gz')[1]
+    self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
+    with beam.Pipeline(DirectRunner()) as p:
+      result = (p
+                | ReadFromTFRecord(
+                    path, compression_type=fileio.CompressionTypes.AUTO))
+      beam.assert_that(result, beam.equal_to(['foo', 'bar']))
+
+
+class TestEnd2EndWriteAndRead(unittest.TestCase):
+
+  def create_inputs(self):
+    input_array = [[random.random() - 0.5 for _ in xrange(15)]
+                   for _ in xrange(12)]
+    memfile = cStringIO.StringIO()
+    pickle.dump(input_array, memfile)
+    return memfile.getvalue()
+
+  def test_end2end(self):
+    file_path_prefix = tempfile.NamedTemporaryFile().name
+
+    # Generate a TFRecord file.
+    with beam.Pipeline(DirectRunner()) as p:
+      expected_data = [self.create_inputs() for _ in range(0, 10)]
+      _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix)
+
+    # Read the file back and compare.
+    with beam.Pipeline(DirectRunner()) as p:
+      actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
+      beam.assert_that(actual_data, beam.equal_to(expected_data))
+
+  def test_end2end_auto_compression(self):
+    file_path_prefix = tempfile.NamedTemporaryFile().name
+
+    # Generate a TFRecord file.
+    with beam.Pipeline(DirectRunner()) as p:
+      expected_data = [self.create_inputs() for _ in range(0, 10)]
+      _ = p | beam.Create(expected_data) | WriteToTFRecord(
+          file_path_prefix, file_name_suffix='.gz')
+
+    # Read the file back and compare.
+    with beam.Pipeline(DirectRunner()) as p:
+      actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
+      beam.assert_that(actual_data, beam.equal_to(expected_data))
+
+  def test_end2end_auto_compression_unsharded(self):
+    file_path_prefix = tempfile.NamedTemporaryFile().name
+
+    # Generate a TFRecord file.
+    with beam.Pipeline(DirectRunner()) as p:
+      expected_data = [self.create_inputs() for _ in range(0, 10)]
+      _ = p | beam.Create(expected_data) | WriteToTFRecord(
+          file_path_prefix + '.gz', shard_name_template='')
+
+    # Read the file back and compare.
+    with beam.Pipeline(DirectRunner()) as p:
+      actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz')
+      beam.assert_that(actual_data, beam.equal_to(expected_data))
+
+  @unittest.skipIf(tf is None, 'tensorflow not installed.')
+  def test_end2end_example_proto(self):
+    file_path_prefix = tempfile.NamedTemporaryFile().name
+
+    example = tf.train.Example()
+    example.features.feature['int'].int64_list.value.extend(range(3))
+    example.features.feature['bytes'].bytes_list.value.extend(
+        [b'foo', b'bar'])
+
+    with beam.Pipeline(DirectRunner()) as p:
+      _ = p | beam.Create([example]) | WriteToTFRecord(
+          file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__))
+
+    # Read the file back and compare.
+    with beam.Pipeline(DirectRunner()) as p:
+      actual_data = (p | ReadFromTFRecord(
+          file_path_prefix + '-*',
+          coder=beam.coders.ProtoCoder(example.__class__)))
+      beam.assert_that(actual_data, beam.equal_to([example]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/setup.py
----------------------------------------------------------------------
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index f6357b6..1fd622f 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -85,6 +85,7 @@ else:
 
 REQUIRED_PACKAGES = [
     'avro>=1.7.7,<2.0.0',
+    'crcmod>=1.7,<2.0',
     'dill>=0.2.5,<0.3',
     'google-apitools>=0.5.6,<1.0.0',
     'googledatastore==6.4.1',


[3/3] beam git commit: Closes #1749

Posted by ro...@apache.org.
Closes #1749


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a2551517
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a2551517
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a2551517

Branch: refs/heads/python-sdk
Commit: a2551517184e55ecf01d07dffd921b105f7dc448
Parents: 69d8f2b 93e8d19
Author: Robert Bradshaw <ro...@google.com>
Authored: Mon Jan 9 13:13:55 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:55 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/__init__.py        |   1 +
 sdks/python/apache_beam/io/tfrecordio.py      | 271 ++++++++++++++
 sdks/python/apache_beam/io/tfrecordio_test.py | 389 +++++++++++++++++++++
 sdks/python/setup.py                          |   1 +
 4 files changed, 662 insertions(+)
----------------------------------------------------------------------



[2/3] beam git commit: Provided temporary directory management for test cases.

Posted by ro...@apache.org.
Provided temporary directory management for test cases.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/93e8d19e
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/93e8d19e
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/93e8d19e

Branch: refs/heads/python-sdk
Commit: 93e8d19e32807fb5279ed711f0f06c3123adfb2e
Parents: 88833ba
Author: Younghee Kwon <yo...@gmail.com>
Authored: Mon Jan 9 11:50:57 2017 -0800
Committer: Robert Bradshaw <ro...@google.com>
Committed: Mon Jan 9 13:13:46 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/tfrecordio_test.py | 58 +++++++++++++++-------
 1 file changed, 41 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/93e8d19e/sdks/python/apache_beam/io/tfrecordio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py
index ee287b3..ecd58f5 100644
--- a/sdks/python/apache_beam/io/tfrecordio_test.py
+++ b/sdks/python/apache_beam/io/tfrecordio_test.py
@@ -20,8 +20,10 @@ import cStringIO
 import glob
 import gzip
 import logging
+import os
 import pickle
 import random
+import shutil
 import tempfile
 import unittest
 
@@ -134,7 +136,29 @@ class TestTFRecordUtil(unittest.TestCase):
       self.assertEqual(record, actual)
 
 
-class TestTFRecordSink(unittest.TestCase):
+class _TestCaseWithTempDirCleanUp(unittest.TestCase):
+  """Base class for TestCases that deals with TempDir clean-up.
+
+  Inherited test cases will call self._new_tempdir() to start a temporary dir
+  which will be deleted at the end of the tests (when tearDown() is called).
+  """
+
+  def setUp(self):
+    self._tempdirs = []
+
+  def tearDown(self):
+    for path in self._tempdirs:
+      if os.path.exists(path):
+        shutil.rmtree(path)
+    self._tempdirs = []
+
+  def _new_tempdir(self):
+    result = tempfile.mkdtemp()
+    self._tempdirs.append(result)
+    return result
+
+
+class TestTFRecordSink(_TestCaseWithTempDirCleanUp):
 
   def _write_lines(self, sink, path, lines):
     f = sink.open(path)
@@ -143,7 +167,7 @@ class TestTFRecordSink(unittest.TestCase):
     sink.close(f)
 
   def test_write_record_single(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     record = binascii.a2b_base64(FOO_RECORD_BASE64)
     sink = _TFRecordSink(
         path,
@@ -158,7 +182,7 @@ class TestTFRecordSink(unittest.TestCase):
       self.assertEqual(f.read(), record)
 
   def test_write_record_multiple(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
     sink = _TFRecordSink(
         path,
@@ -177,8 +201,8 @@ class TestTFRecordSink(unittest.TestCase):
 class TestWriteToTFRecord(TestTFRecordSink):
 
   def test_write_record_gzip(self):
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
     with beam.Pipeline(DirectRunner()) as p:
-      file_path_prefix = tempfile.NamedTemporaryFile().name
       input_data = ['foo', 'bar']
       _ = p | beam.Create(input_data) | WriteToTFRecord(
           file_path_prefix, compression_type=fileio.CompressionTypes.GZIP)
@@ -192,8 +216,8 @@ class TestWriteToTFRecord(TestTFRecordSink):
     self.assertEqual(actual, input_data)
 
   def test_write_record_auto(self):
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
     with beam.Pipeline(DirectRunner()) as p:
-      file_path_prefix = tempfile.NamedTemporaryFile().name
       input_data = ['foo', 'bar']
       _ = p | beam.Create(input_data) | WriteToTFRecord(
           file_path_prefix, file_name_suffix='.gz')
@@ -207,7 +231,7 @@ class TestWriteToTFRecord(TestTFRecordSink):
     self.assertEqual(actual, input_data)
 
 
-class TestTFRecordSource(unittest.TestCase):
+class TestTFRecordSource(_TestCaseWithTempDirCleanUp):
 
   def _write_file(self, path, base64_records):
     record = binascii.a2b_base64(base64_records)
@@ -220,7 +244,7 @@ class TestTFRecordSource(unittest.TestCase):
       f.write(record)
 
   def test_process_single(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file(path, FOO_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -232,7 +256,7 @@ class TestTFRecordSource(unittest.TestCase):
       beam.assert_that(result, beam.equal_to(['foo']))
 
   def test_process_multiple(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -244,7 +268,7 @@ class TestTFRecordSource(unittest.TestCase):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_gzip(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -256,7 +280,7 @@ class TestTFRecordSource(unittest.TestCase):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_auto(self):
-    path = tempfile.mkstemp(suffix='.gz')[1]
+    path = os.path.join(self._new_tempdir(), 'result.gz')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -271,7 +295,7 @@ class TestTFRecordSource(unittest.TestCase):
 class TestReadFromTFRecordSource(TestTFRecordSource):
 
   def test_process_gzip(self):
-    path = tempfile.NamedTemporaryFile().name
+    path = os.path.join(self._new_tempdir(), 'result')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -280,7 +304,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
   def test_process_gzip_auto(self):
-    path = tempfile.mkstemp(suffix='.gz')[1]
+    path = os.path.join(self._new_tempdir(), 'result.gz')
     self._write_file_gzip(path, FOO_BAR_RECORD_BASE64)
     with beam.Pipeline(DirectRunner()) as p:
       result = (p
@@ -289,7 +313,7 @@ class TestReadFromTFRecordSource(TestTFRecordSource):
       beam.assert_that(result, beam.equal_to(['foo', 'bar']))
 
 
-class TestEnd2EndWriteAndRead(unittest.TestCase):
+class TestEnd2EndWriteAndRead(_TestCaseWithTempDirCleanUp):
 
   def create_inputs(self):
     input_array = [[random.random() - 0.5 for _ in xrange(15)]
@@ -299,7 +323,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
     return memfile.getvalue()
 
   def test_end2end(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     # Generate a TFRecord file.
     with beam.Pipeline(DirectRunner()) as p:
@@ -312,7 +336,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
       beam.assert_that(actual_data, beam.equal_to(expected_data))
 
   def test_end2end_auto_compression(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     # Generate a TFRecord file.
     with beam.Pipeline(DirectRunner()) as p:
@@ -326,7 +350,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
       beam.assert_that(actual_data, beam.equal_to(expected_data))
 
   def test_end2end_auto_compression_unsharded(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     # Generate a TFRecord file.
     with beam.Pipeline(DirectRunner()) as p:
@@ -341,7 +365,7 @@ class TestEnd2EndWriteAndRead(unittest.TestCase):
 
   @unittest.skipIf(tf is None, 'tensorflow not installed.')
   def test_end2end_example_proto(self):
-    file_path_prefix = tempfile.NamedTemporaryFile().name
+    file_path_prefix = os.path.join(self._new_tempdir(), 'result')
 
     example = tf.train.Example()
     example.features.feature['int'].int64_list.value.extend(range(3))