You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2016/06/14 23:13:11 UTC
[36/50] [abbrv] incubator-beam git commit: Move all files to
apache_beam folder
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/fileio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py
new file mode 100644
index 0000000..9a003f0
--- /dev/null
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -0,0 +1,747 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""File-based sources and sinks."""
+
+from __future__ import absolute_import
+
+import glob
+import gzip
+import logging
+from multiprocessing.pool import ThreadPool
+import os
+import re
+import shutil
+import tempfile
+import time
+
+from google.cloud.dataflow import coders
+from google.cloud.dataflow.io import iobase
+from google.cloud.dataflow.io import range_trackers
+from google.cloud.dataflow.utils import processes
+from google.cloud.dataflow.utils import retry
+
+
+__all__ = ['TextFileSource', 'TextFileSink']
+
+DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
+
+
+# Retrying is needed because there are transient errors that can happen.
+@retry.with_exponential_backoff(num_retries=4, retry_filter=lambda _: True)
+def _gcs_file_copy(from_path, to_path, encoding=''):
+ """Copy a local file to a GCS location with retries for transient errors."""
+ if not encoding:
+ command_args = ['gsutil', '-m', '-q', 'cp', from_path, to_path]
+ else:
+ encoding = 'Content-Type:' + encoding
+ command_args = ['gsutil', '-m', '-q', '-h', encoding, 'cp', from_path,
+ to_path]
+ logging.info('Executing command: %s', command_args)
+ popen = processes.Popen(command_args, stdout=processes.PIPE,
+ stderr=processes.PIPE)
+ stdoutdata, stderrdata = popen.communicate()
+ if popen.returncode != 0:
+ raise ValueError(
+ 'Failed to copy GCS file from %s to %s (stdout=%s, stderr=%s).' % (
+ from_path, to_path, stdoutdata, stderrdata))
+
+
+# -----------------------------------------------------------------------------
+# TextFileSource, TextFileSink.
+
+
+class TextFileSource(iobase.NativeSource):
+ """A source for a GCS or local text file.
+
+ Parses a text file as newline-delimited elements, by default assuming
+ UTF-8 encoding.
+ """
+
+ def __init__(self, file_path, start_offset=None, end_offset=None,
+ compression_type='AUTO', strip_trailing_newlines=True,
+ coder=coders.StrUtf8Coder()):
+ """Initialize a TextSource.
+
+ Args:
+ file_path: The file path to read from as a local file path or a GCS
+ gs:// path. The path can contain glob characters (*, ?, and [...]
+ sets).
+ start_offset: The byte offset in the source text file that the reader
+ should start reading. By default is 0 (beginning of file).
+ end_offset: The byte offset in the file that the reader should stop
+ reading. By default it is the end of the file.
+ compression_type: Used to handle compressed input files. Typical value
+ is 'AUTO'.
+ strip_trailing_newlines: Indicates whether this source should remove
+ the newline char in each line it reads before decoding that line.
+ coder: Coder used to decode each line.
+
+ Raises:
+ TypeError: if file_path is not a string.
+
+ If the file_path contains glob characters then the start_offset and
+ end_offset must not be specified.
+
+ The 'start_offset' and 'end_offset' pair provide a mechanism to divide the
+ text file into multiple pieces for individual sources. Because the offset
+ is measured by bytes, some complication arises when the offset splits in
+ the middle of a text line. To avoid the scenario where two adjacent sources
+ each get a fraction of a line we adopt the following rules:
+
+ If start_offset falls inside a line (any character except the firt one)
+ then the source will skip the line and start with the next one.
+
+ If end_offset falls inside a line (any character except the first one) then
+ the source will contain that entire line.
+ """
+ if not isinstance(file_path, basestring):
+ raise TypeError(
+ '%s: file_path must be a string; got %r instead' %
+ (self.__class__.__name__, file_path))
+
+ self.file_path = file_path
+ self.start_offset = start_offset
+ self.end_offset = end_offset
+ self.compression_type = compression_type
+ self.strip_trailing_newlines = strip_trailing_newlines
+ self.coder = coder
+
+ self.is_gcs_source = file_path.startswith('gs://')
+
+ @property
+ def format(self):
+ """Source format name required for remote execution."""
+ return 'text'
+
+ def __eq__(self, other):
+ return (self.file_path == other.file_path and
+ self.start_offset == other.start_offset and
+ self.end_offset == other.end_offset and
+ self.strip_trailing_newlines == other.strip_trailing_newlines and
+ self.coder == other.coder)
+
+ @property
+ def path(self):
+ return self.file_path
+
+ def reader(self):
+ # If a multi-file pattern was specified as a source then make sure the
+ # start/end offsets use the default values for reading the entire file.
+ if re.search(r'[*?\[\]]', self.file_path) is not None:
+ if self.start_offset is not None:
+ raise ValueError(
+ 'start offset cannot be specified for a multi-file source: '
+ '%s' % self.file_path)
+ if self.end_offset is not None:
+ raise ValueError(
+ 'End offset cannot be specified for a multi-file source: '
+ '%s' % self.file_path)
+ return TextMultiFileReader(self)
+ else:
+ return TextFileReader(self)
+
+
+class ChannelFactory(object):
+ # TODO(robertwb): Generalize into extensible framework.
+
+ @staticmethod
+ def mkdir(path):
+ if path.startswith('gs://'):
+ return
+ else:
+ try:
+ os.makedirs(path)
+ except OSError as err:
+ raise IOError(err)
+
+ @staticmethod
+ def open(path, mode, mime_type):
+ if path.startswith('gs://'):
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ return gcsio.GcsIO().open(path, mode, mime_type=mime_type)
+ else:
+ return open(path, mode)
+
+ @staticmethod
+ def rename(src, dst):
+ if src.startswith('gs://'):
+ assert dst.startswith('gs://'), dst
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ gcsio.GcsIO().rename(src, dst)
+ else:
+ try:
+ os.rename(src, dst)
+ except OSError as err:
+ raise IOError(err)
+
+ @staticmethod
+ def copytree(src, dst):
+ if src.startswith('gs://'):
+ assert dst.startswith('gs://'), dst
+ assert src.endswith('/'), src
+ assert dst.endswith('/'), dst
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ gcsio.GcsIO().copytree(src, dst)
+ else:
+ try:
+ if os.path.exists(dst):
+ shutil.rmtree(dst)
+ shutil.copytree(src, dst)
+ except OSError as err:
+ raise IOError(err)
+
+ @staticmethod
+ def exists(path):
+ if path.startswith('gs://'):
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ return gcsio.GcsIO().exists()
+ else:
+ return os.path.exists(path)
+
+ @staticmethod
+ def rmdir(path):
+ if path.startswith('gs://'):
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ gcs = gcsio.GcsIO()
+ if not path.endswith('/'):
+ path += '/'
+ # TODO(robertwb): Threadpool?
+ for entry in gcs.glob(path + '*'):
+ gcs.delete(entry)
+ else:
+ try:
+ shutil.rmtree(path)
+ except OSError as err:
+ raise IOError(err)
+
+ @staticmethod
+ def rm(path):
+ if path.startswith('gs://'):
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ gcsio.GcsIO().delete(path)
+ else:
+ try:
+ os.remove(path)
+ except OSError as err:
+ raise IOError(err)
+
+ @staticmethod
+ def glob(path):
+ if path.startswith('gs://'):
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ return gcsio.GcsIO().glob(path)
+ else:
+ return glob.glob(path)
+
+
+class _CompressionType(object):
+ """Object representing single compression type."""
+
+ def __init__(self, identifier):
+ self.identifier = identifier
+
+ def __eq__(self, other):
+ return self.identifier == other.identifier
+
+
+class CompressionTypes(object):
+ """Enum-like class representing known compression types."""
+ NO_COMPRESSION = _CompressionType(1) # No compression.
+ DEFLATE = _CompressionType(2) # 'Deflate' ie gzip compression.
+
+ @staticmethod
+ def valid_compression_type(compression_type):
+ """Returns true for valid compression types, false otherwise."""
+ return isinstance(compression_type, _CompressionType)
+
+
+class FileSink(iobase.Sink):
+ """A sink to a GCS or local files.
+
+ To implement a file-based sink, extend this class and override
+ either ``write_record()`` or ``write_encoded_record()``.
+
+ If needed, also overwrite ``open()`` and/or ``close()`` to customize the
+ file handling or write headers and footers.
+
+ The output of this write is a PCollection of all written shards.
+ """
+
+ # Approximate number of write results be assigned for each rename thread.
+ _WRITE_RESULTS_PER_RENAME_THREAD = 100
+
+ # Max number of threads to be used for renaming even if it means each thread
+ # will process more write results.
+ _MAX_RENAME_THREADS = 64
+
+ def __init__(self,
+ file_path_prefix,
+ coder,
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=None,
+ mime_type='application/octet-stream'):
+ if shard_name_template is None:
+ shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
+ elif shard_name_template is '':
+ num_shards = 1
+ self.file_path_prefix = file_path_prefix
+ self.file_name_suffix = file_name_suffix
+ self.num_shards = num_shards
+ self.coder = coder
+ self.mime_type = mime_type
+ self.shard_name_format = self._template_to_format(shard_name_template)
+
+ def open(self, temp_path):
+ """Opens ``temp_path``, returning an opaque file handle object.
+
+ The returned file handle is passed to ``write_[encoded_]record`` and
+ ``close``.
+ """
+ return ChannelFactory.open(temp_path, 'wb', self.mime_type)
+
+ def write_record(self, file_handle, value):
+ """Writes a single record go the file handle returned by ``open()``.
+
+ By default, calls ``write_encoded_record`` after encoding the record with
+ this sink's Coder.
+ """
+ self.write_encoded_record(file_handle, self.coder.encode(value))
+
+ def write_encoded_record(self, file_handle, encoded_value):
+ """Writes a single encoded record to the file handle returned by ``open()``.
+ """
+ raise NotImplementedError
+
+ def close(self, file_handle):
+ """Finalize and close the file handle returned from ``open()``.
+
+ Called after all records are written.
+
+ By default, calls ``file_handle.close()`` iff it is not None.
+ """
+ if file_handle is not None:
+ file_handle.close()
+
+ def initialize_write(self):
+ tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime(
+ '-temp-%Y-%m-%d_%H-%M-%S')
+ ChannelFactory().mkdir(tmp_dir)
+ return tmp_dir
+
+ def open_writer(self, init_result, uid):
+ return FileSinkWriter(self, os.path.join(init_result, uid))
+
+ def finalize_write(self, init_result, writer_results):
+ writer_results = sorted(writer_results)
+ num_shards = len(writer_results)
+ channel_factory = ChannelFactory()
+ num_threads = max(1, min(
+ num_shards / FileSink._WRITE_RESULTS_PER_RENAME_THREAD,
+ FileSink._MAX_RENAME_THREADS))
+
+ rename_ops = []
+ for shard_num, shard in enumerate(writer_results):
+ final_name = ''.join([
+ self.file_path_prefix,
+ self.shard_name_format % dict(shard_num=shard_num,
+ num_shards=num_shards),
+ self.file_name_suffix])
+ rename_ops.append((shard, final_name))
+
+ logging.info(
+ 'Starting finalize_write threads with num_shards: %d, num_threads: %d',
+ num_shards, num_threads)
+ start_time = time.time()
+
+ # Use a thread pool for renaming operations.
+ def _rename_file(rename_op):
+ """_rename_file executes single (old_name, new_name) rename operation."""
+ old_name, final_name = rename_op
+ try:
+ channel_factory.rename(old_name, final_name)
+ except IOError as e:
+ # May have already been copied.
+ exists = channel_factory.exists(final_name)
+ if not exists:
+ logging.warning(('IOError in _rename_file. old_name: %s, '
+ 'final_name: %s, err: %s'), old_name, final_name, e)
+ return(None, e)
+ except Exception as e: # pylint: disable=broad-except
+ logging.warning(('Exception in _rename_file. old_name: %s, '
+ 'final_name: %s, err: %s'), old_name, final_name, e)
+ return(None, e)
+ return (final_name, None)
+
+ rename_results = ThreadPool(num_threads).map(_rename_file, rename_ops)
+
+ for final_name, err in rename_results:
+ if err:
+ logging.warning('Error when processing rename_results: %s', err)
+ raise err
+ else:
+ yield final_name
+
+ logging.info('Renamed %d shards in %.2f seconds.',
+ num_shards, time.time() - start_time)
+
+ try:
+ channel_factory.rmdir(init_result)
+ except IOError:
+ # May have already been removed.
+ pass
+
+ @staticmethod
+ def _template_to_format(shard_name_template):
+ if not shard_name_template:
+ return ''
+ m = re.search('S+', shard_name_template)
+ if m is None:
+ raise ValueError("Shard number pattern S+ not found in template '%s'"
+ % shard_name_template)
+ shard_name_format = shard_name_template.replace(
+ m.group(0), '%%(shard_num)0%dd' % len(m.group(0)))
+ m = re.search('N+', shard_name_format)
+ if m:
+ shard_name_format = shard_name_format.replace(
+ m.group(0), '%%(num_shards)0%dd' % len(m.group(0)))
+ return shard_name_format
+
+ def __eq__(self, other):
+ # TODO(robertwb): Clean up workitem_test which uses this.
+ # pylint: disable=unidiomatic-typecheck
+ return type(self) == type(other) and self.__dict__ == other.__dict__
+
+
+class FileSinkWriter(iobase.Writer):
+ """The writer for FileSink.
+ """
+
+ def __init__(self, sink, temp_shard_path):
+ self.sink = sink
+ self.temp_shard_path = temp_shard_path
+ self.temp_handle = self.sink.open(temp_shard_path)
+
+ def write(self, value):
+ self.sink.write_record(self.temp_handle, value)
+
+ def close(self):
+ self.sink.close(self.temp_handle)
+ return self.temp_shard_path
+
+
+class TextFileSink(FileSink):
+ """A sink to a GCS or local text file or files."""
+
+ def __init__(self,
+ file_path_prefix,
+ file_name_suffix='',
+ append_trailing_newlines=True,
+ num_shards=0,
+ shard_name_template=None,
+ coder=coders.ToStringCoder(),
+ compression_type=CompressionTypes.NO_COMPRESSION,
+ ):
+ """Initialize a TextFileSink.
+
+ Args:
+ file_path_prefix: The file path to write to. The files written will begin
+ with this prefix, followed by a shard identifier (see num_shards), and
+ end in a common extension, if given by file_name_suffix. In most cases,
+ only this argument is specified and num_shards, shard_name_template, and
+ file_name_suffix use default values.
+ file_name_suffix: Suffix for the files written.
+ append_trailing_newlines: indicate whether this sink should write an
+ additional newline char after writing each element.
+ num_shards: The number of files (shards) used for output. If not set, the
+ service will decide on the optimal number of shards.
+ Constraining the number of shards is likely to reduce
+ the performance of a pipeline. Setting this value is not recommended
+ unless you require a specific number of output files.
+ shard_name_template: A template string containing placeholders for
+ the shard number and shard count. Currently only '' and
+ '-SSSSS-of-NNNNN' are patterns accepted by the service.
+ When constructing a filename for a particular shard number, the
+ upper-case letters 'S' and 'N' are replaced with the 0-padded shard
+ number and shard count respectively. This argument can be '' in which
+ case it behaves as if num_shards was set to 1 and only one file will be
+ generated. The default pattern used is '-SSSSS-of-NNNNN'.
+ coder: Coder used to encode each line.
+ compression_type: Type of compression to use for this sink.
+
+ Raises:
+ TypeError: if file path parameters are not a string or if compression_type
+ is not member of CompressionTypes.
+ ValueError: if shard_name_template is not of expected format.
+
+ Returns:
+ A TextFileSink object usable for writing.
+ """
+ if not isinstance(file_path_prefix, basestring):
+ raise TypeError(
+ 'TextFileSink: file_path_prefix must be a string; got %r instead' %
+ file_path_prefix)
+ if not isinstance(file_name_suffix, basestring):
+ raise TypeError(
+ 'TextFileSink: file_name_suffix must be a string; got %r instead' %
+ file_name_suffix)
+
+ if not CompressionTypes.valid_compression_type(compression_type):
+ raise TypeError('compression_type must be CompressionType object but '
+ 'was %s' % type(compression_type))
+ if compression_type == CompressionTypes.DEFLATE:
+ mime_type = 'application/x-gzip'
+ else:
+ mime_type = 'text/plain'
+
+ super(TextFileSink, self).__init__(file_path_prefix,
+ file_name_suffix=file_name_suffix,
+ num_shards=num_shards,
+ shard_name_template=shard_name_template,
+ coder=coder,
+ mime_type=mime_type)
+
+ self.compression_type = compression_type
+ self.append_trailing_newlines = append_trailing_newlines
+
+ def open(self, temp_path):
+ """Opens ''temp_path'', returning a writeable file object."""
+ fobj = ChannelFactory.open(temp_path, 'wb', self.mime_type)
+ if self.compression_type == CompressionTypes.DEFLATE:
+ return gzip.GzipFile(fileobj=fobj)
+ return fobj
+
+ def write_encoded_record(self, file_handle, encoded_value):
+ file_handle.write(encoded_value)
+ if self.append_trailing_newlines:
+ file_handle.write('\n')
+
+
+class NativeTextFileSink(iobase.NativeSink):
+ """A sink to a GCS or local text file or files."""
+
+ def __init__(self, file_path_prefix,
+ append_trailing_newlines=True,
+ file_name_suffix='',
+ num_shards=0,
+ shard_name_template=None,
+ validate=True,
+ coder=coders.ToStringCoder()):
+ # We initialize a file_path attribute containing just the prefix part for
+ # local runner environment. For now, sharding is not supported in the local
+ # runner and sharding options (template, num, suffix) are ignored.
+ # The attribute is also used in the worker environment when we just write
+ # to a specific file.
+ # TODO(silviuc): Add support for file sharding in the local runner.
+ self.file_path = file_path_prefix
+ self.append_trailing_newlines = append_trailing_newlines
+ self.coder = coder
+
+ self.is_gcs_sink = self.file_path.startswith('gs://')
+
+ self.file_name_prefix = file_path_prefix
+ self.file_name_suffix = file_name_suffix
+ self.num_shards = num_shards
+ # TODO(silviuc): Update this when the service supports more patterns.
+ self.shard_name_template = ('-SSSSS-of-NNNNN' if shard_name_template is None
+ else shard_name_template)
+ # TODO(silviuc): Implement sink validation.
+ self.validate = validate
+
+ @property
+ def format(self):
+ """Sink format name required for remote execution."""
+ return 'text'
+
+ @property
+ def path(self):
+ return self.file_path
+
+ def writer(self):
+ return TextFileWriter(self)
+
+ def __eq__(self, other):
+ return (self.file_path == other.file_path and
+ self.append_trailing_newlines == other.append_trailing_newlines and
+ self.coder == other.coder and
+ self.file_name_prefix == other.file_name_prefix and
+ self.file_name_suffix == other.file_name_suffix and
+ self.num_shards == other.num_shards and
+ self.shard_name_template == other.shard_name_template and
+ self.validate == other.validate)
+
+
+# -----------------------------------------------------------------------------
+# TextFileReader, TextMultiFileReader.
+
+
+class TextFileReader(iobase.NativeSourceReader):
+ """A reader for a text file source."""
+
+ def __init__(self, source):
+ self.source = source
+ self.start_offset = self.source.start_offset or 0
+ self.end_offset = self.source.end_offset
+ self.current_offset = self.start_offset
+
+ def __enter__(self):
+ if self.source.is_gcs_source:
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.io import gcsio
+ self._file = gcsio.GcsIO().open(self.source.file_path, 'rb')
+ else:
+ self._file = open(self.source.file_path, 'rb')
+ # Determine the real end_offset.
+ # If not specified it will be the length of the file.
+ if self.end_offset is None:
+ self._file.seek(0, os.SEEK_END)
+ self.end_offset = self._file.tell()
+
+ if self.start_offset is None:
+ self.start_offset = 0
+ self.current_offset = self.start_offset
+ if self.start_offset > 0:
+ # Read one byte before. This operation will either consume a previous
+ # newline if start_offset was at the beginning of a line or consume the
+ # line if we were in the middle of it. Either way we get the read position
+ # exactly where we wanted: at the begining of the first full line.
+ self._file.seek(self.start_offset - 1)
+ self.current_offset -= 1
+ line = self._file.readline()
+ self.current_offset += len(line)
+ else:
+ self._file.seek(self.start_offset)
+
+ # Initializing range tracker after start and end offsets are finalized.
+ self.range_tracker = range_trackers.OffsetRangeTracker(self.start_offset,
+ self.end_offset)
+
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self._file.close()
+
+ def __iter__(self):
+ while True:
+ if not self.range_tracker.try_claim(
+ record_start=self.current_offset):
+ # Reader has completed reading the set of records in its range. Note
+ # that the end offset of the range may be smaller than the original
+ # end offset defined when creating the reader due to reader accepting
+ # a dynamic split request from the service.
+ return
+ line = self._file.readline()
+ self.current_offset += len(line)
+ if self.source.strip_trailing_newlines:
+ line = line.rstrip('\n')
+ yield self.source.coder.decode(line)
+
+ def get_progress(self):
+ return iobase.ReaderProgress(position=iobase.ReaderPosition(
+ byte_offset=self.range_tracker.last_record_start))
+
+ def request_dynamic_split(self, dynamic_split_request):
+ assert dynamic_split_request is not None
+ progress = dynamic_split_request.progress
+ split_position = progress.position
+ if split_position is None:
+ percent_complete = progress.percent_complete
+ if percent_complete is not None:
+ if percent_complete <= 0 or percent_complete >= 1:
+ logging.warning(
+ 'FileBasedReader cannot be split since the provided percentage '
+ 'of work to be completed is out of the valid range (0, '
+ '1). Requested: %r',
+ dynamic_split_request)
+ return
+ split_position = iobase.ReaderPosition()
+ split_position.byte_offset = (
+ self.range_tracker.position_at_fraction(percent_complete))
+ else:
+ logging.warning(
+ 'TextReader requires either a position or a percentage of work to '
+ 'be complete to perform a dynamic split request. Requested: %r',
+ dynamic_split_request)
+ return
+
+ if self.range_tracker.try_split(split_position.byte_offset):
+ return iobase.DynamicSplitResultWithPosition(split_position)
+ else:
+ return
+
+
+class TextMultiFileReader(iobase.NativeSourceReader):
+ """A reader for a multi-file text source."""
+
+ def __init__(self, source):
+ self.source = source
+ self.file_paths = ChannelFactory.glob(self.source.file_path)
+ if not self.file_paths:
+ raise RuntimeError(
+ 'No files found for path: %s' % self.source.file_path)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ pass
+
+ def __iter__(self):
+ index = 0
+ for path in self.file_paths:
+ index += 1
+ logging.info('Reading from %s (%d/%d)', path, index, len(self.file_paths))
+ with TextFileSource(
+ path, strip_trailing_newlines=self.source.strip_trailing_newlines,
+ coder=self.source.coder).reader() as reader:
+ for line in reader:
+ yield line
+
+
+# -----------------------------------------------------------------------------
+# TextFileWriter.
+
+
+class TextFileWriter(iobase.NativeSinkWriter):
+ """The sink writer for a TextFileSink."""
+
+ def __init__(self, sink):
+ self.sink = sink
+
+ def __enter__(self):
+ if self.sink.is_gcs_sink:
+ # TODO(silviuc): Use the storage library instead of gsutil for writes.
+ self.temp_path = os.path.join(tempfile.mkdtemp(), 'gcsfile')
+ self._file = open(self.temp_path, 'wb')
+ else:
+ self._file = open(self.sink.file_path, 'wb')
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self._file.close()
+ if hasattr(self, 'temp_path'):
+ _gcs_file_copy(self.temp_path, self.sink.file_path, 'text/plain')
+
+ def Write(self, line):
+ self._file.write(self.sink.coder.encode(line))
+ if self.sink.append_trailing_newlines:
+ self._file.write('\n')
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/fileio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py
new file mode 100644
index 0000000..70192d1
--- /dev/null
+++ b/sdks/python/apache_beam/io/fileio_test.py
@@ -0,0 +1,522 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for local and GCS sources and sinks."""
+
+import glob
+import gzip
+import logging
+import os
+import tempfile
+import unittest
+
+import google.cloud.dataflow as df
+from google.cloud.dataflow import coders
+from google.cloud.dataflow.io import fileio
+from google.cloud.dataflow.io import iobase
+
+
+class TestTextFileSource(unittest.TestCase):
+
+ def create_temp_file(self, text):
+ temp = tempfile.NamedTemporaryFile(delete=False)
+ with temp.file as tmp:
+ tmp.write(text)
+ return temp.name
+
+ def read_with_offsets(self, input_lines, output_lines,
+ start_offset=None, end_offset=None):
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file('\n'.join(input_lines)),
+ start_offset=start_offset, end_offset=end_offset)
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, output_lines)
+
+ def progress_with_offsets(self, input_lines,
+ start_offset=None, end_offset=None):
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file('\n'.join(input_lines)),
+ start_offset=start_offset, end_offset=end_offset)
+ progress_record = []
+ with source.reader() as reader:
+ self.assertEqual(reader.get_progress().position.byte_offset, -1)
+ for line in reader:
+ self.assertIsNotNone(line)
+ progress_record.append(reader.get_progress().position.byte_offset)
+
+ previous = 0
+ for current in progress_record:
+ self.assertGreater(current, previous)
+ previous = current
+
+ def test_read_entire_file(self):
+ lines = ['First', 'Second', 'Third']
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file('\n'.join(lines)))
+ read_lines = []
+ with source.reader() as reader:
+ for line in reader:
+ read_lines.append(line)
+ self.assertEqual(read_lines, lines)
+
+ def test_progress_entire_file(self):
+ lines = ['First', 'Second', 'Third']
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file('\n'.join(lines)))
+ progress_record = []
+ with source.reader() as reader:
+ self.assertEqual(-1, reader.get_progress().position.byte_offset)
+ for line in reader:
+ self.assertIsNotNone(line)
+ progress_record.append(reader.get_progress().position.byte_offset)
+ self.assertEqual(13, reader.get_progress().position.byte_offset)
+
+ self.assertEqual(len(progress_record), 3)
+ self.assertEqual(progress_record, [0, 6, 13])
+
+ def try_splitting_reader_at(self, reader, split_request, expected_response):
+ actual_response = reader.request_dynamic_split(split_request)
+
+ if expected_response is None:
+ self.assertIsNone(actual_response)
+ else:
+ self.assertIsNotNone(actual_response.stop_position)
+ self.assertIsInstance(actual_response.stop_position,
+ iobase.ReaderPosition)
+ self.assertIsNotNone(actual_response.stop_position.byte_offset)
+ self.assertEqual(expected_response.stop_position.byte_offset,
+ actual_response.stop_position.byte_offset)
+
+ return actual_response
+
+ def test_update_stop_position_for_percent_complete(self):
+ lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file('\n'.join(lines)))
+ with source.reader() as reader:
+ # Reading two lines
+ reader_iter = iter(reader)
+ next(reader_iter)
+ next(reader_iter)
+ next(reader_iter)
+
+ # Splitting at end of the range should be unsuccessful
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=0)),
+ None)
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=1)),
+ None)
+
+ # Splitting at positions on or before start offset of the last record
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
+ 0.2)),
+ None)
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
+ 0.4)),
+ None)
+
+ # Splitting at a position after the start offset of the last record should
+ # be successful
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
+ 0.6)),
+ iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition(
+ byte_offset=15)))
+
+ def test_update_stop_position_percent_complete_for_position(self):
+ lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
+ source = fileio.TextFileSource(
+ file_path=self.create_temp_file('\n'.join(lines)))
+ with source.reader() as reader:
+ # Reading two lines
+ reader_iter = iter(reader)
+ next(reader_iter)
+ next(reader_iter)
+ next(reader_iter)
+
+ # Splitting at end of the range should be unsuccessful
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(
+ position=iobase.ReaderPosition(byte_offset=0))),
+ None)
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(
+ position=iobase.ReaderPosition(byte_offset=25))),
+ None)
+
+ # Splitting at positions on or before start offset of the last record
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(
+ position=iobase.ReaderPosition(byte_offset=5))),
+ None)
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(
+ position=iobase.ReaderPosition(byte_offset=10))),
+ None)
+
+ # Splitting at a position after the start offset of the last record should
+ # be successful
+ self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(iobase.ReaderProgress(
+ position=iobase.ReaderPosition(byte_offset=15))),
+ iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition(
+ byte_offset=15)))
+
+ def run_update_stop_position_exhaustive(self, lines, newline):
+ """An exhaustive test for dynamic splitting.
+
+ For the given set of data items, try to perform a split at all possible
+ combinations of following.
+
+ * start position
+ * original stop position
+ * updated stop position
+ * number of items read
+
+ Args:
+ lines: set of data items to be used to create the file
+ newline: separater to be used when writing give set of lines to a text
+ file.
+ """
+
+ file_path = self.create_temp_file(newline.join(lines))
+
+ total_records = len(lines)
+ total_bytes = 0
+
+ for line in lines:
+ total_bytes += len(line)
+ total_bytes += len(newline) * (total_records - 1)
+
+ for start in xrange(0, total_bytes - 1):
+ for end in xrange(start + 1, total_bytes):
+ for stop in xrange(start, end):
+ for records_to_read in range(0, total_records):
+ self.run_update_stop_position(start, end, stop, records_to_read,
+ file_path)
+
+ def test_update_stop_position_exhaustive(self):
+ self.run_update_stop_position_exhaustive(
+ ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\n')
+
+ def test_update_stop_position_exhaustive_with_empty_lines(self):
+ self.run_update_stop_position_exhaustive(
+ ['', 'aaaa', '', 'bbbb', 'cccc', '', 'dddd', 'eeee', ''], '\n')
+
+ def test_update_stop_position_exhaustive_windows_newline(self):
+ self.run_update_stop_position_exhaustive(
+ ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\r\n')
+
+ def test_update_stop_position_exhaustive_multi_byte(self):
+ self.run_update_stop_position_exhaustive(
+ [u'\u0d85\u0d85\u0d85\u0d85'.encode('utf-8'), u'\u0db6\u0db6\u0db6\u0db6'.encode('utf-8'),
+ u'\u0d9a\u0d9a\u0d9a\u0d9a'.encode('utf-8')], '\n')
+
+ def run_update_stop_position(self, start_offset, end_offset, stop_offset,
+ records_to_read,
+ file_path):
+ source = fileio.TextFileSource(file_path, start_offset, end_offset)
+
+ records_of_first_split = ''
+
+ with source.reader() as reader:
+ reader_iter = iter(reader)
+ i = 0
+
+ try:
+ while i < records_to_read:
+ records_of_first_split += next(reader_iter)
+ i += 1
+ except StopIteration:
+ # Invalid case, given source does not contain this many records.
+ return
+
+ last_record_start_after_reading = reader.range_tracker.last_record_start
+
+ if stop_offset <= last_record_start_after_reading:
+ expected_split_response = None
+ elif stop_offset == start_offset or stop_offset == end_offset:
+ expected_split_response = None
+ elif records_to_read == 0:
+ expected_split_response = None # unstarted
+ else:
+ expected_split_response = iobase.DynamicSplitResultWithPosition(
+ stop_position=iobase.ReaderPosition(byte_offset=stop_offset))
+
+ split_response = self.try_splitting_reader_at(
+ reader,
+ iobase.DynamicSplitRequest(progress=iobase.ReaderProgress(
+ iobase.ReaderPosition(byte_offset=stop_offset))),
+ expected_split_response)
+
+ # Reading remaining records from the updated reader.
+ for line in reader:
+ records_of_first_split += line
+
+ if split_response is not None:
+ # Total contents received by reading the two splits should be equal to the
+ # result obtained by reading the original source.
+ records_of_original = ''
+ records_of_second_split = ''
+
+ with source.reader() as original_reader:
+ for line in original_reader:
+ records_of_original += line
+
+ new_source = fileio.TextFileSource(
+ file_path,
+ split_response.stop_position.byte_offset,
+ end_offset)
+ with new_source.reader() as reader:
+ for line in reader:
+ records_of_second_split += line
+
+ self.assertEqual(records_of_original,
+ records_of_first_split + records_of_second_split)
+
+ def test_various_offset_combination_with_local_file_for_read(self):
+ lines = ['01234', '6789012', '456789012']
+ self.read_with_offsets(lines, lines[1:], start_offset=5)
+ self.read_with_offsets(lines, lines[1:], start_offset=6)
+ self.read_with_offsets(lines, lines[2:], start_offset=7)
+ self.read_with_offsets(lines, lines[1:2], start_offset=5, end_offset=13)
+ self.read_with_offsets(lines, lines[1:2], start_offset=5, end_offset=14)
+ self.read_with_offsets(lines, lines[1:], start_offset=5, end_offset=16)
+ self.read_with_offsets(lines, lines[2:], start_offset=14, end_offset=20)
+ self.read_with_offsets(lines, lines[2:], start_offset=14)
+ self.read_with_offsets(lines, [], start_offset=20, end_offset=20)
+
+ def test_various_offset_combination_with_local_file_for_progress(self):
+ lines = ['01234', '6789012', '456789012']
+ self.progress_with_offsets(lines, start_offset=5)
+ self.progress_with_offsets(lines, start_offset=6)
+ self.progress_with_offsets(lines, start_offset=7)
+ self.progress_with_offsets(lines, start_offset=5, end_offset=13)
+ self.progress_with_offsets(lines, start_offset=5, end_offset=14)
+ self.progress_with_offsets(lines, start_offset=5, end_offset=16)
+ self.progress_with_offsets(lines, start_offset=14, end_offset=20)
+ self.progress_with_offsets(lines, start_offset=14)
+ self.progress_with_offsets(lines, start_offset=20, end_offset=20)
+
+
+class NativeTestTextFileSink(unittest.TestCase):
+
+ def create_temp_file(self):
+ temp = tempfile.NamedTemporaryFile(delete=False)
+ return temp.name
+
+ def test_write_entire_file(self):
+ lines = ['First', 'Second', 'Third']
+ file_path = self.create_temp_file()
+ sink = fileio.NativeTextFileSink(file_path)
+ with sink.writer() as writer:
+ for line in lines:
+ writer.Write(line)
+ with open(file_path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), lines)
+
+
+class TestPureTextFileSink(unittest.TestCase):
+
+ def setUp(self):
+ self.lines = ['Line %d' % d for d in range(100)]
+ self.path = tempfile.NamedTemporaryFile().name
+
+ def _write_lines(self, sink, lines):
+ f = sink.open(self.path)
+ for line in lines:
+ sink.write_record(f, line)
+ sink.close(f)
+
+ def test_write_text_file(self):
+ sink = fileio.TextFileSink(self.path)
+ self._write_lines(sink, self.lines)
+
+ with open(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), self.lines)
+
+ def test_write_gzip_file(self):
+ sink = fileio.TextFileSink(
+ self.path, compression_type=fileio.CompressionTypes.DEFLATE)
+ self._write_lines(sink, self.lines)
+
+ with gzip.GzipFile(self.path, 'r') as f:
+ self.assertEqual(f.read().splitlines(), self.lines)
+
+
+class MyFileSink(fileio.FileSink):
+
+ def open(self, temp_path):
+ # TODO(robertwb): Fix main session pickling.
+ # file_handle = super(MyFileSink, self).open(temp_path)
+ file_handle = fileio.FileSink.open(self, temp_path)
+ file_handle.write('[start]')
+ return file_handle
+
+ def write_encoded_record(self, file_handle, encoded_value):
+ file_handle.write('[')
+ file_handle.write(encoded_value)
+ file_handle.write(']')
+
+ def close(self, file_handle):
+ file_handle.write('[end]')
+ # TODO(robertwb): Fix main session pickling.
+ # file_handle = super(MyFileSink, self).close(file_handle)
+ file_handle = fileio.FileSink.close(self, file_handle)
+
+
+class TestFileSink(unittest.TestCase):
+
+ def test_file_sink_writing(self):
+ temp_path = tempfile.NamedTemporaryFile().name
+ sink = MyFileSink(temp_path,
+ file_name_suffix='.foo',
+ coder=coders.ToStringCoder())
+
+ # Manually invoke the generic Sink API.
+ init_token = sink.initialize_write()
+
+ writer1 = sink.open_writer(init_token, '1')
+ writer1.write('a')
+ writer1.write('b')
+ res1 = writer1.close()
+
+ writer2 = sink.open_writer(init_token, '2')
+ writer2.write('x')
+ writer2.write('y')
+ writer2.write('z')
+ res2 = writer2.close()
+
+ res = list(sink.finalize_write(init_token, [res1, res2]))
+ # Retry the finalize operation (as if the first attempt was lost).
+ res = list(sink.finalize_write(init_token, [res1, res2]))
+
+ # Check the results.
+ shard1 = temp_path + '-00000-of-00002.foo'
+ shard2 = temp_path + '-00001-of-00002.foo'
+ self.assertEqual(res, [shard1, shard2])
+ self.assertEqual(open(shard1).read(), '[start][a][b][end]')
+ self.assertEqual(open(shard2).read(), '[start][x][y][z][end]')
+
+ # Check that any temp files are deleted.
+ self.assertItemsEqual([shard1, shard2], glob.glob(temp_path + '*'))
+
+ def test_empty_write(self):
+ temp_path = tempfile.NamedTemporaryFile().name
+ sink = MyFileSink(temp_path,
+ file_name_suffix='.foo',
+ coder=coders.ToStringCoder())
+ p = df.Pipeline('DirectPipelineRunner')
+ p | df.Create([]) | df.io.Write(sink) # pylint: disable=expression-not-assigned
+ p.run()
+ self.assertEqual(open(temp_path + '-00000-of-00001.foo').read(),
+ '[start][end]')
+
+ def test_fixed_shard_write(self):
+ temp_path = tempfile.NamedTemporaryFile().name
+ sink = MyFileSink(temp_path,
+ file_name_suffix='.foo',
+ num_shards=3,
+ shard_name_template='_NN_SSS_',
+ coder=coders.ToStringCoder())
+ p = df.Pipeline('DirectPipelineRunner')
+ p | df.Create(['a', 'b']) | df.io.Write(sink) # pylint: disable=expression-not-assigned
+
+ p.run()
+
+ concat = ''.join(open(temp_path + '_03_%03d_.foo' % shard_num).read()
+ for shard_num in range(3))
+ self.assertTrue('][a][' in concat, concat)
+ self.assertTrue('][b][' in concat, concat)
+
+ def test_file_sink_multi_shards(self):
+ temp_path = tempfile.NamedTemporaryFile().name
+ sink = MyFileSink(temp_path,
+ file_name_suffix='.foo',
+ coder=coders.ToStringCoder())
+
+ # Manually invoke the generic Sink API.
+ init_token = sink.initialize_write()
+
+ num_shards = 1000
+ writer_results = []
+ for i in range(num_shards):
+ uuid = 'uuid-%05d' % i
+ writer = sink.open_writer(init_token, uuid)
+ writer.write('a')
+ writer.write('b')
+ writer.write(uuid)
+ writer_results.append(writer.close())
+
+ res_first = list(sink.finalize_write(init_token, writer_results))
+ # Retry the finalize operation (as if the first attempt was lost).
+ res_second = list(sink.finalize_write(init_token, writer_results))
+
+ self.assertItemsEqual(res_first, res_second)
+
+ res = sorted(res_second)
+ for i in range(num_shards):
+ shard_name = '%s-%05d-of-%05d.foo' % (temp_path, i, num_shards)
+ uuid = 'uuid-%05d' % i
+ self.assertEqual(res[i], shard_name)
+ self.assertEqual(
+ open(shard_name).read(), ('[start][a][b][%s][end]' % uuid))
+
+ # Check that any temp files are deleted.
+ self.assertItemsEqual(res, glob.glob(temp_path + '*'))
+
+ def test_file_sink_io_error(self):
+ temp_path = tempfile.NamedTemporaryFile().name
+ sink = MyFileSink(temp_path,
+ file_name_suffix='.foo',
+ coder=coders.ToStringCoder())
+
+ # Manually invoke the generic Sink API.
+ init_token = sink.initialize_write()
+
+ writer1 = sink.open_writer(init_token, '1')
+ writer1.write('a')
+ writer1.write('b')
+ res1 = writer1.close()
+
+ writer2 = sink.open_writer(init_token, '2')
+ writer2.write('x')
+ writer2.write('y')
+ writer2.write('z')
+ res2 = writer2.close()
+
+ os.remove(res2)
+ with self.assertRaises(IOError):
+ list(sink.finalize_write(init_token, [res1, res2]))
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/gcsio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcsio.py b/sdks/python/apache_beam/io/gcsio.py
new file mode 100644
index 0000000..8157b76
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcsio.py
@@ -0,0 +1,602 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Google Cloud Storage client.
+
+This library evolved from the Google App Engine GCS client available at
+https://github.com/GoogleCloudPlatform/appengine-gcs-client.
+"""
+
+import errno
+import fnmatch
+import logging
+import multiprocessing
+import os
+import re
+import StringIO
+import threading
+
+from google.cloud.dataflow.internal import auth
+from google.cloud.dataflow.utils import retry
+
+from apitools.base.py.exceptions import HttpError
+import apitools.base.py.transfer as transfer
+
+# Issue a friendlier error message if the storage library is not available.
+# TODO(silviuc): Remove this guard when storage is available everywhere.
+try:
+ # pylint: disable=g-import-not-at-top
+ from google.cloud.dataflow.internal.clients import storage
+except ImportError:
+ raise RuntimeError(
+ 'Google Cloud Storage I/O not supported for this execution environment '
+ '(could not import storage API client).')
+
+
+DEFAULT_READ_BUFFER_SIZE = 1024 * 1024
+
+
+def parse_gcs_path(gcs_path):
+ """Return the bucket and object names of the given gs:// path."""
+ match = re.match('^gs://([^/]+)/(.+)$', gcs_path)
+ if match is None:
+ raise ValueError('GCS path must be in the form gs://<bucket>/<object>.')
+ return match.group(1), match.group(2)
+
+
+class GcsIOError(IOError, retry.PermanentException):
+ """GCS IO error that should not be retried."""
+ pass
+
+
+class GcsIO(object):
+ """Google Cloud Storage I/O client."""
+
+ def __new__(cls, storage_client=None):
+ if storage_client:
+ return super(GcsIO, cls).__new__(cls, storage_client)
+ else:
+ # Create a single storage client for each thread. We would like to avoid
+ # creating more than one storage client for each thread, since each
+ # initialization requires the relatively expensive step of initializing
+ # credentaials.
+ local_state = threading.local()
+ if getattr(local_state, 'gcsio_instance', None) is None:
+ credentials = auth.get_service_credentials()
+ storage_client = storage.StorageV1(credentials=credentials)
+ local_state.gcsio_instance = (
+ super(GcsIO, cls).__new__(cls, storage_client))
+ local_state.gcsio_instance.client = storage_client
+ return local_state.gcsio_instance
+
+ def __init__(self, storage_client=None):
+ # We must do this check on storage_client because the client attribute may
+ # have already been set in __new__ for the singleton case when
+ # storage_client is None.
+ if storage_client is not None:
+ self.client = storage_client
+
+ def open(self, filename, mode='r',
+ read_buffer_size=DEFAULT_READ_BUFFER_SIZE,
+ mime_type='application/octet-stream'):
+ """Open a GCS file path for reading or writing.
+
+ Args:
+ filename: GCS file path in the form gs://<bucket>/<object>.
+ mode: 'r' for reading or 'w' for writing.
+ read_buffer_size: Buffer size to use during read operations.
+ mime_type: Mime type to set for write operations.
+
+ Returns:
+ file object.
+
+ Raises:
+ ValueError: Invalid open file mode.
+ """
+ if mode == 'r' or mode == 'rb':
+ return GcsBufferedReader(self.client, filename,
+ buffer_size=read_buffer_size)
+ elif mode == 'w' or mode == 'wb':
+ return GcsBufferedWriter(self.client, filename, mime_type=mime_type)
+ else:
+ raise ValueError('Invalid file open mode: %s.' % mode)
+
+ @retry.with_exponential_backoff(
+ retry_filter=retry.retry_on_server_errors_and_timeout_filter)
+ def glob(self, pattern):
+ """Return the GCS path names matching a given path name pattern.
+
+ Path name patterns are those recognized by fnmatch.fnmatch(). The path
+ can contain glob characters (*, ?, and [...] sets).
+
+ Args:
+ pattern: GCS file path pattern in the form gs://<bucket>/<name_pattern>.
+
+ Returns:
+ list of GCS file paths matching the given pattern.
+ """
+ bucket, name_pattern = parse_gcs_path(pattern)
+ # Get the prefix with which we can list objects in the given bucket.
+ prefix = re.match('^[^[*?]*', name_pattern).group(0)
+ request = storage.StorageObjectsListRequest(bucket=bucket, prefix=prefix)
+ object_paths = []
+ while True:
+ response = self.client.objects.List(request)
+ for item in response.items:
+ if fnmatch.fnmatch(item.name, name_pattern):
+ object_paths.append('gs://%s/%s' % (item.bucket, item.name))
+ if response.nextPageToken:
+ request.pageToken = response.nextPageToken
+ else:
+ break
+ return object_paths
+
+ @retry.with_exponential_backoff(
+ retry_filter=retry.retry_on_server_errors_and_timeout_filter)
+ def delete(self, path):
+ """Deletes the object at the given GCS path.
+
+ Args:
+ path: GCS file path pattern in the form gs://<bucket>/<name>.
+ """
+ bucket, object_path = parse_gcs_path(path)
+ request = storage.StorageObjectsDeleteRequest(bucket=bucket,
+ object=object_path)
+ try:
+ self.client.objects.Delete(request)
+ except HttpError as http_error:
+ if http_error.status_code == 404:
+ # Return success when the file doesn't exist anymore for idempotency.
+ return
+ raise
+
+ @retry.with_exponential_backoff(
+ retry_filter=retry.retry_on_server_errors_and_timeout_filter)
+ def copy(self, src, dest):
+ """Copies the given GCS object from src to dest.
+
+ Args:
+ src: GCS file path pattern in the form gs://<bucket>/<name>.
+ dest: GCS file path pattern in the form gs://<bucket>/<name>.
+ """
+ src_bucket, src_path = parse_gcs_path(src)
+ dest_bucket, dest_path = parse_gcs_path(dest)
+ request = storage.StorageObjectsCopyRequest(sourceBucket=src_bucket,
+ sourceObject=src_path,
+ destinationBucket=dest_bucket,
+ destinationObject=dest_path)
+ try:
+ self.client.objects.Copy(request)
+ except HttpError as http_error:
+ if http_error.status_code == 404:
+ # This is a permanent error that should not be retried. Note that
+ # FileSink.finalize_write expects an IOError when the source file does
+ # not exist.
+ raise GcsIOError(errno.ENOENT, 'Source file not found: %s' % src)
+ raise
+
+ # We intentionally do not decorate this method with a retry, since the
+ # underlying copy and delete operations are already idempotent operations
+ # protected by retry decorators.
+ def copytree(self, src, dest):
+ """Renames the given GCS "directory" recursively from src to dest.
+
+ Args:
+ src: GCS file path pattern in the form gs://<bucket>/<name>/.
+ dest: GCS file path pattern in the form gs://<bucket>/<name>/.
+ """
+ assert src.endswith('/')
+ assert dest.endswith('/')
+ for entry in self.glob(src + '*'):
+ rel_path = entry[len(src):]
+ self.copy(entry, dest + rel_path)
+
+ # We intentionally do not decorate this method with a retry, since the
+ # underlying copy and delete operations are already idempotent operations
+ # protected by retry decorators.
+ def rename(self, src, dest):
+ """Renames the given GCS object from src to dest.
+
+ Args:
+ src: GCS file path pattern in the form gs://<bucket>/<name>.
+ dest: GCS file path pattern in the form gs://<bucket>/<name>.
+ """
+ self.copy(src, dest)
+ self.delete(src)
+
+ @retry.with_exponential_backoff(
+ retry_filter=retry.retry_on_server_errors_and_timeout_filter)
+ def exists(self, path):
+ """Returns whether the given GCS object exists.
+
+ Args:
+ path: GCS file path pattern in the form gs://<bucket>/<name>.
+ """
+ bucket, object_path = parse_gcs_path(path)
+ try:
+ request = storage.StorageObjectsGetRequest(bucket=bucket,
+ object=object_path)
+ self.client.objects.Get(request) # metadata
+ return True
+ except IOError:
+ return False
+
+
+class GcsBufferedReader(object):
+ """A class for reading Google Cloud Storage files."""
+
+ def __init__(self, client, path, buffer_size=DEFAULT_READ_BUFFER_SIZE):
+ self.client = client
+ self.path = path
+ self.bucket, self.name = parse_gcs_path(path)
+ self.buffer_size = buffer_size
+
+ # Get object state.
+ get_request = (
+ storage.StorageObjectsGetRequest(
+ bucket=self.bucket,
+ object=self.name))
+ try:
+ metadata = self._get_object_metadata(get_request)
+ except HttpError as http_error:
+ if http_error.status_code == 404:
+ raise IOError(errno.ENOENT, 'Not found: %s' % self.path)
+ else:
+ logging.error(
+ 'HTTP error while requesting file %s: %s', self.path, http_error)
+ raise
+ self.size = metadata.size
+
+ # Ensure read is from file of the correct generation.
+ get_request.generation = metadata.generation
+
+ # Initialize read buffer state.
+ self.download_stream = StringIO.StringIO()
+ self.downloader = transfer.Download(
+ self.download_stream, auto_transfer=False)
+ self.client.objects.Get(get_request, download=self.downloader)
+ self.position = 0
+ self.buffer = ''
+ self.buffer_start_position = 0
+ self.closed = False
+
+ @retry.with_exponential_backoff(
+ retry_filter=retry.retry_on_server_errors_and_timeout_filter)
+ def _get_object_metadata(self, get_request):
+ return self.client.objects.Get(get_request)
+
+ def read(self, size=-1):
+ """Read data from a GCS file.
+
+ Args:
+ size: Number of bytes to read. Actual number of bytes read is always
+ equal to size unless EOF is reached. If size is negative or
+ unspecified, read the entire file.
+
+ Returns:
+ data read as str.
+
+ Raises:
+ IOError: When this buffer is closed.
+ """
+ return self._read_inner(size=size, readline=False)
+
+ def readline(self, size=-1):
+ """Read one line delimited by '\\n' from the file.
+
+ Mimics behavior of the readline() method on standard file objects.
+
+ A trailing newline character is kept in the string. It may be absent when a
+ file ends with an incomplete line. If the size argument is non-negative,
+ it specifies the maximum string size (counting the newline) to return.
+ A negative size is the same as unspecified. Empty string is returned
+ only when EOF is encountered immediately.
+
+ Args:
+ size: Maximum number of bytes to read. If not specified, readline stops
+ only on '\\n' or EOF.
+
+ Returns:
+ The data read as a string.
+
+ Raises:
+ IOError: When this buffer is closed.
+ """
+ return self._read_inner(size=size, readline=True)
+
+ def _read_inner(self, size=-1, readline=False):
+ """Shared implementation of read() and readline()."""
+ self._check_open()
+ if not self._remaining():
+ return ''
+
+ # Prepare to read.
+ data_list = []
+ if size is None:
+ size = -1
+ to_read = min(size, self._remaining())
+ if to_read < 0:
+ to_read = self._remaining()
+ break_after = False
+
+ while to_read > 0:
+ # If we have exhausted the buffer, get the next segment.
+ # TODO(ccy): We should consider prefetching the next block in another
+ # thread.
+ self._fetch_next_if_buffer_exhausted()
+
+ # Determine number of bytes to read from buffer.
+ buffer_bytes_read = self.position - self.buffer_start_position
+ bytes_to_read_from_buffer = min(
+ len(self.buffer) - buffer_bytes_read, to_read)
+
+ # If readline is set, we only want to read up to and including the next
+ # newline character.
+ if readline:
+ next_newline_position = self.buffer.find(
+ '\n', buffer_bytes_read, len(self.buffer))
+ if next_newline_position != -1:
+ bytes_to_read_from_buffer = (1 + next_newline_position -
+ buffer_bytes_read)
+ break_after = True
+
+ # Read bytes.
+ data_list.append(
+ self.buffer[buffer_bytes_read:buffer_bytes_read +
+ bytes_to_read_from_buffer])
+ self.position += bytes_to_read_from_buffer
+ to_read -= bytes_to_read_from_buffer
+
+ if break_after:
+ break
+
+ return ''.join(data_list)
+
+ def _fetch_next_if_buffer_exhausted(self):
+ if not self.buffer or (self.buffer_start_position + len(self.buffer)
+ <= self.position):
+ bytes_to_request = min(self._remaining(), self.buffer_size)
+ self.buffer_start_position = self.position
+ self.buffer = self._get_segment(self.position, bytes_to_request)
+
+ def _remaining(self):
+ return self.size - self.position
+
+ def close(self):
+ """Close the current GCS file."""
+ self.closed = True
+ self.download_stream = None
+ self.downloader = None
+ self.buffer = None
+
+ def _get_segment(self, start, size):
+ """Get the given segment of the current GCS file."""
+ if size == 0:
+ return ''
+ end = start + size - 1
+ self.downloader.GetRange(start, end)
+ value = self.download_stream.getvalue()
+ # Clear the StringIO object after we've read its contents.
+ self.download_stream.truncate(0)
+ assert len(value) == size
+ return value
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self.close()
+
+ def seek(self, offset, whence=os.SEEK_SET):
+ """Set the file's current offset.
+
+ Note if the new offset is out of bound, it is adjusted to either 0 or EOF.
+
+ Args:
+ offset: seek offset as number.
+ whence: seek mode. Supported modes are os.SEEK_SET (absolute seek),
+ os.SEEK_CUR (seek relative to the current position), and os.SEEK_END
+ (seek relative to the end, offset should be negative).
+
+ Raises:
+ IOError: When this buffer is closed.
+ ValueError: When whence is invalid.
+ """
+ self._check_open()
+
+ self.buffer = ''
+ self.buffer_start_position = -1
+
+ if whence == os.SEEK_SET:
+ self.position = offset
+ elif whence == os.SEEK_CUR:
+ self.position += offset
+ elif whence == os.SEEK_END:
+ self.position = self.size + offset
+ else:
+ raise ValueError('Whence mode %r is invalid.' % whence)
+
+ self.position = min(self.position, self.size)
+ self.position = max(self.position, 0)
+
+ def tell(self):
+ """Tell the file's current offset.
+
+ Returns:
+ current offset in reading this file.
+
+ Raises:
+ IOError: When this buffer is closed.
+ """
+ self._check_open()
+ return self.position
+
+ def _check_open(self):
+ if self.closed:
+ raise IOError('Buffer is closed.')
+
+ def seekable(self):
+ return True
+
+ def readable(self):
+ return True
+
+ def writable(self):
+ return False
+
+
+class GcsBufferedWriter(object):
+ """A class for writing Google Cloud Storage files."""
+
+ class PipeStream(object):
+ """A class that presents a pipe connection as a readable stream."""
+
+ def __init__(self, recv_pipe):
+ self.conn = recv_pipe
+ self.closed = False
+ self.position = 0
+ self.remaining = ''
+
+ def read(self, size):
+ """Read data from the wrapped pipe connection.
+
+ Args:
+ size: Number of bytes to read. Actual number of bytes read is always
+ equal to size unless EOF is reached.
+
+ Returns:
+ data read as str.
+ """
+ data_list = []
+ bytes_read = 0
+ while bytes_read < size:
+ bytes_from_remaining = min(size - bytes_read, len(self.remaining))
+ data_list.append(self.remaining[0:bytes_from_remaining])
+ self.remaining = self.remaining[bytes_from_remaining:]
+ self.position += bytes_from_remaining
+ bytes_read += bytes_from_remaining
+ if not self.remaining:
+ try:
+ self.remaining = self.conn.recv_bytes()
+ except EOFError:
+ break
+ return ''.join(data_list)
+
+ def tell(self):
+ """Tell the file's current offset.
+
+ Returns:
+ current offset in reading this file.
+
+ Raises:
+ IOError: When this stream is closed.
+ """
+ self._check_open()
+ return self.position
+
+ def seek(self, offset, whence=os.SEEK_SET):
+ # The apitools.base.py.transfer.Upload class insists on seeking to the end
+ # of a stream to do a check before completing an upload, so we must have
+ # this no-op method here in that case.
+ if whence == os.SEEK_END and offset == 0:
+ return
+ elif whence == os.SEEK_SET and offset == self.position:
+ return
+ raise NotImplementedError
+
+ def _check_open(self):
+ if self.closed:
+ raise IOError('Stream is closed.')
+
+ def __init__(self, client, path, mime_type='application/octet-stream'):
+ self.client = client
+ self.path = path
+ self.bucket, self.name = parse_gcs_path(path)
+
+ self.closed = False
+ self.position = 0
+
+ # Set up communication with uploading thread.
+ parent_conn, child_conn = multiprocessing.Pipe()
+ self.conn = parent_conn
+
+ # Set up uploader.
+ self.insert_request = (
+ storage.StorageObjectsInsertRequest(
+ bucket=self.bucket,
+ name=self.name))
+ self.upload = transfer.Upload(GcsBufferedWriter.PipeStream(child_conn),
+ mime_type)
+ self.upload.strategy = transfer.RESUMABLE_UPLOAD
+
+ # Start uploading thread.
+ self.upload_thread = threading.Thread(target=self._start_upload)
+ self.upload_thread.daemon = True
+ self.upload_thread.start()
+
+ # TODO(silviuc): Refactor so that retry logic can be applied.
+ # There is retry logic in the underlying transfer library but we should make
+ # it more explicit so we can control the retry parameters.
+ @retry.no_retries # Using no_retries marks this as an integration point.
+ def _start_upload(self):
+ # This starts the uploader thread. We are forced to run the uploader in
+ # another thread because the apitools uploader insists on taking a stream
+ # as input. Happily, this also means we get asynchronous I/O to GCS.
+ #
+ # The uploader by default transfers data in chunks of 1024 * 1024 bytes at
+ # a time, buffering writes until that size is reached.
+ self.client.objects.Insert(self.insert_request, upload=self.upload)
+
+ def write(self, data):
+ """Write data to a GCS file.
+
+ Args:
+ data: data to write as str.
+
+ Raises:
+ IOError: When this buffer is closed.
+ """
+ self._check_open()
+ if not data:
+ return
+ self.conn.send_bytes(data)
+ self.position += len(data)
+
+ def tell(self):
+ """Return the total number of bytes passed to write() so far."""
+ return self.position
+
+ def close(self):
+ """Close the current GCS file."""
+ self.conn.close()
+ self.upload_thread.join()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self.close()
+
+ def _check_open(self):
+ if self.closed:
+ raise IOError('Buffer is closed.')
+
+ def seekable(self):
+ return False
+
+ def readable(self):
+ return False
+
+ def writable(self):
+ return True
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/gcsio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcsio_test.py b/sdks/python/apache_beam/io/gcsio_test.py
new file mode 100644
index 0000000..702c834
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcsio_test.py
@@ -0,0 +1,503 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for Google Cloud Storage client."""
+
+import logging
+import multiprocessing
+import os
+import random
+import threading
+import unittest
+
+
+import httplib2
+
+from google.cloud.dataflow.io import gcsio
+from apitools.base.py.exceptions import HttpError
+from google.cloud.dataflow.internal.clients import storage
+
+
+class FakeGcsClient(object):
+ # Fake storage client. Usage in gcsio.py is client.objects.Get(...) and
+ # client.objects.Insert(...).
+
+ def __init__(self):
+ self.objects = FakeGcsObjects()
+
+
+class FakeFile(object):
+
+ def __init__(self, bucket, obj, contents, generation):
+ self.bucket = bucket
+ self.object = obj
+ self.contents = contents
+ self.generation = generation
+
+ def get_metadata(self):
+ return storage.Object(bucket=self.bucket,
+ name=self.object,
+ generation=self.generation,
+ size=len(self.contents))
+
+
+class FakeGcsObjects(object):
+
+ def __init__(self):
+ self.files = {}
+ # Store the last generation used for a given object name. Note that this
+ # has to persist even past the deletion of the object.
+ self.last_generation = {}
+ self.list_page_tokens = {}
+
+ def add_file(self, f):
+ self.files[(f.bucket, f.object)] = f
+ self.last_generation[(f.bucket, f.object)] = f.generation
+
+ def get_file(self, bucket, obj):
+ return self.files.get((bucket, obj), None)
+
+ def delete_file(self, bucket, obj):
+ del self.files[(bucket, obj)]
+
+ def get_last_generation(self, bucket, obj):
+ return self.last_generation.get((bucket, obj), 0)
+
+ def Get(self, get_request, download=None): # pylint: disable=invalid-name
+ f = self.get_file(get_request.bucket, get_request.object)
+ if f is None:
+ raise ValueError('Specified object does not exist.')
+ if download is None:
+ return f.get_metadata()
+ else:
+ stream = download.stream
+
+ def get_range_callback(start, end):
+ assert start >= 0 and end >= start and end < len(f.contents)
+ stream.write(f.contents[start:end + 1])
+ download.GetRange = get_range_callback
+
+ def Insert(self, insert_request, upload=None): # pylint: disable=invalid-name
+ assert upload is not None
+ generation = self.get_last_generation(insert_request.bucket,
+ insert_request.name) + 1
+ f = FakeFile(insert_request.bucket, insert_request.name, '', generation)
+
+ # Stream data into file.
+ stream = upload.stream
+ data_list = []
+ while True:
+ data = stream.read(1024 * 1024)
+ if not data:
+ break
+ data_list.append(data)
+ f.contents = ''.join(data_list)
+
+ self.add_file(f)
+
+ def Copy(self, copy_request): # pylint: disable=invalid-name
+ src_file = self.get_file(copy_request.sourceBucket,
+ copy_request.sourceObject)
+ if not src_file:
+ raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found',
+ 'https://fake/url')
+ generation = self.get_last_generation(copy_request.destinationBucket,
+ copy_request.destinationObject) + 1
+ dest_file = FakeFile(copy_request.destinationBucket,
+ copy_request.destinationObject,
+ src_file.contents, generation)
+ self.add_file(dest_file)
+
+ def Delete(self, delete_request): # pylint: disable=invalid-name
+ # Here, we emulate the behavior of the GCS service in raising a 404 error
+ # if this object already exists.
+ if self.get_file(delete_request.bucket, delete_request.object):
+ self.delete_file(delete_request.bucket, delete_request.object)
+ else:
+ raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found',
+ 'https://fake/url')
+
+ def List(self, list_request): # pylint: disable=invalid-name
+ bucket = list_request.bucket
+ prefix = list_request.prefix or ''
+ matching_files = []
+ for file_bucket, file_name in sorted(iter(self.files)):
+ if bucket == file_bucket and file_name.startswith(prefix):
+ file_object = self.files[(file_bucket, file_name)].get_metadata()
+ matching_files.append(file_object)
+
+ # Handle pagination.
+ items_per_page = 5
+ if not list_request.pageToken:
+ range_start = 0
+ else:
+ if list_request.pageToken not in self.list_page_tokens:
+ raise ValueError('Invalid page token.')
+ range_start = self.list_page_tokens[list_request.pageToken]
+ del self.list_page_tokens[list_request.pageToken]
+
+ result = storage.Objects(
+ items=matching_files[range_start:range_start + items_per_page])
+ if range_start + items_per_page < len(matching_files):
+ next_range_start = range_start + items_per_page
+ next_page_token = '_page_token_%s_%s_%d' % (bucket, prefix,
+ next_range_start)
+ self.list_page_tokens[next_page_token] = next_range_start
+ result.nextPageToken = next_page_token
+ return result
+
+
+class TestGCSPathParser(unittest.TestCase):
+
+ def test_gcs_path(self):
+ self.assertEqual(
+ gcsio.parse_gcs_path('gs://bucket/name'), ('bucket', 'name'))
+ self.assertEqual(
+ gcsio.parse_gcs_path('gs://bucket/name/sub'), ('bucket', 'name/sub'))
+
+ def test_bad_gcs_path(self):
+ self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://')
+ self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://bucket')
+ self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://bucket/')
+ self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:///name')
+ self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:///')
+ self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:/blah/bucket/name')
+
+
+class TestGCSIO(unittest.TestCase):
+
+ def _insert_random_file(self, client, path, size, generation=1):
+ bucket, name = gcsio.parse_gcs_path(path)
+ f = FakeFile(bucket, name, os.urandom(size), generation)
+ client.objects.add_file(f)
+ return f
+
+ def setUp(self):
+ self.client = FakeGcsClient()
+ self.gcs = gcsio.GcsIO(self.client)
+
+ def test_delete(self):
+ file_name = 'gs://gcsio-test/delete_me'
+ file_size = 1024
+
+ # Test deletion of non-existent file.
+ self.gcs.delete(file_name)
+
+ self._insert_random_file(self.client, file_name, file_size)
+ self.assertTrue(gcsio.parse_gcs_path(file_name) in
+ self.client.objects.files)
+
+ self.gcs.delete(file_name)
+
+ self.assertFalse(gcsio.parse_gcs_path(file_name) in
+ self.client.objects.files)
+
+ def test_copy(self):
+ src_file_name = 'gs://gcsio-test/source'
+ dest_file_name = 'gs://gcsio-test/dest'
+ file_size = 1024
+ self._insert_random_file(self.client, src_file_name,
+ file_size)
+ self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
+ self.client.objects.files)
+ self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
+ self.client.objects.files)
+
+ self.gcs.copy(src_file_name, dest_file_name)
+
+ self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
+ self.client.objects.files)
+ self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
+ self.client.objects.files)
+
+ self.assertRaises(IOError, self.gcs.copy,
+ 'gs://gcsio-test/non-existent',
+ 'gs://gcsio-test/non-existent-destination')
+
+ def test_copytree(self):
+ src_dir_name = 'gs://gcsio-test/source/'
+ dest_dir_name = 'gs://gcsio-test/dest/'
+ file_size = 1024
+ paths = ['a', 'b/c', 'b/d']
+ for path in paths:
+ src_file_name = src_dir_name + path
+ dest_file_name = dest_dir_name + path
+ self._insert_random_file(self.client, src_file_name,
+ file_size)
+ self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
+ self.client.objects.files)
+ self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
+ self.client.objects.files)
+
+ self.gcs.copytree(src_dir_name, dest_dir_name)
+
+ for path in paths:
+ src_file_name = src_dir_name + path
+ dest_file_name = dest_dir_name + path
+ self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
+ self.client.objects.files)
+ self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
+ self.client.objects.files)
+
+ def test_rename(self):
+ src_file_name = 'gs://gcsio-test/source'
+ dest_file_name = 'gs://gcsio-test/dest'
+ file_size = 1024
+ self._insert_random_file(self.client, src_file_name,
+ file_size)
+ self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
+ self.client.objects.files)
+ self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
+ self.client.objects.files)
+
+ self.gcs.rename(src_file_name, dest_file_name)
+
+ self.assertFalse(gcsio.parse_gcs_path(src_file_name) in
+ self.client.objects.files)
+ self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
+ self.client.objects.files)
+
+ def test_full_file_read(self):
+ file_name = 'gs://gcsio-test/full_file'
+ file_size = 5 * 1024 * 1024 + 100
+ random_file = self._insert_random_file(self.client, file_name, file_size)
+ f = self.gcs.open(file_name)
+ f.seek(0, os.SEEK_END)
+ self.assertEqual(f.tell(), file_size)
+ self.assertEqual(f.read(), '')
+ f.seek(0)
+ self.assertEqual(f.read(), random_file.contents)
+
+ def test_file_random_seek(self):
+ file_name = 'gs://gcsio-test/seek_file'
+ file_size = 5 * 1024 * 1024 - 100
+ random_file = self._insert_random_file(self.client, file_name, file_size)
+
+ f = self.gcs.open(file_name)
+ random.seed(0)
+ for _ in range(0, 10):
+ a = random.randint(0, file_size - 1)
+ b = random.randint(0, file_size - 1)
+ start, end = min(a, b), max(a, b)
+ f.seek(start)
+ self.assertEqual(f.tell(), start)
+ self.assertEqual(f.read(end - start + 1),
+ random_file.contents[start:end + 1])
+ self.assertEqual(f.tell(), end + 1)
+
+ def test_file_read_line(self):
+ file_name = 'gs://gcsio-test/read_line_file'
+ lines = []
+
+ # Set a small buffer size to exercise refilling the buffer.
+ # First line is carefully crafted so the newline falls as the last character
+ # of the buffer to exercise this code path.
+ read_buffer_size = 1024
+ lines.append('x' * 1023 + '\n')
+
+ for _ in range(1, 1000):
+ line_length = random.randint(100, 500)
+ line = os.urandom(line_length).replace('\n', ' ') + '\n'
+ lines.append(line)
+ contents = ''.join(lines)
+
+ file_size = len(contents)
+ bucket, name = gcsio.parse_gcs_path(file_name)
+ self.client.objects.add_file(FakeFile(bucket, name, contents, 1))
+
+ f = self.gcs.open(file_name, read_buffer_size=read_buffer_size)
+
+ # Test read of first two lines.
+ f.seek(0)
+ self.assertEqual(f.readline(), lines[0])
+ self.assertEqual(f.tell(), len(lines[0]))
+ self.assertEqual(f.readline(), lines[1])
+
+ # Test read at line boundary.
+ f.seek(file_size - len(lines[-1]) - 1)
+ self.assertEqual(f.readline(), '\n')
+
+ # Test read at end of file.
+ f.seek(file_size)
+ self.assertEqual(f.readline(), '')
+
+ # Test reads at random positions.
+ random.seed(0)
+ for _ in range(0, 10):
+ start = random.randint(0, file_size - 1)
+ line_index = 0
+ # Find line corresponding to start index.
+ chars_left = start
+ while True:
+ next_line_length = len(lines[line_index])
+ if chars_left - next_line_length < 0:
+ break
+ chars_left -= next_line_length
+ line_index += 1
+ f.seek(start)
+ self.assertEqual(f.readline(), lines[line_index][chars_left:])
+
+ def test_file_write(self):
+ file_name = 'gs://gcsio-test/write_file'
+ file_size = 5 * 1024 * 1024 + 2000
+ contents = os.urandom(file_size)
+ f = self.gcs.open(file_name, 'w')
+ f.write(contents[0:1000])
+ f.write(contents[1000:1024 * 1024])
+ f.write(contents[1024 * 1024:])
+ f.close()
+ bucket, name = gcsio.parse_gcs_path(file_name)
+ self.assertEqual(
+ self.client.objects.get_file(bucket, name).contents, contents)
+
+ def test_context_manager(self):
+ # Test writing with a context manager.
+ file_name = 'gs://gcsio-test/context_manager_file'
+ file_size = 1024
+ contents = os.urandom(file_size)
+ with self.gcs.open(file_name, 'w') as f:
+ f.write(contents)
+ bucket, name = gcsio.parse_gcs_path(file_name)
+ self.assertEqual(
+ self.client.objects.get_file(bucket, name).contents, contents)
+
+ # Test reading with a context manager.
+ with self.gcs.open(file_name) as f:
+ self.assertEqual(f.read(), contents)
+
+ # Test that exceptions are not swallowed by the context manager.
+ with self.assertRaises(ZeroDivisionError):
+ with self.gcs.open(file_name) as f:
+ f.read(0 / 0)
+
+ def test_glob(self):
+ bucket_name = 'gcsio-test'
+ object_names = [
+ 'cow/cat/fish',
+ 'cow/cat/blubber',
+ 'cow/dog/blubber',
+ 'apple/dog/blubber',
+ 'apple/fish/blubber',
+ 'apple/fish/blowfish',
+ 'apple/fish/bambi',
+ 'apple/fish/balloon',
+ 'apple/fish/cat',
+ 'apple/fish/cart',
+ 'apple/fish/carl',
+ 'apple/dish/bat',
+ 'apple/dish/cat',
+ 'apple/dish/carl',
+ ]
+ for object_name in object_names:
+ file_name = 'gs://%s/%s' % (bucket_name, object_name)
+ self._insert_random_file(self.client, file_name, 0)
+ test_cases = [
+ ('gs://gcsio-test/*', [
+ 'cow/cat/fish',
+ 'cow/cat/blubber',
+ 'cow/dog/blubber',
+ 'apple/dog/blubber',
+ 'apple/fish/blubber',
+ 'apple/fish/blowfish',
+ 'apple/fish/bambi',
+ 'apple/fish/balloon',
+ 'apple/fish/cat',
+ 'apple/fish/cart',
+ 'apple/fish/carl',
+ 'apple/dish/bat',
+ 'apple/dish/cat',
+ 'apple/dish/carl',
+ ]),
+ ('gs://gcsio-test/cow/*', [
+ 'cow/cat/fish',
+ 'cow/cat/blubber',
+ 'cow/dog/blubber',
+ ]),
+ ('gs://gcsio-test/cow/ca*', [
+ 'cow/cat/fish',
+ 'cow/cat/blubber',
+ ]),
+ ('gs://gcsio-test/apple/[df]ish/ca*', [
+ 'apple/fish/cat',
+ 'apple/fish/cart',
+ 'apple/fish/carl',
+ 'apple/dish/cat',
+ 'apple/dish/carl',
+ ]),
+ ('gs://gcsio-test/apple/fish/car?', [
+ 'apple/fish/cart',
+ 'apple/fish/carl',
+ ]),
+ ('gs://gcsio-test/apple/fish/b*', [
+ 'apple/fish/blubber',
+ 'apple/fish/blowfish',
+ 'apple/fish/bambi',
+ 'apple/fish/balloon',
+ ]),
+ ('gs://gcsio-test/apple/dish/[cb]at', [
+ 'apple/dish/bat',
+ 'apple/dish/cat',
+ ]),
+ ]
+ for file_pattern, expected_object_names in test_cases:
+ expected_file_names = ['gs://%s/%s' % (bucket_name, o) for o in
+ expected_object_names]
+ self.assertEqual(set(self.gcs.glob(file_pattern)),
+ set(expected_file_names))
+
+
+class TestPipeStream(unittest.TestCase):
+
+ def _read_and_verify(self, stream, expected, buffer_size):
+ data_list = []
+ bytes_read = 0
+ seen_last_block = False
+ while True:
+ data = stream.read(buffer_size)
+ self.assertLessEqual(len(data), buffer_size)
+ if len(data) < buffer_size:
+ # Test the constraint that the pipe stream returns less than the buffer
+ # size only when at the end of the stream.
+ if data:
+ self.assertFalse(seen_last_block)
+ seen_last_block = True
+ if not data:
+ break
+ data_list.append(data)
+ bytes_read += len(data)
+ self.assertEqual(stream.tell(), bytes_read)
+ self.assertEqual(''.join(data_list), expected)
+
+ def test_pipe_stream(self):
+ block_sizes = list(4 ** i for i in range(0, 12))
+ data_blocks = list(os.urandom(size) for size in block_sizes)
+ expected = ''.join(data_blocks)
+
+ buffer_sizes = [100001, 512 * 1024, 1024 * 1024]
+
+ for buffer_size in buffer_sizes:
+ parent_conn, child_conn = multiprocessing.Pipe()
+ stream = gcsio.GcsBufferedWriter.PipeStream(child_conn)
+ child_thread = threading.Thread(target=self._read_and_verify,
+ args=(stream, expected, buffer_size))
+ child_thread.start()
+ for data in data_blocks:
+ parent_conn.send_bytes(data)
+ parent_conn.close()
+ child_thread.join()
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()