You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2016/06/14 23:12:47 UTC
[12/50] [abbrv] incubator-beam git commit: Move all files to
apache_beam folder
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/google/cloud/dataflow/io/fileio.py
----------------------------------------------------------------------
diff --git a/sdks/python/google/cloud/dataflow/io/fileio.py b/sdks/python/google/cloud/dataflow/io/fileio.py
deleted file mode 100644
index 9a003f0..0000000
--- a/sdks/python/google/cloud/dataflow/io/fileio.py
+++ /dev/null
@@ -1,747 +0,0 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""File-based sources and sinks."""
-
-from __future__ import absolute_import
-
-import glob
-import gzip
-import logging
-from multiprocessing.pool import ThreadPool
-import os
-import re
-import shutil
-import tempfile
-import time
-
-from google.cloud.dataflow import coders
-from google.cloud.dataflow.io import iobase
-from google.cloud.dataflow.io import range_trackers
-from google.cloud.dataflow.utils import processes
-from google.cloud.dataflow.utils import retry
-
-
-__all__ = ['TextFileSource', 'TextFileSink']
-
-DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
-
-
-# Retrying is needed because there are transient errors that can happen.
-@retry.with_exponential_backoff(num_retries=4, retry_filter=lambda _: True)
-def _gcs_file_copy(from_path, to_path, encoding=''):
- """Copy a local file to a GCS location with retries for transient errors."""
- if not encoding:
- command_args = ['gsutil', '-m', '-q', 'cp', from_path, to_path]
- else:
- encoding = 'Content-Type:' + encoding
- command_args = ['gsutil', '-m', '-q', '-h', encoding, 'cp', from_path,
- to_path]
- logging.info('Executing command: %s', command_args)
- popen = processes.Popen(command_args, stdout=processes.PIPE,
- stderr=processes.PIPE)
- stdoutdata, stderrdata = popen.communicate()
- if popen.returncode != 0:
- raise ValueError(
- 'Failed to copy GCS file from %s to %s (stdout=%s, stderr=%s).' % (
- from_path, to_path, stdoutdata, stderrdata))
-
-
-# -----------------------------------------------------------------------------
-# TextFileSource, TextFileSink.
-
-
-class TextFileSource(iobase.NativeSource):
- """A source for a GCS or local text file.
-
- Parses a text file as newline-delimited elements, by default assuming
- UTF-8 encoding.
- """
-
- def __init__(self, file_path, start_offset=None, end_offset=None,
- compression_type='AUTO', strip_trailing_newlines=True,
- coder=coders.StrUtf8Coder()):
- """Initialize a TextSource.
-
- Args:
- file_path: The file path to read from as a local file path or a GCS
- gs:// path. The path can contain glob characters (*, ?, and [...]
- sets).
- start_offset: The byte offset in the source text file that the reader
- should start reading. By default is 0 (beginning of file).
- end_offset: The byte offset in the file that the reader should stop
- reading. By default it is the end of the file.
- compression_type: Used to handle compressed input files. Typical value
- is 'AUTO'.
- strip_trailing_newlines: Indicates whether this source should remove
- the newline char in each line it reads before decoding that line.
- coder: Coder used to decode each line.
-
- Raises:
- TypeError: if file_path is not a string.
-
- If the file_path contains glob characters then the start_offset and
- end_offset must not be specified.
-
- The 'start_offset' and 'end_offset' pair provide a mechanism to divide the
- text file into multiple pieces for individual sources. Because the offset
- is measured by bytes, some complication arises when the offset splits in
- the middle of a text line. To avoid the scenario where two adjacent sources
- each get a fraction of a line we adopt the following rules:
-
- If start_offset falls inside a line (any character except the firt one)
- then the source will skip the line and start with the next one.
-
- If end_offset falls inside a line (any character except the first one) then
- the source will contain that entire line.
- """
- if not isinstance(file_path, basestring):
- raise TypeError(
- '%s: file_path must be a string; got %r instead' %
- (self.__class__.__name__, file_path))
-
- self.file_path = file_path
- self.start_offset = start_offset
- self.end_offset = end_offset
- self.compression_type = compression_type
- self.strip_trailing_newlines = strip_trailing_newlines
- self.coder = coder
-
- self.is_gcs_source = file_path.startswith('gs://')
-
- @property
- def format(self):
- """Source format name required for remote execution."""
- return 'text'
-
- def __eq__(self, other):
- return (self.file_path == other.file_path and
- self.start_offset == other.start_offset and
- self.end_offset == other.end_offset and
- self.strip_trailing_newlines == other.strip_trailing_newlines and
- self.coder == other.coder)
-
- @property
- def path(self):
- return self.file_path
-
- def reader(self):
- # If a multi-file pattern was specified as a source then make sure the
- # start/end offsets use the default values for reading the entire file.
- if re.search(r'[*?\[\]]', self.file_path) is not None:
- if self.start_offset is not None:
- raise ValueError(
- 'start offset cannot be specified for a multi-file source: '
- '%s' % self.file_path)
- if self.end_offset is not None:
- raise ValueError(
- 'End offset cannot be specified for a multi-file source: '
- '%s' % self.file_path)
- return TextMultiFileReader(self)
- else:
- return TextFileReader(self)
-
-
-class ChannelFactory(object):
- # TODO(robertwb): Generalize into extensible framework.
-
- @staticmethod
- def mkdir(path):
- if path.startswith('gs://'):
- return
- else:
- try:
- os.makedirs(path)
- except OSError as err:
- raise IOError(err)
-
- @staticmethod
- def open(path, mode, mime_type):
- if path.startswith('gs://'):
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- return gcsio.GcsIO().open(path, mode, mime_type=mime_type)
- else:
- return open(path, mode)
-
- @staticmethod
- def rename(src, dst):
- if src.startswith('gs://'):
- assert dst.startswith('gs://'), dst
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- gcsio.GcsIO().rename(src, dst)
- else:
- try:
- os.rename(src, dst)
- except OSError as err:
- raise IOError(err)
-
- @staticmethod
- def copytree(src, dst):
- if src.startswith('gs://'):
- assert dst.startswith('gs://'), dst
- assert src.endswith('/'), src
- assert dst.endswith('/'), dst
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- gcsio.GcsIO().copytree(src, dst)
- else:
- try:
- if os.path.exists(dst):
- shutil.rmtree(dst)
- shutil.copytree(src, dst)
- except OSError as err:
- raise IOError(err)
-
- @staticmethod
- def exists(path):
- if path.startswith('gs://'):
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- return gcsio.GcsIO().exists()
- else:
- return os.path.exists(path)
-
- @staticmethod
- def rmdir(path):
- if path.startswith('gs://'):
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- gcs = gcsio.GcsIO()
- if not path.endswith('/'):
- path += '/'
- # TODO(robertwb): Threadpool?
- for entry in gcs.glob(path + '*'):
- gcs.delete(entry)
- else:
- try:
- shutil.rmtree(path)
- except OSError as err:
- raise IOError(err)
-
- @staticmethod
- def rm(path):
- if path.startswith('gs://'):
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- gcsio.GcsIO().delete(path)
- else:
- try:
- os.remove(path)
- except OSError as err:
- raise IOError(err)
-
- @staticmethod
- def glob(path):
- if path.startswith('gs://'):
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- return gcsio.GcsIO().glob(path)
- else:
- return glob.glob(path)
-
-
-class _CompressionType(object):
- """Object representing single compression type."""
-
- def __init__(self, identifier):
- self.identifier = identifier
-
- def __eq__(self, other):
- return self.identifier == other.identifier
-
-
-class CompressionTypes(object):
- """Enum-like class representing known compression types."""
- NO_COMPRESSION = _CompressionType(1) # No compression.
- DEFLATE = _CompressionType(2) # 'Deflate' ie gzip compression.
-
- @staticmethod
- def valid_compression_type(compression_type):
- """Returns true for valid compression types, false otherwise."""
- return isinstance(compression_type, _CompressionType)
-
-
-class FileSink(iobase.Sink):
- """A sink to a GCS or local files.
-
- To implement a file-based sink, extend this class and override
- either ``write_record()`` or ``write_encoded_record()``.
-
- If needed, also overwrite ``open()`` and/or ``close()`` to customize the
- file handling or write headers and footers.
-
- The output of this write is a PCollection of all written shards.
- """
-
- # Approximate number of write results be assigned for each rename thread.
- _WRITE_RESULTS_PER_RENAME_THREAD = 100
-
- # Max number of threads to be used for renaming even if it means each thread
- # will process more write results.
- _MAX_RENAME_THREADS = 64
-
- def __init__(self,
- file_path_prefix,
- coder,
- file_name_suffix='',
- num_shards=0,
- shard_name_template=None,
- mime_type='application/octet-stream'):
- if shard_name_template is None:
- shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
- elif shard_name_template is '':
- num_shards = 1
- self.file_path_prefix = file_path_prefix
- self.file_name_suffix = file_name_suffix
- self.num_shards = num_shards
- self.coder = coder
- self.mime_type = mime_type
- self.shard_name_format = self._template_to_format(shard_name_template)
-
- def open(self, temp_path):
- """Opens ``temp_path``, returning an opaque file handle object.
-
- The returned file handle is passed to ``write_[encoded_]record`` and
- ``close``.
- """
- return ChannelFactory.open(temp_path, 'wb', self.mime_type)
-
- def write_record(self, file_handle, value):
- """Writes a single record go the file handle returned by ``open()``.
-
- By default, calls ``write_encoded_record`` after encoding the record with
- this sink's Coder.
- """
- self.write_encoded_record(file_handle, self.coder.encode(value))
-
- def write_encoded_record(self, file_handle, encoded_value):
- """Writes a single encoded record to the file handle returned by ``open()``.
- """
- raise NotImplementedError
-
- def close(self, file_handle):
- """Finalize and close the file handle returned from ``open()``.
-
- Called after all records are written.
-
- By default, calls ``file_handle.close()`` iff it is not None.
- """
- if file_handle is not None:
- file_handle.close()
-
- def initialize_write(self):
- tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime(
- '-temp-%Y-%m-%d_%H-%M-%S')
- ChannelFactory().mkdir(tmp_dir)
- return tmp_dir
-
- def open_writer(self, init_result, uid):
- return FileSinkWriter(self, os.path.join(init_result, uid))
-
- def finalize_write(self, init_result, writer_results):
- writer_results = sorted(writer_results)
- num_shards = len(writer_results)
- channel_factory = ChannelFactory()
- num_threads = max(1, min(
- num_shards / FileSink._WRITE_RESULTS_PER_RENAME_THREAD,
- FileSink._MAX_RENAME_THREADS))
-
- rename_ops = []
- for shard_num, shard in enumerate(writer_results):
- final_name = ''.join([
- self.file_path_prefix,
- self.shard_name_format % dict(shard_num=shard_num,
- num_shards=num_shards),
- self.file_name_suffix])
- rename_ops.append((shard, final_name))
-
- logging.info(
- 'Starting finalize_write threads with num_shards: %d, num_threads: %d',
- num_shards, num_threads)
- start_time = time.time()
-
- # Use a thread pool for renaming operations.
- def _rename_file(rename_op):
- """_rename_file executes single (old_name, new_name) rename operation."""
- old_name, final_name = rename_op
- try:
- channel_factory.rename(old_name, final_name)
- except IOError as e:
- # May have already been copied.
- exists = channel_factory.exists(final_name)
- if not exists:
- logging.warning(('IOError in _rename_file. old_name: %s, '
- 'final_name: %s, err: %s'), old_name, final_name, e)
- return(None, e)
- except Exception as e: # pylint: disable=broad-except
- logging.warning(('Exception in _rename_file. old_name: %s, '
- 'final_name: %s, err: %s'), old_name, final_name, e)
- return(None, e)
- return (final_name, None)
-
- rename_results = ThreadPool(num_threads).map(_rename_file, rename_ops)
-
- for final_name, err in rename_results:
- if err:
- logging.warning('Error when processing rename_results: %s', err)
- raise err
- else:
- yield final_name
-
- logging.info('Renamed %d shards in %.2f seconds.',
- num_shards, time.time() - start_time)
-
- try:
- channel_factory.rmdir(init_result)
- except IOError:
- # May have already been removed.
- pass
-
- @staticmethod
- def _template_to_format(shard_name_template):
- if not shard_name_template:
- return ''
- m = re.search('S+', shard_name_template)
- if m is None:
- raise ValueError("Shard number pattern S+ not found in template '%s'"
- % shard_name_template)
- shard_name_format = shard_name_template.replace(
- m.group(0), '%%(shard_num)0%dd' % len(m.group(0)))
- m = re.search('N+', shard_name_format)
- if m:
- shard_name_format = shard_name_format.replace(
- m.group(0), '%%(num_shards)0%dd' % len(m.group(0)))
- return shard_name_format
-
- def __eq__(self, other):
- # TODO(robertwb): Clean up workitem_test which uses this.
- # pylint: disable=unidiomatic-typecheck
- return type(self) == type(other) and self.__dict__ == other.__dict__
-
-
-class FileSinkWriter(iobase.Writer):
- """The writer for FileSink.
- """
-
- def __init__(self, sink, temp_shard_path):
- self.sink = sink
- self.temp_shard_path = temp_shard_path
- self.temp_handle = self.sink.open(temp_shard_path)
-
- def write(self, value):
- self.sink.write_record(self.temp_handle, value)
-
- def close(self):
- self.sink.close(self.temp_handle)
- return self.temp_shard_path
-
-
-class TextFileSink(FileSink):
- """A sink to a GCS or local text file or files."""
-
- def __init__(self,
- file_path_prefix,
- file_name_suffix='',
- append_trailing_newlines=True,
- num_shards=0,
- shard_name_template=None,
- coder=coders.ToStringCoder(),
- compression_type=CompressionTypes.NO_COMPRESSION,
- ):
- """Initialize a TextFileSink.
-
- Args:
- file_path_prefix: The file path to write to. The files written will begin
- with this prefix, followed by a shard identifier (see num_shards), and
- end in a common extension, if given by file_name_suffix. In most cases,
- only this argument is specified and num_shards, shard_name_template, and
- file_name_suffix use default values.
- file_name_suffix: Suffix for the files written.
- append_trailing_newlines: indicate whether this sink should write an
- additional newline char after writing each element.
- num_shards: The number of files (shards) used for output. If not set, the
- service will decide on the optimal number of shards.
- Constraining the number of shards is likely to reduce
- the performance of a pipeline. Setting this value is not recommended
- unless you require a specific number of output files.
- shard_name_template: A template string containing placeholders for
- the shard number and shard count. Currently only '' and
- '-SSSSS-of-NNNNN' are patterns accepted by the service.
- When constructing a filename for a particular shard number, the
- upper-case letters 'S' and 'N' are replaced with the 0-padded shard
- number and shard count respectively. This argument can be '' in which
- case it behaves as if num_shards was set to 1 and only one file will be
- generated. The default pattern used is '-SSSSS-of-NNNNN'.
- coder: Coder used to encode each line.
- compression_type: Type of compression to use for this sink.
-
- Raises:
- TypeError: if file path parameters are not a string or if compression_type
- is not member of CompressionTypes.
- ValueError: if shard_name_template is not of expected format.
-
- Returns:
- A TextFileSink object usable for writing.
- """
- if not isinstance(file_path_prefix, basestring):
- raise TypeError(
- 'TextFileSink: file_path_prefix must be a string; got %r instead' %
- file_path_prefix)
- if not isinstance(file_name_suffix, basestring):
- raise TypeError(
- 'TextFileSink: file_name_suffix must be a string; got %r instead' %
- file_name_suffix)
-
- if not CompressionTypes.valid_compression_type(compression_type):
- raise TypeError('compression_type must be CompressionType object but '
- 'was %s' % type(compression_type))
- if compression_type == CompressionTypes.DEFLATE:
- mime_type = 'application/x-gzip'
- else:
- mime_type = 'text/plain'
-
- super(TextFileSink, self).__init__(file_path_prefix,
- file_name_suffix=file_name_suffix,
- num_shards=num_shards,
- shard_name_template=shard_name_template,
- coder=coder,
- mime_type=mime_type)
-
- self.compression_type = compression_type
- self.append_trailing_newlines = append_trailing_newlines
-
- def open(self, temp_path):
- """Opens ''temp_path'', returning a writeable file object."""
- fobj = ChannelFactory.open(temp_path, 'wb', self.mime_type)
- if self.compression_type == CompressionTypes.DEFLATE:
- return gzip.GzipFile(fileobj=fobj)
- return fobj
-
- def write_encoded_record(self, file_handle, encoded_value):
- file_handle.write(encoded_value)
- if self.append_trailing_newlines:
- file_handle.write('\n')
-
-
-class NativeTextFileSink(iobase.NativeSink):
- """A sink to a GCS or local text file or files."""
-
- def __init__(self, file_path_prefix,
- append_trailing_newlines=True,
- file_name_suffix='',
- num_shards=0,
- shard_name_template=None,
- validate=True,
- coder=coders.ToStringCoder()):
- # We initialize a file_path attribute containing just the prefix part for
- # local runner environment. For now, sharding is not supported in the local
- # runner and sharding options (template, num, suffix) are ignored.
- # The attribute is also used in the worker environment when we just write
- # to a specific file.
- # TODO(silviuc): Add support for file sharding in the local runner.
- self.file_path = file_path_prefix
- self.append_trailing_newlines = append_trailing_newlines
- self.coder = coder
-
- self.is_gcs_sink = self.file_path.startswith('gs://')
-
- self.file_name_prefix = file_path_prefix
- self.file_name_suffix = file_name_suffix
- self.num_shards = num_shards
- # TODO(silviuc): Update this when the service supports more patterns.
- self.shard_name_template = ('-SSSSS-of-NNNNN' if shard_name_template is None
- else shard_name_template)
- # TODO(silviuc): Implement sink validation.
- self.validate = validate
-
- @property
- def format(self):
- """Sink format name required for remote execution."""
- return 'text'
-
- @property
- def path(self):
- return self.file_path
-
- def writer(self):
- return TextFileWriter(self)
-
- def __eq__(self, other):
- return (self.file_path == other.file_path and
- self.append_trailing_newlines == other.append_trailing_newlines and
- self.coder == other.coder and
- self.file_name_prefix == other.file_name_prefix and
- self.file_name_suffix == other.file_name_suffix and
- self.num_shards == other.num_shards and
- self.shard_name_template == other.shard_name_template and
- self.validate == other.validate)
-
-
-# -----------------------------------------------------------------------------
-# TextFileReader, TextMultiFileReader.
-
-
-class TextFileReader(iobase.NativeSourceReader):
- """A reader for a text file source."""
-
- def __init__(self, source):
- self.source = source
- self.start_offset = self.source.start_offset or 0
- self.end_offset = self.source.end_offset
- self.current_offset = self.start_offset
-
- def __enter__(self):
- if self.source.is_gcs_source:
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.io import gcsio
- self._file = gcsio.GcsIO().open(self.source.file_path, 'rb')
- else:
- self._file = open(self.source.file_path, 'rb')
- # Determine the real end_offset.
- # If not specified it will be the length of the file.
- if self.end_offset is None:
- self._file.seek(0, os.SEEK_END)
- self.end_offset = self._file.tell()
-
- if self.start_offset is None:
- self.start_offset = 0
- self.current_offset = self.start_offset
- if self.start_offset > 0:
- # Read one byte before. This operation will either consume a previous
- # newline if start_offset was at the beginning of a line or consume the
- # line if we were in the middle of it. Either way we get the read position
- # exactly where we wanted: at the begining of the first full line.
- self._file.seek(self.start_offset - 1)
- self.current_offset -= 1
- line = self._file.readline()
- self.current_offset += len(line)
- else:
- self._file.seek(self.start_offset)
-
- # Initializing range tracker after start and end offsets are finalized.
- self.range_tracker = range_trackers.OffsetRangeTracker(self.start_offset,
- self.end_offset)
-
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- self._file.close()
-
- def __iter__(self):
- while True:
- if not self.range_tracker.try_claim(
- record_start=self.current_offset):
- # Reader has completed reading the set of records in its range. Note
- # that the end offset of the range may be smaller than the original
- # end offset defined when creating the reader due to reader accepting
- # a dynamic split request from the service.
- return
- line = self._file.readline()
- self.current_offset += len(line)
- if self.source.strip_trailing_newlines:
- line = line.rstrip('\n')
- yield self.source.coder.decode(line)
-
- def get_progress(self):
- return iobase.ReaderProgress(position=iobase.ReaderPosition(
- byte_offset=self.range_tracker.last_record_start))
-
- def request_dynamic_split(self, dynamic_split_request):
- assert dynamic_split_request is not None
- progress = dynamic_split_request.progress
- split_position = progress.position
- if split_position is None:
- percent_complete = progress.percent_complete
- if percent_complete is not None:
- if percent_complete <= 0 or percent_complete >= 1:
- logging.warning(
- 'FileBasedReader cannot be split since the provided percentage '
- 'of work to be completed is out of the valid range (0, '
- '1). Requested: %r',
- dynamic_split_request)
- return
- split_position = iobase.ReaderPosition()
- split_position.byte_offset = (
- self.range_tracker.position_at_fraction(percent_complete))
- else:
- logging.warning(
- 'TextReader requires either a position or a percentage of work to '
- 'be complete to perform a dynamic split request. Requested: %r',
- dynamic_split_request)
- return
-
- if self.range_tracker.try_split(split_position.byte_offset):
- return iobase.DynamicSplitResultWithPosition(split_position)
- else:
- return
-
-
-class TextMultiFileReader(iobase.NativeSourceReader):
- """A reader for a multi-file text source."""
-
- def __init__(self, source):
- self.source = source
- self.file_paths = ChannelFactory.glob(self.source.file_path)
- if not self.file_paths:
- raise RuntimeError(
- 'No files found for path: %s' % self.source.file_path)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- pass
-
- def __iter__(self):
- index = 0
- for path in self.file_paths:
- index += 1
- logging.info('Reading from %s (%d/%d)', path, index, len(self.file_paths))
- with TextFileSource(
- path, strip_trailing_newlines=self.source.strip_trailing_newlines,
- coder=self.source.coder).reader() as reader:
- for line in reader:
- yield line
-
-
-# -----------------------------------------------------------------------------
-# TextFileWriter.
-
-
-class TextFileWriter(iobase.NativeSinkWriter):
- """The sink writer for a TextFileSink."""
-
- def __init__(self, sink):
- self.sink = sink
-
- def __enter__(self):
- if self.sink.is_gcs_sink:
- # TODO(silviuc): Use the storage library instead of gsutil for writes.
- self.temp_path = os.path.join(tempfile.mkdtemp(), 'gcsfile')
- self._file = open(self.temp_path, 'wb')
- else:
- self._file = open(self.sink.file_path, 'wb')
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- self._file.close()
- if hasattr(self, 'temp_path'):
- _gcs_file_copy(self.temp_path, self.sink.file_path, 'text/plain')
-
- def Write(self, line):
- self._file.write(self.sink.coder.encode(line))
- if self.sink.append_trailing_newlines:
- self._file.write('\n')
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/google/cloud/dataflow/io/fileio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/google/cloud/dataflow/io/fileio_test.py b/sdks/python/google/cloud/dataflow/io/fileio_test.py
deleted file mode 100644
index 70192d1..0000000
--- a/sdks/python/google/cloud/dataflow/io/fileio_test.py
+++ /dev/null
@@ -1,522 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Unit tests for local and GCS sources and sinks."""
-
-import glob
-import gzip
-import logging
-import os
-import tempfile
-import unittest
-
-import google.cloud.dataflow as df
-from google.cloud.dataflow import coders
-from google.cloud.dataflow.io import fileio
-from google.cloud.dataflow.io import iobase
-
-
-class TestTextFileSource(unittest.TestCase):
-
- def create_temp_file(self, text):
- temp = tempfile.NamedTemporaryFile(delete=False)
- with temp.file as tmp:
- tmp.write(text)
- return temp.name
-
- def read_with_offsets(self, input_lines, output_lines,
- start_offset=None, end_offset=None):
- source = fileio.TextFileSource(
- file_path=self.create_temp_file('\n'.join(input_lines)),
- start_offset=start_offset, end_offset=end_offset)
- read_lines = []
- with source.reader() as reader:
- for line in reader:
- read_lines.append(line)
- self.assertEqual(read_lines, output_lines)
-
- def progress_with_offsets(self, input_lines,
- start_offset=None, end_offset=None):
- source = fileio.TextFileSource(
- file_path=self.create_temp_file('\n'.join(input_lines)),
- start_offset=start_offset, end_offset=end_offset)
- progress_record = []
- with source.reader() as reader:
- self.assertEqual(reader.get_progress().position.byte_offset, -1)
- for line in reader:
- self.assertIsNotNone(line)
- progress_record.append(reader.get_progress().position.byte_offset)
-
- previous = 0
- for current in progress_record:
- self.assertGreater(current, previous)
- previous = current
-
- def test_read_entire_file(self):
- lines = ['First', 'Second', 'Third']
- source = fileio.TextFileSource(
- file_path=self.create_temp_file('\n'.join(lines)))
- read_lines = []
- with source.reader() as reader:
- for line in reader:
- read_lines.append(line)
- self.assertEqual(read_lines, lines)
-
- def test_progress_entire_file(self):
- lines = ['First', 'Second', 'Third']
- source = fileio.TextFileSource(
- file_path=self.create_temp_file('\n'.join(lines)))
- progress_record = []
- with source.reader() as reader:
- self.assertEqual(-1, reader.get_progress().position.byte_offset)
- for line in reader:
- self.assertIsNotNone(line)
- progress_record.append(reader.get_progress().position.byte_offset)
- self.assertEqual(13, reader.get_progress().position.byte_offset)
-
- self.assertEqual(len(progress_record), 3)
- self.assertEqual(progress_record, [0, 6, 13])
-
- def try_splitting_reader_at(self, reader, split_request, expected_response):
- actual_response = reader.request_dynamic_split(split_request)
-
- if expected_response is None:
- self.assertIsNone(actual_response)
- else:
- self.assertIsNotNone(actual_response.stop_position)
- self.assertIsInstance(actual_response.stop_position,
- iobase.ReaderPosition)
- self.assertIsNotNone(actual_response.stop_position.byte_offset)
- self.assertEqual(expected_response.stop_position.byte_offset,
- actual_response.stop_position.byte_offset)
-
- return actual_response
-
- def test_update_stop_position_for_percent_complete(self):
- lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
- source = fileio.TextFileSource(
- file_path=self.create_temp_file('\n'.join(lines)))
- with source.reader() as reader:
- # Reading two lines
- reader_iter = iter(reader)
- next(reader_iter)
- next(reader_iter)
- next(reader_iter)
-
- # Splitting at end of the range should be unsuccessful
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=0)),
- None)
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=1)),
- None)
-
- # Splitting at positions on or before start offset of the last record
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
- 0.2)),
- None)
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
- 0.4)),
- None)
-
- # Splitting at a position after the start offset of the last record should
- # be successful
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=
- 0.6)),
- iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition(
- byte_offset=15)))
-
- def test_update_stop_position_percent_complete_for_position(self):
- lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee']
- source = fileio.TextFileSource(
- file_path=self.create_temp_file('\n'.join(lines)))
- with source.reader() as reader:
- # Reading two lines
- reader_iter = iter(reader)
- next(reader_iter)
- next(reader_iter)
- next(reader_iter)
-
- # Splitting at end of the range should be unsuccessful
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=0))),
- None)
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=25))),
- None)
-
- # Splitting at positions on or before start offset of the last record
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=5))),
- None)
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=10))),
- None)
-
- # Splitting at a position after the start offset of the last record should
- # be successful
- self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(iobase.ReaderProgress(
- position=iobase.ReaderPosition(byte_offset=15))),
- iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition(
- byte_offset=15)))
-
- def run_update_stop_position_exhaustive(self, lines, newline):
- """An exhaustive test for dynamic splitting.
-
- For the given set of data items, try to perform a split at all possible
- combinations of following.
-
- * start position
- * original stop position
- * updated stop position
- * number of items read
-
- Args:
- lines: set of data items to be used to create the file
- newline: separater to be used when writing give set of lines to a text
- file.
- """
-
- file_path = self.create_temp_file(newline.join(lines))
-
- total_records = len(lines)
- total_bytes = 0
-
- for line in lines:
- total_bytes += len(line)
- total_bytes += len(newline) * (total_records - 1)
-
- for start in xrange(0, total_bytes - 1):
- for end in xrange(start + 1, total_bytes):
- for stop in xrange(start, end):
- for records_to_read in range(0, total_records):
- self.run_update_stop_position(start, end, stop, records_to_read,
- file_path)
-
- def test_update_stop_position_exhaustive(self):
- self.run_update_stop_position_exhaustive(
- ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\n')
-
- def test_update_stop_position_exhaustive_with_empty_lines(self):
- self.run_update_stop_position_exhaustive(
- ['', 'aaaa', '', 'bbbb', 'cccc', '', 'dddd', 'eeee', ''], '\n')
-
- def test_update_stop_position_exhaustive_windows_newline(self):
- self.run_update_stop_position_exhaustive(
- ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\r\n')
-
- def test_update_stop_position_exhaustive_multi_byte(self):
- self.run_update_stop_position_exhaustive(
- [u'\u0d85\u0d85\u0d85\u0d85'.encode('utf-8'), u'\u0db6\u0db6\u0db6\u0db6'.encode('utf-8'),
- u'\u0d9a\u0d9a\u0d9a\u0d9a'.encode('utf-8')], '\n')
-
- def run_update_stop_position(self, start_offset, end_offset, stop_offset,
- records_to_read,
- file_path):
- source = fileio.TextFileSource(file_path, start_offset, end_offset)
-
- records_of_first_split = ''
-
- with source.reader() as reader:
- reader_iter = iter(reader)
- i = 0
-
- try:
- while i < records_to_read:
- records_of_first_split += next(reader_iter)
- i += 1
- except StopIteration:
- # Invalid case, given source does not contain this many records.
- return
-
- last_record_start_after_reading = reader.range_tracker.last_record_start
-
- if stop_offset <= last_record_start_after_reading:
- expected_split_response = None
- elif stop_offset == start_offset or stop_offset == end_offset:
- expected_split_response = None
- elif records_to_read == 0:
- expected_split_response = None # unstarted
- else:
- expected_split_response = iobase.DynamicSplitResultWithPosition(
- stop_position=iobase.ReaderPosition(byte_offset=stop_offset))
-
- split_response = self.try_splitting_reader_at(
- reader,
- iobase.DynamicSplitRequest(progress=iobase.ReaderProgress(
- iobase.ReaderPosition(byte_offset=stop_offset))),
- expected_split_response)
-
- # Reading remaining records from the updated reader.
- for line in reader:
- records_of_first_split += line
-
- if split_response is not None:
- # Total contents received by reading the two splits should be equal to the
- # result obtained by reading the original source.
- records_of_original = ''
- records_of_second_split = ''
-
- with source.reader() as original_reader:
- for line in original_reader:
- records_of_original += line
-
- new_source = fileio.TextFileSource(
- file_path,
- split_response.stop_position.byte_offset,
- end_offset)
- with new_source.reader() as reader:
- for line in reader:
- records_of_second_split += line
-
- self.assertEqual(records_of_original,
- records_of_first_split + records_of_second_split)
-
- def test_various_offset_combination_with_local_file_for_read(self):
- lines = ['01234', '6789012', '456789012']
- self.read_with_offsets(lines, lines[1:], start_offset=5)
- self.read_with_offsets(lines, lines[1:], start_offset=6)
- self.read_with_offsets(lines, lines[2:], start_offset=7)
- self.read_with_offsets(lines, lines[1:2], start_offset=5, end_offset=13)
- self.read_with_offsets(lines, lines[1:2], start_offset=5, end_offset=14)
- self.read_with_offsets(lines, lines[1:], start_offset=5, end_offset=16)
- self.read_with_offsets(lines, lines[2:], start_offset=14, end_offset=20)
- self.read_with_offsets(lines, lines[2:], start_offset=14)
- self.read_with_offsets(lines, [], start_offset=20, end_offset=20)
-
- def test_various_offset_combination_with_local_file_for_progress(self):
- lines = ['01234', '6789012', '456789012']
- self.progress_with_offsets(lines, start_offset=5)
- self.progress_with_offsets(lines, start_offset=6)
- self.progress_with_offsets(lines, start_offset=7)
- self.progress_with_offsets(lines, start_offset=5, end_offset=13)
- self.progress_with_offsets(lines, start_offset=5, end_offset=14)
- self.progress_with_offsets(lines, start_offset=5, end_offset=16)
- self.progress_with_offsets(lines, start_offset=14, end_offset=20)
- self.progress_with_offsets(lines, start_offset=14)
- self.progress_with_offsets(lines, start_offset=20, end_offset=20)
-
-
-class NativeTestTextFileSink(unittest.TestCase):
-
- def create_temp_file(self):
- temp = tempfile.NamedTemporaryFile(delete=False)
- return temp.name
-
- def test_write_entire_file(self):
- lines = ['First', 'Second', 'Third']
- file_path = self.create_temp_file()
- sink = fileio.NativeTextFileSink(file_path)
- with sink.writer() as writer:
- for line in lines:
- writer.Write(line)
- with open(file_path, 'r') as f:
- self.assertEqual(f.read().splitlines(), lines)
-
-
-class TestPureTextFileSink(unittest.TestCase):
-
- def setUp(self):
- self.lines = ['Line %d' % d for d in range(100)]
- self.path = tempfile.NamedTemporaryFile().name
-
- def _write_lines(self, sink, lines):
- f = sink.open(self.path)
- for line in lines:
- sink.write_record(f, line)
- sink.close(f)
-
- def test_write_text_file(self):
- sink = fileio.TextFileSink(self.path)
- self._write_lines(sink, self.lines)
-
- with open(self.path, 'r') as f:
- self.assertEqual(f.read().splitlines(), self.lines)
-
- def test_write_gzip_file(self):
- sink = fileio.TextFileSink(
- self.path, compression_type=fileio.CompressionTypes.DEFLATE)
- self._write_lines(sink, self.lines)
-
- with gzip.GzipFile(self.path, 'r') as f:
- self.assertEqual(f.read().splitlines(), self.lines)
-
-
-class MyFileSink(fileio.FileSink):
-
- def open(self, temp_path):
- # TODO(robertwb): Fix main session pickling.
- # file_handle = super(MyFileSink, self).open(temp_path)
- file_handle = fileio.FileSink.open(self, temp_path)
- file_handle.write('[start]')
- return file_handle
-
- def write_encoded_record(self, file_handle, encoded_value):
- file_handle.write('[')
- file_handle.write(encoded_value)
- file_handle.write(']')
-
- def close(self, file_handle):
- file_handle.write('[end]')
- # TODO(robertwb): Fix main session pickling.
- # file_handle = super(MyFileSink, self).close(file_handle)
- file_handle = fileio.FileSink.close(self, file_handle)
-
-
-class TestFileSink(unittest.TestCase):
-
- def test_file_sink_writing(self):
- temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
-
- # Manually invoke the generic Sink API.
- init_token = sink.initialize_write()
-
- writer1 = sink.open_writer(init_token, '1')
- writer1.write('a')
- writer1.write('b')
- res1 = writer1.close()
-
- writer2 = sink.open_writer(init_token, '2')
- writer2.write('x')
- writer2.write('y')
- writer2.write('z')
- res2 = writer2.close()
-
- res = list(sink.finalize_write(init_token, [res1, res2]))
- # Retry the finalize operation (as if the first attempt was lost).
- res = list(sink.finalize_write(init_token, [res1, res2]))
-
- # Check the results.
- shard1 = temp_path + '-00000-of-00002.foo'
- shard2 = temp_path + '-00001-of-00002.foo'
- self.assertEqual(res, [shard1, shard2])
- self.assertEqual(open(shard1).read(), '[start][a][b][end]')
- self.assertEqual(open(shard2).read(), '[start][x][y][z][end]')
-
- # Check that any temp files are deleted.
- self.assertItemsEqual([shard1, shard2], glob.glob(temp_path + '*'))
-
- def test_empty_write(self):
- temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
- p = df.Pipeline('DirectPipelineRunner')
- p | df.Create([]) | df.io.Write(sink) # pylint: disable=expression-not-assigned
- p.run()
- self.assertEqual(open(temp_path + '-00000-of-00001.foo').read(),
- '[start][end]')
-
- def test_fixed_shard_write(self):
- temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- num_shards=3,
- shard_name_template='_NN_SSS_',
- coder=coders.ToStringCoder())
- p = df.Pipeline('DirectPipelineRunner')
- p | df.Create(['a', 'b']) | df.io.Write(sink) # pylint: disable=expression-not-assigned
-
- p.run()
-
- concat = ''.join(open(temp_path + '_03_%03d_.foo' % shard_num).read()
- for shard_num in range(3))
- self.assertTrue('][a][' in concat, concat)
- self.assertTrue('][b][' in concat, concat)
-
- def test_file_sink_multi_shards(self):
- temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
-
- # Manually invoke the generic Sink API.
- init_token = sink.initialize_write()
-
- num_shards = 1000
- writer_results = []
- for i in range(num_shards):
- uuid = 'uuid-%05d' % i
- writer = sink.open_writer(init_token, uuid)
- writer.write('a')
- writer.write('b')
- writer.write(uuid)
- writer_results.append(writer.close())
-
- res_first = list(sink.finalize_write(init_token, writer_results))
- # Retry the finalize operation (as if the first attempt was lost).
- res_second = list(sink.finalize_write(init_token, writer_results))
-
- self.assertItemsEqual(res_first, res_second)
-
- res = sorted(res_second)
- for i in range(num_shards):
- shard_name = '%s-%05d-of-%05d.foo' % (temp_path, i, num_shards)
- uuid = 'uuid-%05d' % i
- self.assertEqual(res[i], shard_name)
- self.assertEqual(
- open(shard_name).read(), ('[start][a][b][%s][end]' % uuid))
-
- # Check that any temp files are deleted.
- self.assertItemsEqual(res, glob.glob(temp_path + '*'))
-
- def test_file_sink_io_error(self):
- temp_path = tempfile.NamedTemporaryFile().name
- sink = MyFileSink(temp_path,
- file_name_suffix='.foo',
- coder=coders.ToStringCoder())
-
- # Manually invoke the generic Sink API.
- init_token = sink.initialize_write()
-
- writer1 = sink.open_writer(init_token, '1')
- writer1.write('a')
- writer1.write('b')
- res1 = writer1.close()
-
- writer2 = sink.open_writer(init_token, '2')
- writer2.write('x')
- writer2.write('y')
- writer2.write('z')
- res2 = writer2.close()
-
- os.remove(res2)
- with self.assertRaises(IOError):
- list(sink.finalize_write(init_token, [res1, res2]))
-
-if __name__ == '__main__':
- logging.getLogger().setLevel(logging.INFO)
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/google/cloud/dataflow/io/gcsio.py
----------------------------------------------------------------------
diff --git a/sdks/python/google/cloud/dataflow/io/gcsio.py b/sdks/python/google/cloud/dataflow/io/gcsio.py
deleted file mode 100644
index 8157b76..0000000
--- a/sdks/python/google/cloud/dataflow/io/gcsio.py
+++ /dev/null
@@ -1,602 +0,0 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Google Cloud Storage client.
-
-This library evolved from the Google App Engine GCS client available at
-https://github.com/GoogleCloudPlatform/appengine-gcs-client.
-"""
-
-import errno
-import fnmatch
-import logging
-import multiprocessing
-import os
-import re
-import StringIO
-import threading
-
-from google.cloud.dataflow.internal import auth
-from google.cloud.dataflow.utils import retry
-
-from apitools.base.py.exceptions import HttpError
-import apitools.base.py.transfer as transfer
-
-# Issue a friendlier error message if the storage library is not available.
-# TODO(silviuc): Remove this guard when storage is available everywhere.
-try:
- # pylint: disable=g-import-not-at-top
- from google.cloud.dataflow.internal.clients import storage
-except ImportError:
- raise RuntimeError(
- 'Google Cloud Storage I/O not supported for this execution environment '
- '(could not import storage API client).')
-
-
-DEFAULT_READ_BUFFER_SIZE = 1024 * 1024
-
-
-def parse_gcs_path(gcs_path):
- """Return the bucket and object names of the given gs:// path."""
- match = re.match('^gs://([^/]+)/(.+)$', gcs_path)
- if match is None:
- raise ValueError('GCS path must be in the form gs://<bucket>/<object>.')
- return match.group(1), match.group(2)
-
-
-class GcsIOError(IOError, retry.PermanentException):
- """GCS IO error that should not be retried."""
- pass
-
-
-class GcsIO(object):
- """Google Cloud Storage I/O client."""
-
- def __new__(cls, storage_client=None):
- if storage_client:
- return super(GcsIO, cls).__new__(cls, storage_client)
- else:
- # Create a single storage client for each thread. We would like to avoid
- # creating more than one storage client for each thread, since each
- # initialization requires the relatively expensive step of initializing
- # credentaials.
- local_state = threading.local()
- if getattr(local_state, 'gcsio_instance', None) is None:
- credentials = auth.get_service_credentials()
- storage_client = storage.StorageV1(credentials=credentials)
- local_state.gcsio_instance = (
- super(GcsIO, cls).__new__(cls, storage_client))
- local_state.gcsio_instance.client = storage_client
- return local_state.gcsio_instance
-
- def __init__(self, storage_client=None):
- # We must do this check on storage_client because the client attribute may
- # have already been set in __new__ for the singleton case when
- # storage_client is None.
- if storage_client is not None:
- self.client = storage_client
-
- def open(self, filename, mode='r',
- read_buffer_size=DEFAULT_READ_BUFFER_SIZE,
- mime_type='application/octet-stream'):
- """Open a GCS file path for reading or writing.
-
- Args:
- filename: GCS file path in the form gs://<bucket>/<object>.
- mode: 'r' for reading or 'w' for writing.
- read_buffer_size: Buffer size to use during read operations.
- mime_type: Mime type to set for write operations.
-
- Returns:
- file object.
-
- Raises:
- ValueError: Invalid open file mode.
- """
- if mode == 'r' or mode == 'rb':
- return GcsBufferedReader(self.client, filename,
- buffer_size=read_buffer_size)
- elif mode == 'w' or mode == 'wb':
- return GcsBufferedWriter(self.client, filename, mime_type=mime_type)
- else:
- raise ValueError('Invalid file open mode: %s.' % mode)
-
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
- def glob(self, pattern):
- """Return the GCS path names matching a given path name pattern.
-
- Path name patterns are those recognized by fnmatch.fnmatch(). The path
- can contain glob characters (*, ?, and [...] sets).
-
- Args:
- pattern: GCS file path pattern in the form gs://<bucket>/<name_pattern>.
-
- Returns:
- list of GCS file paths matching the given pattern.
- """
- bucket, name_pattern = parse_gcs_path(pattern)
- # Get the prefix with which we can list objects in the given bucket.
- prefix = re.match('^[^[*?]*', name_pattern).group(0)
- request = storage.StorageObjectsListRequest(bucket=bucket, prefix=prefix)
- object_paths = []
- while True:
- response = self.client.objects.List(request)
- for item in response.items:
- if fnmatch.fnmatch(item.name, name_pattern):
- object_paths.append('gs://%s/%s' % (item.bucket, item.name))
- if response.nextPageToken:
- request.pageToken = response.nextPageToken
- else:
- break
- return object_paths
-
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
- def delete(self, path):
- """Deletes the object at the given GCS path.
-
- Args:
- path: GCS file path pattern in the form gs://<bucket>/<name>.
- """
- bucket, object_path = parse_gcs_path(path)
- request = storage.StorageObjectsDeleteRequest(bucket=bucket,
- object=object_path)
- try:
- self.client.objects.Delete(request)
- except HttpError as http_error:
- if http_error.status_code == 404:
- # Return success when the file doesn't exist anymore for idempotency.
- return
- raise
-
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
- def copy(self, src, dest):
- """Copies the given GCS object from src to dest.
-
- Args:
- src: GCS file path pattern in the form gs://<bucket>/<name>.
- dest: GCS file path pattern in the form gs://<bucket>/<name>.
- """
- src_bucket, src_path = parse_gcs_path(src)
- dest_bucket, dest_path = parse_gcs_path(dest)
- request = storage.StorageObjectsCopyRequest(sourceBucket=src_bucket,
- sourceObject=src_path,
- destinationBucket=dest_bucket,
- destinationObject=dest_path)
- try:
- self.client.objects.Copy(request)
- except HttpError as http_error:
- if http_error.status_code == 404:
- # This is a permanent error that should not be retried. Note that
- # FileSink.finalize_write expects an IOError when the source file does
- # not exist.
- raise GcsIOError(errno.ENOENT, 'Source file not found: %s' % src)
- raise
-
- # We intentionally do not decorate this method with a retry, since the
- # underlying copy and delete operations are already idempotent operations
- # protected by retry decorators.
- def copytree(self, src, dest):
- """Renames the given GCS "directory" recursively from src to dest.
-
- Args:
- src: GCS file path pattern in the form gs://<bucket>/<name>/.
- dest: GCS file path pattern in the form gs://<bucket>/<name>/.
- """
- assert src.endswith('/')
- assert dest.endswith('/')
- for entry in self.glob(src + '*'):
- rel_path = entry[len(src):]
- self.copy(entry, dest + rel_path)
-
- # We intentionally do not decorate this method with a retry, since the
- # underlying copy and delete operations are already idempotent operations
- # protected by retry decorators.
- def rename(self, src, dest):
- """Renames the given GCS object from src to dest.
-
- Args:
- src: GCS file path pattern in the form gs://<bucket>/<name>.
- dest: GCS file path pattern in the form gs://<bucket>/<name>.
- """
- self.copy(src, dest)
- self.delete(src)
-
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
- def exists(self, path):
- """Returns whether the given GCS object exists.
-
- Args:
- path: GCS file path pattern in the form gs://<bucket>/<name>.
- """
- bucket, object_path = parse_gcs_path(path)
- try:
- request = storage.StorageObjectsGetRequest(bucket=bucket,
- object=object_path)
- self.client.objects.Get(request) # metadata
- return True
- except IOError:
- return False
-
-
-class GcsBufferedReader(object):
- """A class for reading Google Cloud Storage files."""
-
- def __init__(self, client, path, buffer_size=DEFAULT_READ_BUFFER_SIZE):
- self.client = client
- self.path = path
- self.bucket, self.name = parse_gcs_path(path)
- self.buffer_size = buffer_size
-
- # Get object state.
- get_request = (
- storage.StorageObjectsGetRequest(
- bucket=self.bucket,
- object=self.name))
- try:
- metadata = self._get_object_metadata(get_request)
- except HttpError as http_error:
- if http_error.status_code == 404:
- raise IOError(errno.ENOENT, 'Not found: %s' % self.path)
- else:
- logging.error(
- 'HTTP error while requesting file %s: %s', self.path, http_error)
- raise
- self.size = metadata.size
-
- # Ensure read is from file of the correct generation.
- get_request.generation = metadata.generation
-
- # Initialize read buffer state.
- self.download_stream = StringIO.StringIO()
- self.downloader = transfer.Download(
- self.download_stream, auto_transfer=False)
- self.client.objects.Get(get_request, download=self.downloader)
- self.position = 0
- self.buffer = ''
- self.buffer_start_position = 0
- self.closed = False
-
- @retry.with_exponential_backoff(
- retry_filter=retry.retry_on_server_errors_and_timeout_filter)
- def _get_object_metadata(self, get_request):
- return self.client.objects.Get(get_request)
-
- def read(self, size=-1):
- """Read data from a GCS file.
-
- Args:
- size: Number of bytes to read. Actual number of bytes read is always
- equal to size unless EOF is reached. If size is negative or
- unspecified, read the entire file.
-
- Returns:
- data read as str.
-
- Raises:
- IOError: When this buffer is closed.
- """
- return self._read_inner(size=size, readline=False)
-
- def readline(self, size=-1):
- """Read one line delimited by '\\n' from the file.
-
- Mimics behavior of the readline() method on standard file objects.
-
- A trailing newline character is kept in the string. It may be absent when a
- file ends with an incomplete line. If the size argument is non-negative,
- it specifies the maximum string size (counting the newline) to return.
- A negative size is the same as unspecified. Empty string is returned
- only when EOF is encountered immediately.
-
- Args:
- size: Maximum number of bytes to read. If not specified, readline stops
- only on '\\n' or EOF.
-
- Returns:
- The data read as a string.
-
- Raises:
- IOError: When this buffer is closed.
- """
- return self._read_inner(size=size, readline=True)
-
- def _read_inner(self, size=-1, readline=False):
- """Shared implementation of read() and readline()."""
- self._check_open()
- if not self._remaining():
- return ''
-
- # Prepare to read.
- data_list = []
- if size is None:
- size = -1
- to_read = min(size, self._remaining())
- if to_read < 0:
- to_read = self._remaining()
- break_after = False
-
- while to_read > 0:
- # If we have exhausted the buffer, get the next segment.
- # TODO(ccy): We should consider prefetching the next block in another
- # thread.
- self._fetch_next_if_buffer_exhausted()
-
- # Determine number of bytes to read from buffer.
- buffer_bytes_read = self.position - self.buffer_start_position
- bytes_to_read_from_buffer = min(
- len(self.buffer) - buffer_bytes_read, to_read)
-
- # If readline is set, we only want to read up to and including the next
- # newline character.
- if readline:
- next_newline_position = self.buffer.find(
- '\n', buffer_bytes_read, len(self.buffer))
- if next_newline_position != -1:
- bytes_to_read_from_buffer = (1 + next_newline_position -
- buffer_bytes_read)
- break_after = True
-
- # Read bytes.
- data_list.append(
- self.buffer[buffer_bytes_read:buffer_bytes_read +
- bytes_to_read_from_buffer])
- self.position += bytes_to_read_from_buffer
- to_read -= bytes_to_read_from_buffer
-
- if break_after:
- break
-
- return ''.join(data_list)
-
- def _fetch_next_if_buffer_exhausted(self):
- if not self.buffer or (self.buffer_start_position + len(self.buffer)
- <= self.position):
- bytes_to_request = min(self._remaining(), self.buffer_size)
- self.buffer_start_position = self.position
- self.buffer = self._get_segment(self.position, bytes_to_request)
-
- def _remaining(self):
- return self.size - self.position
-
- def close(self):
- """Close the current GCS file."""
- self.closed = True
- self.download_stream = None
- self.downloader = None
- self.buffer = None
-
- def _get_segment(self, start, size):
- """Get the given segment of the current GCS file."""
- if size == 0:
- return ''
- end = start + size - 1
- self.downloader.GetRange(start, end)
- value = self.download_stream.getvalue()
- # Clear the StringIO object after we've read its contents.
- self.download_stream.truncate(0)
- assert len(value) == size
- return value
-
- def __enter__(self):
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- self.close()
-
- def seek(self, offset, whence=os.SEEK_SET):
- """Set the file's current offset.
-
- Note if the new offset is out of bound, it is adjusted to either 0 or EOF.
-
- Args:
- offset: seek offset as number.
- whence: seek mode. Supported modes are os.SEEK_SET (absolute seek),
- os.SEEK_CUR (seek relative to the current position), and os.SEEK_END
- (seek relative to the end, offset should be negative).
-
- Raises:
- IOError: When this buffer is closed.
- ValueError: When whence is invalid.
- """
- self._check_open()
-
- self.buffer = ''
- self.buffer_start_position = -1
-
- if whence == os.SEEK_SET:
- self.position = offset
- elif whence == os.SEEK_CUR:
- self.position += offset
- elif whence == os.SEEK_END:
- self.position = self.size + offset
- else:
- raise ValueError('Whence mode %r is invalid.' % whence)
-
- self.position = min(self.position, self.size)
- self.position = max(self.position, 0)
-
- def tell(self):
- """Tell the file's current offset.
-
- Returns:
- current offset in reading this file.
-
- Raises:
- IOError: When this buffer is closed.
- """
- self._check_open()
- return self.position
-
- def _check_open(self):
- if self.closed:
- raise IOError('Buffer is closed.')
-
- def seekable(self):
- return True
-
- def readable(self):
- return True
-
- def writable(self):
- return False
-
-
-class GcsBufferedWriter(object):
- """A class for writing Google Cloud Storage files."""
-
- class PipeStream(object):
- """A class that presents a pipe connection as a readable stream."""
-
- def __init__(self, recv_pipe):
- self.conn = recv_pipe
- self.closed = False
- self.position = 0
- self.remaining = ''
-
- def read(self, size):
- """Read data from the wrapped pipe connection.
-
- Args:
- size: Number of bytes to read. Actual number of bytes read is always
- equal to size unless EOF is reached.
-
- Returns:
- data read as str.
- """
- data_list = []
- bytes_read = 0
- while bytes_read < size:
- bytes_from_remaining = min(size - bytes_read, len(self.remaining))
- data_list.append(self.remaining[0:bytes_from_remaining])
- self.remaining = self.remaining[bytes_from_remaining:]
- self.position += bytes_from_remaining
- bytes_read += bytes_from_remaining
- if not self.remaining:
- try:
- self.remaining = self.conn.recv_bytes()
- except EOFError:
- break
- return ''.join(data_list)
-
- def tell(self):
- """Tell the file's current offset.
-
- Returns:
- current offset in reading this file.
-
- Raises:
- IOError: When this stream is closed.
- """
- self._check_open()
- return self.position
-
- def seek(self, offset, whence=os.SEEK_SET):
- # The apitools.base.py.transfer.Upload class insists on seeking to the end
- # of a stream to do a check before completing an upload, so we must have
- # this no-op method here in that case.
- if whence == os.SEEK_END and offset == 0:
- return
- elif whence == os.SEEK_SET and offset == self.position:
- return
- raise NotImplementedError
-
- def _check_open(self):
- if self.closed:
- raise IOError('Stream is closed.')
-
- def __init__(self, client, path, mime_type='application/octet-stream'):
- self.client = client
- self.path = path
- self.bucket, self.name = parse_gcs_path(path)
-
- self.closed = False
- self.position = 0
-
- # Set up communication with uploading thread.
- parent_conn, child_conn = multiprocessing.Pipe()
- self.conn = parent_conn
-
- # Set up uploader.
- self.insert_request = (
- storage.StorageObjectsInsertRequest(
- bucket=self.bucket,
- name=self.name))
- self.upload = transfer.Upload(GcsBufferedWriter.PipeStream(child_conn),
- mime_type)
- self.upload.strategy = transfer.RESUMABLE_UPLOAD
-
- # Start uploading thread.
- self.upload_thread = threading.Thread(target=self._start_upload)
- self.upload_thread.daemon = True
- self.upload_thread.start()
-
- # TODO(silviuc): Refactor so that retry logic can be applied.
- # There is retry logic in the underlying transfer library but we should make
- # it more explicit so we can control the retry parameters.
- @retry.no_retries # Using no_retries marks this as an integration point.
- def _start_upload(self):
- # This starts the uploader thread. We are forced to run the uploader in
- # another thread because the apitools uploader insists on taking a stream
- # as input. Happily, this also means we get asynchronous I/O to GCS.
- #
- # The uploader by default transfers data in chunks of 1024 * 1024 bytes at
- # a time, buffering writes until that size is reached.
- self.client.objects.Insert(self.insert_request, upload=self.upload)
-
- def write(self, data):
- """Write data to a GCS file.
-
- Args:
- data: data to write as str.
-
- Raises:
- IOError: When this buffer is closed.
- """
- self._check_open()
- if not data:
- return
- self.conn.send_bytes(data)
- self.position += len(data)
-
- def tell(self):
- """Return the total number of bytes passed to write() so far."""
- return self.position
-
- def close(self):
- """Close the current GCS file."""
- self.conn.close()
- self.upload_thread.join()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exception_type, exception_value, traceback):
- self.close()
-
- def _check_open(self):
- if self.closed:
- raise IOError('Buffer is closed.')
-
- def seekable(self):
- return False
-
- def readable(self):
- return False
-
- def writable(self):
- return True
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/google/cloud/dataflow/io/gcsio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/google/cloud/dataflow/io/gcsio_test.py b/sdks/python/google/cloud/dataflow/io/gcsio_test.py
deleted file mode 100644
index 702c834..0000000
--- a/sdks/python/google/cloud/dataflow/io/gcsio_test.py
+++ /dev/null
@@ -1,503 +0,0 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Tests for Google Cloud Storage client."""
-
-import logging
-import multiprocessing
-import os
-import random
-import threading
-import unittest
-
-
-import httplib2
-
-from google.cloud.dataflow.io import gcsio
-from apitools.base.py.exceptions import HttpError
-from google.cloud.dataflow.internal.clients import storage
-
-
-class FakeGcsClient(object):
- # Fake storage client. Usage in gcsio.py is client.objects.Get(...) and
- # client.objects.Insert(...).
-
- def __init__(self):
- self.objects = FakeGcsObjects()
-
-
-class FakeFile(object):
-
- def __init__(self, bucket, obj, contents, generation):
- self.bucket = bucket
- self.object = obj
- self.contents = contents
- self.generation = generation
-
- def get_metadata(self):
- return storage.Object(bucket=self.bucket,
- name=self.object,
- generation=self.generation,
- size=len(self.contents))
-
-
-class FakeGcsObjects(object):
-
- def __init__(self):
- self.files = {}
- # Store the last generation used for a given object name. Note that this
- # has to persist even past the deletion of the object.
- self.last_generation = {}
- self.list_page_tokens = {}
-
- def add_file(self, f):
- self.files[(f.bucket, f.object)] = f
- self.last_generation[(f.bucket, f.object)] = f.generation
-
- def get_file(self, bucket, obj):
- return self.files.get((bucket, obj), None)
-
- def delete_file(self, bucket, obj):
- del self.files[(bucket, obj)]
-
- def get_last_generation(self, bucket, obj):
- return self.last_generation.get((bucket, obj), 0)
-
- def Get(self, get_request, download=None): # pylint: disable=invalid-name
- f = self.get_file(get_request.bucket, get_request.object)
- if f is None:
- raise ValueError('Specified object does not exist.')
- if download is None:
- return f.get_metadata()
- else:
- stream = download.stream
-
- def get_range_callback(start, end):
- assert start >= 0 and end >= start and end < len(f.contents)
- stream.write(f.contents[start:end + 1])
- download.GetRange = get_range_callback
-
- def Insert(self, insert_request, upload=None): # pylint: disable=invalid-name
- assert upload is not None
- generation = self.get_last_generation(insert_request.bucket,
- insert_request.name) + 1
- f = FakeFile(insert_request.bucket, insert_request.name, '', generation)
-
- # Stream data into file.
- stream = upload.stream
- data_list = []
- while True:
- data = stream.read(1024 * 1024)
- if not data:
- break
- data_list.append(data)
- f.contents = ''.join(data_list)
-
- self.add_file(f)
-
- def Copy(self, copy_request): # pylint: disable=invalid-name
- src_file = self.get_file(copy_request.sourceBucket,
- copy_request.sourceObject)
- if not src_file:
- raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found',
- 'https://fake/url')
- generation = self.get_last_generation(copy_request.destinationBucket,
- copy_request.destinationObject) + 1
- dest_file = FakeFile(copy_request.destinationBucket,
- copy_request.destinationObject,
- src_file.contents, generation)
- self.add_file(dest_file)
-
- def Delete(self, delete_request): # pylint: disable=invalid-name
- # Here, we emulate the behavior of the GCS service in raising a 404 error
- # if this object already exists.
- if self.get_file(delete_request.bucket, delete_request.object):
- self.delete_file(delete_request.bucket, delete_request.object)
- else:
- raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found',
- 'https://fake/url')
-
- def List(self, list_request): # pylint: disable=invalid-name
- bucket = list_request.bucket
- prefix = list_request.prefix or ''
- matching_files = []
- for file_bucket, file_name in sorted(iter(self.files)):
- if bucket == file_bucket and file_name.startswith(prefix):
- file_object = self.files[(file_bucket, file_name)].get_metadata()
- matching_files.append(file_object)
-
- # Handle pagination.
- items_per_page = 5
- if not list_request.pageToken:
- range_start = 0
- else:
- if list_request.pageToken not in self.list_page_tokens:
- raise ValueError('Invalid page token.')
- range_start = self.list_page_tokens[list_request.pageToken]
- del self.list_page_tokens[list_request.pageToken]
-
- result = storage.Objects(
- items=matching_files[range_start:range_start + items_per_page])
- if range_start + items_per_page < len(matching_files):
- next_range_start = range_start + items_per_page
- next_page_token = '_page_token_%s_%s_%d' % (bucket, prefix,
- next_range_start)
- self.list_page_tokens[next_page_token] = next_range_start
- result.nextPageToken = next_page_token
- return result
-
-
-class TestGCSPathParser(unittest.TestCase):
-
- def test_gcs_path(self):
- self.assertEqual(
- gcsio.parse_gcs_path('gs://bucket/name'), ('bucket', 'name'))
- self.assertEqual(
- gcsio.parse_gcs_path('gs://bucket/name/sub'), ('bucket', 'name/sub'))
-
- def test_bad_gcs_path(self):
- self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://')
- self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://bucket')
- self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://bucket/')
- self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:///name')
- self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:///')
- self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:/blah/bucket/name')
-
-
-class TestGCSIO(unittest.TestCase):
-
- def _insert_random_file(self, client, path, size, generation=1):
- bucket, name = gcsio.parse_gcs_path(path)
- f = FakeFile(bucket, name, os.urandom(size), generation)
- client.objects.add_file(f)
- return f
-
- def setUp(self):
- self.client = FakeGcsClient()
- self.gcs = gcsio.GcsIO(self.client)
-
- def test_delete(self):
- file_name = 'gs://gcsio-test/delete_me'
- file_size = 1024
-
- # Test deletion of non-existent file.
- self.gcs.delete(file_name)
-
- self._insert_random_file(self.client, file_name, file_size)
- self.assertTrue(gcsio.parse_gcs_path(file_name) in
- self.client.objects.files)
-
- self.gcs.delete(file_name)
-
- self.assertFalse(gcsio.parse_gcs_path(file_name) in
- self.client.objects.files)
-
- def test_copy(self):
- src_file_name = 'gs://gcsio-test/source'
- dest_file_name = 'gs://gcsio-test/dest'
- file_size = 1024
- self._insert_random_file(self.client, src_file_name,
- file_size)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
-
- self.gcs.copy(src_file_name, dest_file_name)
-
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
-
- self.assertRaises(IOError, self.gcs.copy,
- 'gs://gcsio-test/non-existent',
- 'gs://gcsio-test/non-existent-destination')
-
- def test_copytree(self):
- src_dir_name = 'gs://gcsio-test/source/'
- dest_dir_name = 'gs://gcsio-test/dest/'
- file_size = 1024
- paths = ['a', 'b/c', 'b/d']
- for path in paths:
- src_file_name = src_dir_name + path
- dest_file_name = dest_dir_name + path
- self._insert_random_file(self.client, src_file_name,
- file_size)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
-
- self.gcs.copytree(src_dir_name, dest_dir_name)
-
- for path in paths:
- src_file_name = src_dir_name + path
- dest_file_name = dest_dir_name + path
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
-
- def test_rename(self):
- src_file_name = 'gs://gcsio-test/source'
- dest_file_name = 'gs://gcsio-test/dest'
- file_size = 1024
- self._insert_random_file(self.client, src_file_name,
- file_size)
- self.assertTrue(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
-
- self.gcs.rename(src_file_name, dest_file_name)
-
- self.assertFalse(gcsio.parse_gcs_path(src_file_name) in
- self.client.objects.files)
- self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in
- self.client.objects.files)
-
- def test_full_file_read(self):
- file_name = 'gs://gcsio-test/full_file'
- file_size = 5 * 1024 * 1024 + 100
- random_file = self._insert_random_file(self.client, file_name, file_size)
- f = self.gcs.open(file_name)
- f.seek(0, os.SEEK_END)
- self.assertEqual(f.tell(), file_size)
- self.assertEqual(f.read(), '')
- f.seek(0)
- self.assertEqual(f.read(), random_file.contents)
-
- def test_file_random_seek(self):
- file_name = 'gs://gcsio-test/seek_file'
- file_size = 5 * 1024 * 1024 - 100
- random_file = self._insert_random_file(self.client, file_name, file_size)
-
- f = self.gcs.open(file_name)
- random.seed(0)
- for _ in range(0, 10):
- a = random.randint(0, file_size - 1)
- b = random.randint(0, file_size - 1)
- start, end = min(a, b), max(a, b)
- f.seek(start)
- self.assertEqual(f.tell(), start)
- self.assertEqual(f.read(end - start + 1),
- random_file.contents[start:end + 1])
- self.assertEqual(f.tell(), end + 1)
-
- def test_file_read_line(self):
- file_name = 'gs://gcsio-test/read_line_file'
- lines = []
-
- # Set a small buffer size to exercise refilling the buffer.
- # First line is carefully crafted so the newline falls as the last character
- # of the buffer to exercise this code path.
- read_buffer_size = 1024
- lines.append('x' * 1023 + '\n')
-
- for _ in range(1, 1000):
- line_length = random.randint(100, 500)
- line = os.urandom(line_length).replace('\n', ' ') + '\n'
- lines.append(line)
- contents = ''.join(lines)
-
- file_size = len(contents)
- bucket, name = gcsio.parse_gcs_path(file_name)
- self.client.objects.add_file(FakeFile(bucket, name, contents, 1))
-
- f = self.gcs.open(file_name, read_buffer_size=read_buffer_size)
-
- # Test read of first two lines.
- f.seek(0)
- self.assertEqual(f.readline(), lines[0])
- self.assertEqual(f.tell(), len(lines[0]))
- self.assertEqual(f.readline(), lines[1])
-
- # Test read at line boundary.
- f.seek(file_size - len(lines[-1]) - 1)
- self.assertEqual(f.readline(), '\n')
-
- # Test read at end of file.
- f.seek(file_size)
- self.assertEqual(f.readline(), '')
-
- # Test reads at random positions.
- random.seed(0)
- for _ in range(0, 10):
- start = random.randint(0, file_size - 1)
- line_index = 0
- # Find line corresponding to start index.
- chars_left = start
- while True:
- next_line_length = len(lines[line_index])
- if chars_left - next_line_length < 0:
- break
- chars_left -= next_line_length
- line_index += 1
- f.seek(start)
- self.assertEqual(f.readline(), lines[line_index][chars_left:])
-
- def test_file_write(self):
- file_name = 'gs://gcsio-test/write_file'
- file_size = 5 * 1024 * 1024 + 2000
- contents = os.urandom(file_size)
- f = self.gcs.open(file_name, 'w')
- f.write(contents[0:1000])
- f.write(contents[1000:1024 * 1024])
- f.write(contents[1024 * 1024:])
- f.close()
- bucket, name = gcsio.parse_gcs_path(file_name)
- self.assertEqual(
- self.client.objects.get_file(bucket, name).contents, contents)
-
- def test_context_manager(self):
- # Test writing with a context manager.
- file_name = 'gs://gcsio-test/context_manager_file'
- file_size = 1024
- contents = os.urandom(file_size)
- with self.gcs.open(file_name, 'w') as f:
- f.write(contents)
- bucket, name = gcsio.parse_gcs_path(file_name)
- self.assertEqual(
- self.client.objects.get_file(bucket, name).contents, contents)
-
- # Test reading with a context manager.
- with self.gcs.open(file_name) as f:
- self.assertEqual(f.read(), contents)
-
- # Test that exceptions are not swallowed by the context manager.
- with self.assertRaises(ZeroDivisionError):
- with self.gcs.open(file_name) as f:
- f.read(0 / 0)
-
- def test_glob(self):
- bucket_name = 'gcsio-test'
- object_names = [
- 'cow/cat/fish',
- 'cow/cat/blubber',
- 'cow/dog/blubber',
- 'apple/dog/blubber',
- 'apple/fish/blubber',
- 'apple/fish/blowfish',
- 'apple/fish/bambi',
- 'apple/fish/balloon',
- 'apple/fish/cat',
- 'apple/fish/cart',
- 'apple/fish/carl',
- 'apple/dish/bat',
- 'apple/dish/cat',
- 'apple/dish/carl',
- ]
- for object_name in object_names:
- file_name = 'gs://%s/%s' % (bucket_name, object_name)
- self._insert_random_file(self.client, file_name, 0)
- test_cases = [
- ('gs://gcsio-test/*', [
- 'cow/cat/fish',
- 'cow/cat/blubber',
- 'cow/dog/blubber',
- 'apple/dog/blubber',
- 'apple/fish/blubber',
- 'apple/fish/blowfish',
- 'apple/fish/bambi',
- 'apple/fish/balloon',
- 'apple/fish/cat',
- 'apple/fish/cart',
- 'apple/fish/carl',
- 'apple/dish/bat',
- 'apple/dish/cat',
- 'apple/dish/carl',
- ]),
- ('gs://gcsio-test/cow/*', [
- 'cow/cat/fish',
- 'cow/cat/blubber',
- 'cow/dog/blubber',
- ]),
- ('gs://gcsio-test/cow/ca*', [
- 'cow/cat/fish',
- 'cow/cat/blubber',
- ]),
- ('gs://gcsio-test/apple/[df]ish/ca*', [
- 'apple/fish/cat',
- 'apple/fish/cart',
- 'apple/fish/carl',
- 'apple/dish/cat',
- 'apple/dish/carl',
- ]),
- ('gs://gcsio-test/apple/fish/car?', [
- 'apple/fish/cart',
- 'apple/fish/carl',
- ]),
- ('gs://gcsio-test/apple/fish/b*', [
- 'apple/fish/blubber',
- 'apple/fish/blowfish',
- 'apple/fish/bambi',
- 'apple/fish/balloon',
- ]),
- ('gs://gcsio-test/apple/dish/[cb]at', [
- 'apple/dish/bat',
- 'apple/dish/cat',
- ]),
- ]
- for file_pattern, expected_object_names in test_cases:
- expected_file_names = ['gs://%s/%s' % (bucket_name, o) for o in
- expected_object_names]
- self.assertEqual(set(self.gcs.glob(file_pattern)),
- set(expected_file_names))
-
-
-class TestPipeStream(unittest.TestCase):
-
- def _read_and_verify(self, stream, expected, buffer_size):
- data_list = []
- bytes_read = 0
- seen_last_block = False
- while True:
- data = stream.read(buffer_size)
- self.assertLessEqual(len(data), buffer_size)
- if len(data) < buffer_size:
- # Test the constraint that the pipe stream returns less than the buffer
- # size only when at the end of the stream.
- if data:
- self.assertFalse(seen_last_block)
- seen_last_block = True
- if not data:
- break
- data_list.append(data)
- bytes_read += len(data)
- self.assertEqual(stream.tell(), bytes_read)
- self.assertEqual(''.join(data_list), expected)
-
- def test_pipe_stream(self):
- block_sizes = list(4 ** i for i in range(0, 12))
- data_blocks = list(os.urandom(size) for size in block_sizes)
- expected = ''.join(data_blocks)
-
- buffer_sizes = [100001, 512 * 1024, 1024 * 1024]
-
- for buffer_size in buffer_sizes:
- parent_conn, child_conn = multiprocessing.Pipe()
- stream = gcsio.GcsBufferedWriter.PipeStream(child_conn)
- child_thread = threading.Thread(target=self._read_and_verify,
- args=(stream, expected, buffer_size))
- child_thread.start()
- for data in data_blocks:
- parent_conn.send_bytes(data)
- parent_conn.close()
- child_thread.join()
-
-
-if __name__ == '__main__':
- logging.getLogger().setLevel(logging.INFO)
- unittest.main()