You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2016/06/22 05:09:05 UTC
[1/2] incubator-beam git commit: Closes #507
Repository: incubator-beam
Updated Branches:
refs/heads/python-sdk 4840f5275 -> e3a43fb5c
Closes #507
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/e3a43fb5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/e3a43fb5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/e3a43fb5
Branch: refs/heads/python-sdk
Commit: e3a43fb5c0530fcfb0e06ef86b78fedd912bbc54
Parents: 4840f52 2ebd137
Author: Dan Halperin <dh...@google.com>
Authored: Tue Jun 21 22:08:47 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Tue Jun 21 22:08:47 2016 -0700
----------------------------------------------------------------------
sdks/python/apache_beam/io/avroio.py | 205 +++++++++
sdks/python/apache_beam/io/avroio_test.py | 147 +++++++
sdks/python/apache_beam/io/filebasedsource.py | 246 +++++++++++
.../apache_beam/io/filebasedsource_test.py | 416 +++++++++++++++++++
sdks/python/apache_beam/io/iobase.py | 26 +-
sdks/python/apache_beam/io/range_trackers.py | 14 +-
sdks/python/apache_beam/io/sources_test.py | 46 +-
.../python/apache_beam/runners/direct_runner.py | 5 +-
sdks/python/setup.py | 1 +
9 files changed, 1094 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
[2/2] incubator-beam git commit: Implements a framework for
developing sources for new file types.
Posted by dh...@apache.org.
Implements a framework for developing sources for new file types.
Module 'filebasedsource' provides a framework for creating sources for new file types. This framework readily implements several features common to many sources based on files.
Additionally, module 'avroio' contains a new source, 'AvroSource', that is implemented using the framework described above. 'AvroSource' is a source for reading Avro files.
Adds many unit tests for 'filebasedsource' and 'avroio' modules.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/2ebd137b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/2ebd137b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/2ebd137b
Branch: refs/heads/python-sdk
Commit: 2ebd137b28acd6f5c8bfd6994973b2ebbde5c84f
Parents: 4840f52
Author: Chamikara Jayalath <ch...@apache.org>
Authored: Mon Jun 20 18:09:50 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Tue Jun 21 22:08:47 2016 -0700
----------------------------------------------------------------------
sdks/python/apache_beam/io/avroio.py | 205 +++++++++
sdks/python/apache_beam/io/avroio_test.py | 147 +++++++
sdks/python/apache_beam/io/filebasedsource.py | 246 +++++++++++
.../apache_beam/io/filebasedsource_test.py | 416 +++++++++++++++++++
sdks/python/apache_beam/io/iobase.py | 26 +-
sdks/python/apache_beam/io/range_trackers.py | 14 +-
sdks/python/apache_beam/io/sources_test.py | 46 +-
.../python/apache_beam/runners/direct_runner.py | 5 +-
sdks/python/setup.py | 1 +
9 files changed, 1094 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/avroio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py
new file mode 100644
index 0000000..022a68d
--- /dev/null
+++ b/sdks/python/apache_beam/io/avroio.py
@@ -0,0 +1,205 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Implements a source for reading Avro files."""
+
+import os
+import StringIO
+import zlib
+
+from apache_beam.io import filebasedsource
+from avro import datafile
+from avro import io as avro_io
+from avro import schema
+
+
+class AvroSource(filebasedsource.FileBasedSource):
+ """A source for reading Avro files.
+
+ ``AvroSource`` is implemented using the file-based source framework available
+ in module 'filebasedsource'. Hence please refer to module 'filebasedsource'
+ to fully understand how this source implements operations common to all
+ file-based sources such as file-pattern expansion and splitting into bundles
+ for parallel processing.
+
+ If '/mypath/myavrofiles*' is a file-pattern that points to a set of Avro
+ files, a ``PCollection`` for the records in these Avro files can be created in
+ the following manner.
+
+ p = df.Pipeline(argv=pipeline_args)
+ records = p | df.io.Read('Read', avroio.AvroSource('/mypath/myavrofiles*'))
+
+ Each record of this ``PCollection`` will contain a Python dictionary that
+ complies with the schema contained in the Avro file that contains that
+ particular record.
+ Keys of each dictionary will contain the corresponding field names and will
+ be of type ``string``. Values of the dictionary will be of the type defined in
+ the corresponding Avro schema.
+
+ For example, if schema of the Avro file is following.
+ {"namespace": "example.avro","type": "record","name": "User","fields":
+ [{"name": "name", "type": "string"},
+ {"name": "favorite_number", "type": ["int", "null"]},
+ {"name": "favorite_color", "type": ["string", "null"]}]}
+
+ Then records generated by ``AvroSource`` will be dictionaries of the following
+ form.
+ {u'name': u'Alyssa', u'favorite_number': 256, u'favorite_color': None}).
+ """
+
+ def __init__(self, file_pattern, min_bundle_size=0):
+ super(AvroSource, self).__init__(file_pattern, min_bundle_size)
+ self._avro_schema = None
+ self._codec = None
+ self._sync_marker = None
+
+ class AvroBlock(object):
+ """Represents a block of an Avro file."""
+
+ def __init__(self, block_bytes, num_records, avro_schema, avro_codec,
+ offset):
+ self._block_bytes = block_bytes
+ self._num_records = num_records
+ self._avro_schema = avro_schema
+ self._avro_codec = avro_codec
+ self._offset = offset
+
+ def size(self):
+ return len(self._block_bytes)
+
+ def _decompress_bytes(self, encoding, data):
+ if encoding == 'null':
+ return data
+ elif encoding == 'deflate':
+ # zlib.MAX_WBITS is the window size. '-' sign indicates that this is
+ # raw data (without headers). See zlib and Avro documentations for more
+ # details.
+ return zlib.decompress(data, -zlib.MAX_WBITS)
+ else:
+ raise ValueError('Unsupported compression type: %r', encoding)
+
+ def records(self):
+ decompressed_bytes = self._decompress_bytes(self._avro_codec,
+ self._block_bytes)
+ decoder = avro_io.BinaryDecoder(StringIO.StringIO(decompressed_bytes))
+ reader = avro_io.DatumReader(
+ writers_schema=schema.parse(self._avro_schema),
+ readers_schema=schema.parse(self._avro_schema))
+
+ current_record = 0
+ while current_record < self._num_records:
+ yield reader.read(decoder)
+ current_record += 1
+
+ def offset(self):
+ return self._offset
+
+ def read_records(self, file_name, range_tracker):
+ start_offset = range_tracker.start_position()
+ if start_offset is None:
+ start_offset = 0
+
+ f = self.open_file(file_name)
+ try:
+ self._codec, self._avro_schema, self._sync_marker = (
+ AvroUtils.read_meta_data_from_file(f))
+
+ # We have to start at current position if previous bundle ended at the
+ # end of a sync marker.
+ start_offset = max(0, start_offset - len(self._sync_marker))
+
+ f.seek(start_offset)
+ while self.advance_pass_next_sync_marker(f):
+ if not range_tracker.try_claim(f.tell()):
+ return
+ next_block = self.read_next_block(f)
+ if next_block:
+ for record in next_block.records():
+ yield record
+ else:
+ return
+ finally:
+ f.close()
+
+ def advance_pass_next_sync_marker(self, f):
+ buf_size = 10000
+
+ data = f.read(buf_size)
+ while data:
+ pos = data.find(self._sync_marker)
+ if pos >= 0:
+ # Adjusting the current position to the ending position of the sync
+ # marker.
+ backtrack = len(data) - pos - len(self._sync_marker)
+ f.seek(-1 * backtrack, os.SEEK_CUR)
+ return True
+ else:
+ if f.tell() >= len(self._sync_marker):
+ # Backtracking in case we partially read the sync marker during the
+ # previous read. We only have to backtrack if there are at least
+ # len(sync_marker) bytes before current position. We only have to
+ # backtrack (len(sync_marker) - 1) bytes.
+ f.seek(-1 * (len(self._sync_marker) - 1), os.SEEK_CUR)
+ data = f.read(buf_size)
+
+ def read_next_block(self, f):
+ decoder = avro_io.BinaryDecoder(f)
+ num_records = decoder.read_long()
+ block_size = decoder.read_long()
+
+ block_bytes = decoder.read(block_size)
+ return AvroSource.AvroBlock(block_bytes, num_records,
+ self._avro_schema,
+ self._codec, f.tell()) if block_bytes else None
+
+
+class AvroUtils(object):
+
+ @staticmethod
+ def read_meta_data_from_file(f):
+ """Reads metadata from a given Avro file.
+
+ Args:
+ f: Avro file to read.
+ Returns:
+ a tuple containing the codec, schema, and the sync marker of the Avro
+ file.
+
+ Raises:
+ ValueError: if the file does not start with the byte sequence defined in
+ the specification.
+ """
+ f.seek(0, 0)
+ header = avro_io.DatumReader().read_data(datafile.META_SCHEMA,
+ datafile.META_SCHEMA,
+ avro_io.BinaryDecoder(f))
+ if header.get('magic') != datafile.MAGIC:
+ raise ValueError('Not an Avro file. File header should start with %s but'
+ 'started with %s instead.',
+ datafile.MAGIC, header.get('magic'))
+
+ meta = header['meta']
+
+ if datafile.CODEC_KEY in meta:
+ codec = meta[datafile.CODEC_KEY]
+ else:
+ codec = 'null'
+
+ schema_string = meta[datafile.SCHEMA_KEY]
+ sync_marker = header['sync']
+
+ return codec, schema_string, sync_marker
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py
new file mode 100644
index 0000000..d70b3d1
--- /dev/null
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+import tempfile
+import unittest
+
+from apache_beam.io import avroio
+from apache_beam.io import filebasedsource
+from avro.datafile import DataFileWriter
+from avro.io import DatumWriter
+import avro.schema as avro_schema
+
+
+class TestAvro(unittest.TestCase):
+
+ def setUp(self):
+ # Reducing the size of thread pools. Without this test execution may fail in
+ # environments with limited amount of resources.
+ filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
+
+ RECORDS = [{'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue'},
+ {'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green'},
+ {'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown'},
+ {'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue'},
+ {'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red'},
+ {'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green'}]
+
+ def _write_data(self, directory=None,
+ prefix=tempfile.template,
+ codec='null',
+ count=len(RECORDS)):
+ schema = ('{\"namespace\": \"example.avro\",'
+ '\"type\": \"record\",'
+ '\"name\": \"User\",'
+ '\"fields\": ['
+ '{\"name\": \"name\", \"type\": \"string\"},'
+ '{\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},'
+ '{\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}'
+ ']}')
+
+ schema = avro_schema.parse(schema)
+
+ with tempfile.NamedTemporaryFile(
+ delete=False, dir=directory, prefix=prefix) as f:
+ writer = DataFileWriter(f, DatumWriter(), schema, codec=codec)
+ len_records = len(self.RECORDS)
+ for i in range(count):
+ writer.append(self.RECORDS[i % len_records])
+ writer.close()
+
+ return f.name
+
+ def _write_pattern(self, num_files):
+ assert num_files > 0
+ temp_dir = tempfile.mkdtemp()
+
+ file_name = None
+ for _ in range(num_files):
+ file_name = self._write_data(directory=temp_dir, prefix='mytemp')
+
+ assert file_name
+ file_name_prefix = file_name[:file_name.rfind(os.path.sep)]
+ return file_name_prefix + os.path.sep + 'mytemp*'
+
+ def _run_avro_test(
+ self, pattern, desired_bundle_size, perform_splitting, expected_result):
+ source = avroio.AvroSource(pattern)
+
+ read_records = []
+ if perform_splitting:
+ assert desired_bundle_size
+ splits = [split for split in source.split(
+ desired_bundle_size=desired_bundle_size)]
+ if len(splits) < 2:
+ raise ValueError('Test is trivial. Please adjust it so that at least '
+ 'two splits get generated')
+ for split in splits:
+ records = [record for record in split.source.read(
+ split.source.get_range_tracker(split.start_position,
+ split.stop_position))]
+ read_records.extend(records)
+ else:
+ range_tracker = source.get_range_tracker(None, None)
+ read_records = [record for record in source.read(range_tracker)]
+
+ self.assertItemsEqual(expected_result, read_records)
+
+ def test_read_without_splitting(self):
+ file_name = self._write_data()
+ expected_result = self.RECORDS
+ self._run_avro_test(file_name, None, False, expected_result)
+
+ def test_read_with_splitting(self):
+ file_name = self._write_data()
+ expected_result = self.RECORDS
+ self._run_avro_test(file_name, 100, True, expected_result)
+
+ def test_read_without_splitting_multiple_blocks(self):
+ file_name = self._write_data(count=12000)
+ expected_result = self.RECORDS * 2000
+ self._run_avro_test(file_name, None, False, expected_result)
+
+ def test_read_with_splitting_multiple_blocks(self):
+ file_name = self._write_data(count=12000)
+ expected_result = self.RECORDS * 2000
+ self._run_avro_test(file_name, 10000, True, expected_result)
+
+ def test_read_without_splitting_compressed_deflate(self):
+ file_name = self._write_data(codec='deflate')
+ expected_result = self.RECORDS
+ self._run_avro_test(file_name, None, False, expected_result)
+
+ def test_read_with_splitting_compressed_deflate(self):
+ file_name = self._write_data(codec='deflate')
+ expected_result = self.RECORDS
+ self._run_avro_test(file_name, 100, True, expected_result)
+
+ def test_read_without_splitting_pattern(self):
+ pattern = self._write_pattern(3)
+ expected_result = self.RECORDS * 3
+ self._run_avro_test(pattern, None, False, expected_result)
+
+ def test_read_with_splitting_pattern(self):
+ pattern = self._write_pattern(3)
+ expected_result = self.RECORDS * 3
+ self._run_avro_test(pattern, 100, True, expected_result)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/filebasedsource.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py
new file mode 100644
index 0000000..c877e44
--- /dev/null
+++ b/sdks/python/apache_beam/io/filebasedsource.py
@@ -0,0 +1,246 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A framework for developing sources for new file types.
+
+To create a source for a new file type a sub-class of ``FileBasedSource`` should
+be created. Sub-classes of ``FileBasedSource`` must implement the method
+``FileBasedSource.read_records()``. Please read the documentation of that method
+for more details.
+
+For an example implementation of ``FileBasedSource`` see ``avroio.AvroSource``.
+"""
+
+from multiprocessing.pool import ThreadPool
+import os
+import range_trackers
+
+from apache_beam.io import fileio
+from apache_beam.io import iobase
+
+MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 25
+
+
+class _ConcatSource(iobase.BoundedSource):
+ """A ``BoundedSource`` that can group a set of ``BoundedSources``."""
+
+ def __init__(self, sources):
+ self._sources = sources
+
+ @property
+ def sources(self):
+ return self._sources
+
+ def estimate_size(self):
+ return sum(s.estimate_size() for s in self._sources)
+
+ def split(
+ self, desired_bundle_size=None, start_position=None, stop_position=None):
+ if start_position or stop_position:
+ raise ValueError(
+ 'Multi-level initial splitting is not supported. Expected start and '
+ 'stop positions to be None. Received %r and %r respectively.',
+ start_position, stop_position)
+
+ for source in self._sources:
+ # We assume all sub-sources to produce bundles that specify weight using
+ # the same unit. For example, all sub-sources may specify the size in
+ # bytes as their weight.
+ for bundle in source.split(desired_bundle_size, None, None):
+ yield bundle
+
+ def get_range_tracker(self, start_position, stop_position):
+ assert start_position is None
+ assert stop_position is None
+ # This will be invoked only when FileBasedSource is read without splitting.
+ # For that case, we only support reading the whole source.
+ return range_trackers.OffsetRangeTracker(0, len(self.sources))
+
+ def read(self, range_tracker):
+ for index, sub_source in enumerate(self.sources):
+ if not range_tracker.try_claim(index):
+ return
+
+ sub_source_tracker = sub_source.get_range_tracker(None, None)
+ for record in sub_source.read(sub_source_tracker):
+ yield record
+
+ def default_output_coder(self):
+ if self._sources:
+ # Getting coder from the first sub-sources. This assumes all sub-sources
+ # to produce the same coder.
+ return self._sources[0].default_output_coder()
+ else:
+ # Defaulting to PickleCoder.
+ return super(_ConcatSource, self).default_output_coder()
+
+
+class FileBasedSource(iobase.BoundedSource):
+ """A ``BoundedSource`` for reading a file glob of a given type."""
+
+ def __init__(self, file_pattern, min_bundle_size=0):
+ """Initializes ``FileBasedSource``.
+
+ Args:
+ file_pattern: the file glob to read.
+ min_bundle_size: minimum size of bundles that should be generated when
+ performing initial splitting on this source.
+ """
+ self._pattern = file_pattern
+ self._concat_source = None
+ self._min_bundle_size = min_bundle_size
+
+ def _get_concat_source(self):
+ if self._concat_source is None:
+ single_file_sources = []
+ file_names = [f for f in fileio.ChannelFactory.glob(self._pattern)]
+ sizes = FileBasedSource._estimate_sizes_in_parallel(file_names)
+
+ for index, file_name in enumerate(file_names):
+ if sizes[index] == 0:
+ continue # Ignoring empty file.
+
+ single_file_source = _SingleFileSource(
+ self, file_name,
+ 0,
+ sizes[index],
+ min_bundle_size=self._min_bundle_size)
+ single_file_sources.append(single_file_source)
+ self._concat_source = _ConcatSource(single_file_sources)
+ return self._concat_source
+
+ def open_file(self, file_name):
+ return fileio.ChannelFactory.open(
+ file_name, 'rb', 'application/octet-stream')
+
+ @staticmethod
+ def _estimate_sizes_in_parallel(file_names):
+
+ def _calculate_size_of_file(file_name):
+ f = fileio.ChannelFactory.open(
+ file_name, 'rb', 'application/octet-stream')
+ try:
+ f.seek(0, os.SEEK_END)
+ return f.tell()
+ finally:
+ f.close()
+
+ return ThreadPool(MAX_NUM_THREADS_FOR_SIZE_ESTIMATION).map(
+ _calculate_size_of_file, file_names)
+
+ def split(
+ self, desired_bundle_size=None, start_position=None, stop_position=None):
+ return self._get_concat_source().split(
+ desired_bundle_size=desired_bundle_size,
+ start_position=start_position,
+ stop_position=stop_position)
+
+ def estimate_size(self):
+ return self._get_concat_source().estimate_size()
+
+ def read(self, range_tracker):
+ return self._get_concat_source().read(range_tracker)
+
+ def get_range_tracker(self, start_position, stop_position):
+ return self._get_concat_source().get_range_tracker(start_position,
+ stop_position)
+
+ def default_output_coder(self):
+ return self._get_concat_source().default_output_coder()
+
+ def read_records(self, file_name, offset_range_tracker):
+ """Returns a generator of records created by reading file 'file_name'.
+
+ Args:
+ file_name: a ``string`` that gives the name of the file to be read. Method
+ ``FileBasedSource.open_file()`` must be used to open the file
+ and create a seekable file object.
+ offset_range_tracker: a object of type ``OffsetRangeTracker``. This
+ defines the byte range of the file that should be
+ read. See documentation in
+ ``iobase.BoundedSource.read()`` for more information
+ on reading records while complying to the range
+ defined by a given ``RangeTracker``.
+
+ Returns:
+ a iterator that gives the records read from the given file.
+ """
+ raise NotImplementedError
+
+
+class _SingleFileSource(iobase.BoundedSource):
+ """Denotes a source for a specific file type.
+
+ This should be sub-classed to add support for reading a new file type.
+ """
+
+ def __init__(self, file_based_source, file_name, start_offset, stop_offset,
+ min_bundle_size=0):
+ if not (isinstance(start_offset, int) or isinstance(start_offset, long)):
+ raise ValueError(
+ 'start_offset must be a number. Received: %r', start_offset)
+ if not (isinstance(stop_offset, int) or isinstance(stop_offset, long)):
+ raise ValueError(
+ 'stop_offset must be a number. Received: %r', stop_offset)
+ if start_offset >= stop_offset:
+ raise ValueError(
+ 'start_offset must be smaller than stop_offset. Received %d and %d '
+ 'for start and stop offsets respectively', start_offset, stop_offset)
+
+ self._file_name = file_name
+ self._is_gcs_file = file_name.startswith('gs://') if file_name else False
+ self._start_offset = start_offset
+ self._stop_offset = stop_offset
+ self._min_bundle_size = min_bundle_size
+ self._file_based_source = file_based_source
+
+ def split(self, desired_bundle_size, start_offset=None, stop_offset=None):
+ if start_offset is None:
+ start_offset = self._start_offset
+ if stop_offset is None:
+ stop_offset = self._stop_offset
+
+ bundle_size = max(desired_bundle_size, self._min_bundle_size)
+
+ bundle_start = start_offset
+ while bundle_start < stop_offset:
+ bundle_stop = min(bundle_start + bundle_size, stop_offset)
+ yield iobase.SourceBundle(
+ bundle_stop - bundle_start,
+ _SingleFileSource(
+ self._file_based_source,
+ self._file_name,
+ bundle_start,
+ bundle_stop,
+ min_bundle_size=self._min_bundle_size),
+ bundle_start,
+ bundle_stop)
+ bundle_start = bundle_stop
+
+ def estimate_size(self):
+ return self._stop_offset - self._start_offset
+
+ def get_range_tracker(self, start_position, stop_position):
+ if start_position is None:
+ start_position = self._start_offset
+ if stop_position is None:
+ stop_position = self._stop_offset
+
+ return range_trackers.OffsetRangeTracker(start_position, stop_position)
+
+ def read(self, range_tracker):
+ return self._file_based_source.read_records(self._file_name, range_tracker)
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/filebasedsource_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py
new file mode 100644
index 0000000..c7837ec
--- /dev/null
+++ b/sdks/python/apache_beam/io/filebasedsource_test.py
@@ -0,0 +1,416 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import math
+import os
+import tempfile
+import unittest
+
+import apache_beam as beam
+from apache_beam.io import filebasedsource
+from apache_beam.io import iobase
+from apache_beam.io import range_trackers
+
+# importing following private classes for testing
+from apache_beam.io.filebasedsource import _ConcatSource as ConcatSource
+from apache_beam.io.filebasedsource import _SingleFileSource as SingleFileSource
+
+from apache_beam.io.filebasedsource import FileBasedSource
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import equal_to
+
+
+class LineSource(FileBasedSource):
+
+ def read_records(self, file_name, range_tracker):
+ f = self.open_file(file_name)
+ try:
+ start = range_tracker.start_position()
+ f.seek(start)
+ if start > 0:
+ f.seek(-1, os.SEEK_CUR)
+ start -= 1
+ line = f.readline()
+ start += len(line)
+ current = start
+ for line in f:
+ if not range_tracker.try_claim(current):
+ return
+ yield line.rstrip('\n')
+ current += len(line)
+ finally:
+ f.close()
+
+
+def _write_data(num_lines, directory=None, prefix=tempfile.template):
+ all_data = []
+ with tempfile.NamedTemporaryFile(
+ delete=False, dir=directory, prefix=prefix) as f:
+ for i in range(num_lines):
+ data = 'line' + str(i)
+ all_data.append(data)
+ f.write(data + '\n')
+
+ return f.name, all_data
+
+
+def _write_pattern(lines_per_file):
+ temp_dir = tempfile.mkdtemp()
+
+ all_data = []
+ file_name = None
+ start_index = 0
+ for i in range(len(lines_per_file)):
+ file_name, data = _write_data(lines_per_file[i],
+ directory=temp_dir, prefix='mytemp')
+ all_data.extend(data)
+ start_index += lines_per_file[i]
+
+ assert file_name
+ return (
+ file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*',
+ all_data)
+
+
+class TestConcatSource(unittest.TestCase):
+
+ class DummySource(iobase.BoundedSource):
+
+ def __init__(self, values):
+ self._values = values
+
+ def split(self, desired_bundle_size, start_position=None,
+ stop_position=None):
+ # simply devides values into two bundles
+ middle = len(self._values) / 2
+ yield iobase.SourceBundle(0.5, TestConcatSource.DummySource(
+ self._values[:middle]), None, None)
+ yield iobase.SourceBundle(0.5, TestConcatSource.DummySource(
+ self._values[middle:]), None, None)
+
+ def get_range_tracker(self, start_position, stop_position):
+ if start_position is None:
+ start_position = 0
+ if stop_position is None:
+ stop_position = len(self._values)
+
+ return range_trackers.OffsetRangeTracker(start_position, stop_position)
+
+ def read(self, range_tracker):
+ for index, value in enumerate(self._values):
+ if not range_tracker.try_claim(index):
+ return
+
+ yield value
+
+ def estimate_size(self):
+ return len(self._values) # Assuming each value to be 1 byte.
+
+ def setUp(self):
+ # Reducing the size of thread pools. Without this test execution may fail in
+ # environments with limited amount of resources.
+ filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
+
+ def test_read(self):
+ sources = [TestConcatSource.DummySource(range(start, start + 10)) for start
+ in [0, 10, 20]]
+ concat = ConcatSource(sources)
+ range_tracker = concat.get_range_tracker(None, None)
+ read_data = [value for value in concat.read(range_tracker)]
+ self.assertItemsEqual(range(30), read_data)
+
+ def test_split(self):
+ sources = [TestConcatSource.DummySource(range(start, start + 10)) for start
+ in [0, 10, 20]]
+ concat = ConcatSource(sources)
+ splits = [split for split in concat.split()]
+ self.assertEquals(6, len(splits))
+
+ # Reading all splits
+ read_data = []
+ for split in splits:
+ range_tracker_for_split = split.source.get_range_tracker(
+ split.start_position,
+ split.stop_position)
+ read_data.extend([value for value in split.source.read(
+ range_tracker_for_split)])
+ self.assertItemsEqual(range(30), read_data)
+
+ def test_estimate_size(self):
+ sources = [TestConcatSource.DummySource(range(start, start + 10)) for start
+ in [0, 10, 20]]
+ concat = ConcatSource(sources)
+ self.assertEquals(30, concat.estimate_size())
+
+
+class TestFileBasedSource(unittest.TestCase):
+
+ def setUp(self):
+ # Reducing the size of thread pools. Without this test execution may fail in
+ # environments with limited amount of resources.
+ filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
+
+ def test_fully_read_single_file(self):
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+ fbs = LineSource(file_name)
+ range_tracker = fbs.get_range_tracker(None, None)
+ read_data = [record for record in fbs.read(range_tracker)]
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_fully_read_file_pattern(self):
+ pattern, expected_data = _write_pattern([5, 3, 12, 8, 8, 4])
+ assert len(expected_data) == 40
+ fbs = LineSource(pattern)
+ range_tracker = fbs.get_range_tracker(None, None)
+ read_data = [record for record in fbs.read(range_tracker)]
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_fully_read_file_pattern_with_empty_files(self):
+ pattern, expected_data = _write_pattern([5, 0, 12, 0, 8, 0])
+ assert len(expected_data) == 25
+ fbs = LineSource(pattern)
+ range_tracker = fbs.get_range_tracker(None, None)
+ read_data = [record for record in fbs.read(range_tracker)]
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_estimate_size_of_file(self):
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+ fbs = LineSource(file_name)
+ self.assertEquals(10 * 6, fbs.estimate_size())
+
+ def test_estimate_size_of_pattern(self):
+ pattern, expected_data = _write_pattern([5, 3, 10, 8, 8, 4])
+ assert len(expected_data) == 38
+ fbs = LineSource(pattern)
+ self.assertEquals(38 * 6, fbs.estimate_size())
+
+ pattern, expected_data = _write_pattern([5, 3, 9])
+ assert len(expected_data) == 17
+ fbs = LineSource(pattern)
+ self.assertEquals(17 * 6, fbs.estimate_size())
+
+ def test_splits_into_subranges(self):
+ pattern, expected_data = _write_pattern([5, 9, 6])
+ assert len(expected_data) == 20
+ fbs = LineSource(pattern)
+ splits = [split for split in fbs.split(desired_bundle_size=15)]
+ expected_num_splits = (
+ math.ceil(float(6 * 5) / 15) +
+ math.ceil(float(6 * 9) / 15) +
+ math.ceil(float(6 * 6) / 15))
+ assert len(splits) == expected_num_splits
+
+ def test_read_splits_single_file(self):
+ file_name, expected_data = _write_data(100)
+ assert len(expected_data) == 100
+ fbs = LineSource(file_name)
+ splits = [split for split in fbs.split(desired_bundle_size=33)]
+
+ # Reading all splits
+ read_data = []
+ for split in splits:
+ source = split.source
+ range_tracker = source.get_range_tracker(split.start_position,
+ split.stop_position)
+ data_from_split = [data for data in source.read(range_tracker)]
+ read_data.extend(data_from_split)
+
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_read_splits_file_pattern(self):
+ pattern, expected_data = _write_pattern([34, 66, 40, 24, 24, 12])
+ assert len(expected_data) == 200
+ fbs = LineSource(pattern)
+ splits = [split for split in fbs.split(desired_bundle_size=50)]
+
+ # Reading all splits
+ read_data = []
+ for split in splits:
+ source = split.source
+ range_tracker = source.get_range_tracker(split.start_position,
+ split.stop_position)
+ data_from_split = [data for data in source.read(range_tracker)]
+ read_data.extend(data_from_split)
+
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_dataflow_file(self):
+ file_name, expected_data = _write_data(100)
+ assert len(expected_data) == 100
+ pipeline = beam.Pipeline('DirectPipelineRunner')
+ pcoll = pipeline | beam.Read('Read', LineSource(file_name))
+ assert_that(pcoll, equal_to(expected_data))
+ pipeline.run()
+
+ def test_dataflow_pattern(self):
+ pattern, expected_data = _write_pattern([34, 66, 40, 24, 24, 12])
+ assert len(expected_data) == 200
+ pipeline = beam.Pipeline('DirectPipelineRunner')
+ pcoll = pipeline | beam.Read('Read', LineSource(pattern))
+ assert_that(pcoll, equal_to(expected_data))
+ pipeline.run()
+
+
+class TestSingleFileSource(unittest.TestCase):
+
+ def setUp(self):
+ # Reducing the size of thread pools. Without this test execution may fail in
+ # environments with limited amount of resources.
+ filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
+
+ def test_source_creation_fails_for_non_number_offsets(self):
+ start_not_a_number_error = 'start_offset must be a number*'
+ stop_not_a_number_error = 'stop_offset must be a number*'
+
+ fbs = LineSource('dymmy_pattern')
+
+ with self.assertRaisesRegexp(ValueError, start_not_a_number_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset='aaa', stop_offset='bbb')
+ with self.assertRaisesRegexp(ValueError, start_not_a_number_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset='aaa', stop_offset=100)
+ with self.assertRaisesRegexp(ValueError, stop_not_a_number_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=100, stop_offset='bbb')
+ with self.assertRaisesRegexp(ValueError, stop_not_a_number_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=100, stop_offset=None)
+ with self.assertRaisesRegexp(ValueError, start_not_a_number_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=None, stop_offset=100)
+
+ def test_source_creation_fails_if_start_lg_stop(self):
+ start_larger_than_stop_error = (
+ 'start_offset must be smaller than stop_offset*')
+
+ fbs = LineSource('dymmy_pattern')
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=99, stop_offset=100)
+ with self.assertRaisesRegexp(ValueError, start_larger_than_stop_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=100, stop_offset=99)
+ with self.assertRaisesRegexp(ValueError, start_larger_than_stop_error):
+ SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=100, stop_offset=100)
+
+ def test_estimates_size(self):
+ fbs = LineSource('dymmy_pattern')
+
+ # Should simply return stop_offset - start_offset
+ source = SingleFileSource(
+ fbs, file_name='dummy_file', start_offset=0, stop_offset=100)
+ self.assertEquals(100, source.estimate_size())
+
+ source = SingleFileSource(fbs, file_name='dummy_file', start_offset=10,
+ stop_offset=100)
+ self.assertEquals(90, source.estimate_size())
+
+ def test_read_range_at_beginning(self):
+ fbs = LineSource('dymmy_pattern')
+
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+
+ source = SingleFileSource(fbs, file_name, 0, 10 * 6)
+ range_tracker = source.get_range_tracker(0, 20)
+ read_data = [value for value in source.read(range_tracker)]
+ self.assertItemsEqual(expected_data[:4], read_data)
+
+ def test_read_range_at_end(self):
+ fbs = LineSource('dymmy_pattern')
+
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+
+ source = SingleFileSource(fbs, file_name, 0, 10 * 6)
+ range_tracker = source.get_range_tracker(40, 60)
+ read_data = [value for value in source.read(range_tracker)]
+ self.assertItemsEqual(expected_data[-3:], read_data)
+
+ def test_read_range_at_middle(self):
+ fbs = LineSource('dymmy_pattern')
+
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+
+ source = SingleFileSource(fbs, file_name, 0, 10 * 6)
+ range_tracker = source.get_range_tracker(20, 40)
+ read_data = [value for value in source.read(range_tracker)]
+ self.assertItemsEqual(expected_data[4:7], read_data)
+
+ def test_produces_splits_desiredsize_large_than_size(self):
+ fbs = LineSource('dymmy_pattern')
+
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+ source = SingleFileSource(fbs, file_name, 0, 10 * 6)
+ splits = [split for split in source.split(desired_bundle_size=100)]
+ self.assertEquals(1, len(splits))
+ self.assertEquals(60, splits[0].weight)
+ self.assertEquals(0, splits[0].start_position)
+ self.assertEquals(60, splits[0].stop_position)
+
+ range_tracker = splits[0].source.get_range_tracker(None, None)
+ read_data = [value for value in splits[0].source.read(range_tracker)]
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_produces_splits_desiredsize_smaller_than_size(self):
+ fbs = LineSource('dymmy_pattern')
+
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+ source = SingleFileSource(fbs, file_name, 0, 10 * 6)
+ splits = [split for split in source.split(desired_bundle_size=25)]
+ self.assertEquals(3, len(splits))
+
+ read_data = []
+ for split in splits:
+ source = split.source
+ range_tracker = source.get_range_tracker(split.start_position,
+ split.stop_position)
+ data_from_split = [data for data in source.read(range_tracker)]
+ read_data.extend(data_from_split)
+ self.assertItemsEqual(expected_data, read_data)
+
+ def test_produce_split_with_start_and_end_positions(self):
+ fbs = LineSource('dymmy_pattern')
+
+ file_name, expected_data = _write_data(10)
+ assert len(expected_data) == 10
+ source = SingleFileSource(fbs, file_name, 0, 10 * 6)
+ splits = [split for split in
+ source.split(desired_bundle_size=15, start_offset=10,
+ stop_offset=50)]
+ self.assertEquals(3, len(splits))
+
+ read_data = []
+ for split in splits:
+ source = split.source
+ range_tracker = source.get_range_tracker(split.start_position,
+ split.stop_position)
+ data_from_split = [data for data in source.read(range_tracker)]
+ read_data.extend(data_from_split)
+ self.assertItemsEqual(expected_data[2:9], read_data)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/iobase.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index da7e007..71ef46b 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -324,7 +324,7 @@ SourceBundle = namedtuple(
class BoundedSource(object):
- """A Dataflow source that reads a finite amount of input records.
+ """A source that reads a finite amount of input records.
This class defines following operations which can be used to read the source
efficiently.
@@ -334,12 +334,21 @@ class BoundedSource(object):
* Splitting into bundles of a given size - method ``split()`` can be used to
split the source into a set of sub-sources (bundles) based on a desired
bundle size.
- * Getting a RangeTracker - method ``get_range_tracker() should return a
+ * Getting a RangeTracker - method ``get_range_tracker()`` should return a
``RangeTracker`` object for a given position range for the position type
of the records returned by the source.
* Reading the data - method ``read()`` can be used to read data from the
source while respecting the boundaries defined by a given
``RangeTracker``.
+
+ A runner will perform reading the source in two steps.
+ (1) Method ``get_range_tracker()`` will be invoked with start and end
+ positions to obtain a ``RangeTracker`` for the range of positions the
+ runner intends to read. Source must define a default initial start and end
+ position range. These positions must be used if the start and/or end
+ positions passed to the method ``get_range_tracker()`` are ``None``
+ (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
+ previous step.
"""
def estimate_size(self):
@@ -378,8 +387,10 @@ class BoundedSource(object):
Framework may invoke ``read()`` method with the RangeTracker object returned
here to read data from the source.
Args:
- start_position: starting position of the range.
- stop_position: ending position of the range.
+ start_position: starting position of the range. If 'None' default start
+ position of the source must be used.
+ stop_position: ending position of the range. If 'None' default stop
+ position of the source must be used.
Returns:
a ``RangeTracker`` for the given position range.
"""
@@ -406,8 +417,9 @@ class BoundedSource(object):
Args:
range_tracker: a ``RangeTracker`` whose boundaries must be respected
- when reading data from the source. If 'None' all records
- represented by the current source should be read.
+ when reading data from the source. A runner that reads this
+ source muss pass a ``RangeTracker`` object that is not
+ ``None``.
Returns:
an iterator of data read by the source.
"""
@@ -578,7 +590,7 @@ class RangeTracker(object):
"""Returns the position at the given fraction.
Given a fraction within the range [0.0, 1.0) this method will return the
- position at the given fraction compared the the position range
+ position at the given fraction compared to the position range
[self.start_position, self.stop_position).
** Thread safety **
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/range_trackers.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py
index 36c5dde..c3481de 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -37,8 +37,20 @@ class OffsetRangeTracker(iobase.RangeTracker):
def __init__(self, start, end):
super(OffsetRangeTracker, self).__init__()
+
+ if start is None:
+ raise ValueError('Start offset must not be \'None\'')
+ if end is None:
+ raise ValueError('End offset must not be \'None\'')
+ assert isinstance(start, int) or isinstance(start, long)
+ if end != self.OFFSET_INFINITY:
+ assert isinstance(end, int) or isinstance(end, long)
+
+ assert start <= end
+
self._start_offset = start
self._stop_offset = end
+
self._last_record_start = -1
self._offset_of_last_split_point = -1
self._lock = threading.Lock()
@@ -270,4 +282,4 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker):
# service will estimate progress from positions for us.
raise RuntimeError('GroupedShuffleRangeTracker does not measure fraction'
' consumed due to positions being opaque strings'
- ' that are interpretted by the service')
+ ' that are interpreted by the service')
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/io/sources_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/sources_test.py b/sdks/python/apache_beam/io/sources_test.py
index 21e9797..45c5ab8 100644
--- a/sdks/python/apache_beam/io/sources_test.py
+++ b/sdks/python/apache_beam/io/sources_test.py
@@ -18,12 +18,14 @@
"""Unit tests for the sources framework."""
import logging
+import os
import tempfile
import unittest
import apache_beam as beam
from apache_beam.io import iobase
+from apache_beam.io import range_trackers
from apache_beam.transforms.util import assert_that
from apache_beam.transforms.util import equal_to
@@ -31,13 +33,50 @@ from apache_beam.transforms.util import equal_to
class LineSource(iobase.BoundedSource):
"""A simple source that reads lines from a given file."""
+ TEST_BUNDLE_SIZE = 10
+
def __init__(self, file_name):
self._file_name = file_name
- def read(self, _):
- with open(self._file_name) as f:
+ def read(self, range_tracker):
+ with open(self._file_name, 'rb') as f:
+ start = range_tracker.start_position()
+ f.seek(start)
+ if start > 0:
+ f.seek(-1, os.SEEK_CUR)
+ start -= 1
+ start += len(f.readline())
+ current = start
for line in f:
+ if not range_tracker.try_claim(current):
+ return
yield line.rstrip('\n')
+ current += len(line)
+
+ def split(self, desired_bundle_size, start_position=None, stop_position=None):
+ assert start_position is None
+ assert stop_position is None
+ with open(self._file_name, 'rb') as f:
+ f.seek(0, os.SEEK_END)
+ size = f.tell()
+
+ bundle_start = 0
+ while bundle_start < size:
+ bundle_stop = min(bundle_start + LineSource.TEST_BUNDLE_SIZE, size)
+ yield iobase.SourceBundle(1, self, bundle_start, bundle_stop)
+ bundle_start = bundle_stop
+
+ def get_range_tracker(self, start_position, stop_position):
+ if start_position is None:
+ start_position = 0
+ if stop_position is None:
+ with open(self._file_name, 'rb') as f:
+ f.seek(0, os.SEEK_END)
+ stop_position = f.tell()
+ return range_trackers.OffsetRangeTracker(start_position, stop_position)
+
+ def default_output_coder(self):
+ return beam.coders.ToStringCoder()
class SourcesTest(unittest.TestCase):
@@ -51,7 +90,8 @@ class SourcesTest(unittest.TestCase):
file_name = self._create_temp_file('aaaa\nbbbb\ncccc\ndddd')
source = LineSource(file_name)
- result = [line for line in source.read(None)]
+ range_tracker = source.get_range_tracker(None, None)
+ result = [line for line in source.read(range_tracker)]
self.assertItemsEqual(['aaaa', 'bbbb', 'cccc', 'dddd'], result)
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/apache_beam/runners/direct_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct_runner.py b/sdks/python/apache_beam/runners/direct_runner.py
index 8478822..2c73394 100644
--- a/sdks/python/apache_beam/runners/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct_runner.py
@@ -248,7 +248,10 @@ class DirectPipelineRunner(PipelineRunner):
self._cache.cache_output(transform_node, read_result)
if isinstance(source, iobase.BoundedSource):
- reader = source.read(None)
+ # Getting a RangeTracker for the default range of the source and reading
+ # the full source using that.
+ range_tracker = source.get_range_tracker(None, None)
+ reader = source.read(range_tracker)
read_values(reader)
else:
with source.reader() as reader:
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ebd137b/sdks/python/setup.py
----------------------------------------------------------------------
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index a87c4f0..029226e 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -44,6 +44,7 @@ def get_version():
# Configure the required packages and scripts to install.
REQUIRED_PACKAGES = [
+ 'avro>=1.7.7',
'dill>=0.2.5',
'google-apitools>=0.5.2',
# TODO(silviuc): Reenable api client package dependencies when we can