You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/08/29 00:47:37 UTC
[beam] branch master updated: Remove unnecessary limiters and add a
method to get the size of a PCollection on disk.
This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new fe474b5 Remove unnecessary limiters and add a method to get the size of a PCollection on disk.
new 1d25e2e Merge pull request #12700 from [BEAM-10603] Remove unnecessary limiters and add a method to get the size of a PCollection on disk.
fe474b5 is described below
commit fe474b506277341890aad1aadf1572797a685050
Author: Sam Rohde <sr...@google.com>
AuthorDate: Thu Aug 27 11:02:38 2020 -0700
Remove unnecessary limiters and add a method to get the size of a PCollection on disk.
Change-Id: I3b9ff420817c78988a283bc7b046d85cc732a09b
---
.../runners/interactive/cache_manager.py | 38 +++----
.../runners/interactive/cache_manager_test.py | 42 ++++---
.../runners/interactive/caching/streaming_cache.py | 56 +++++-----
.../interactive/caching/streaming_cache_test.py | 122 ++++-----------------
.../runners/interactive/recording_manager.py | 10 +-
.../interactive/testing/test_cache_manager.py | 20 ++--
6 files changed, 100 insertions(+), 188 deletions(-)
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 48f1fc5..d168a76 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -74,10 +74,9 @@ class CacheManager(object):
Args:
*labels: List of labels for PCollection instance.
- **args: Dict of additional arguments. Currently only supports 'limiters'
- as a list of ElementLimiters, and 'tail' as a boolean. Limiters limits
- the amount of elements read and duration with respect to processing
- time.
+ **args: Dict of additional arguments. Currently only 'tail' as a boolean.
+ When tail is True, will wait and read new elements until the cache is
+ complete.
Returns:
A tuple containing an iterator for the items in the PCollection and the
@@ -157,6 +156,12 @@ class CacheManager(object):
"""Cleans up all the PCollection caches."""
raise NotImplementedError
+ def size(self, *labels):
+ # type: (*str) -> int
+
+ """Returns the size of the PCollection on disk in bytes."""
+ raise NotImplementedError
+
class FileBasedCacheManager(CacheManager):
"""Maps PCollections to local temp files for materialization."""
@@ -173,6 +178,7 @@ class FileBasedCacheManager(CacheManager):
self._cache_dir = tempfile.mkdtemp(
prefix='it-', dir=os.environ.get('TEST_TMPDIR', None))
self._versions = collections.defaultdict(lambda: self._CacheVersion())
+ self.cache_format = cache_format
if cache_format not in self._available_formats:
raise ValueError("Unsupported cache format: '%s'." % cache_format)
@@ -193,6 +199,11 @@ class FileBasedCacheManager(CacheManager):
# and its PCoder type.
self._saved_pcoders = {}
+ def size(self, *labels):
+ if self.exists(*labels):
+ return sum(os.path.getsize(path) for path in self._match(*labels))
+ return 0
+
def exists(self, *labels):
return bool(self._match(*labels))
@@ -216,30 +227,13 @@ class FileBasedCacheManager(CacheManager):
if not self.exists(*labels):
return iter([]), -1
- limiters = args.pop('limiters', [])
-
# Otherwise, return a generator to the cached PCollection.
source = self.source(*labels)._source
range_tracker = source.get_range_tracker(None, None)
reader = source.read(range_tracker)
version = self._latest_version(*labels)
- # The return type is a generator, so in order to implement the limiter for
- # the FileBasedCacheManager we wrap the original generator with the logic
- # to limit yielded elements.
- def limit_reader(r):
- for e in r:
- # Update the limiters and break early out of reading from cache if any
- # are triggered.
- for l in limiters:
- l.update(e)
-
- if any(l.is_triggered() for l in limiters):
- break
-
- yield e
-
- return limit_reader(reader), version
+ return reader, version
def write(self, values, *labels):
sink = self.sink(labels)._sink
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index e7dc936..12c644c 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -30,7 +30,6 @@ import unittest
from apache_beam import coders
from apache_beam.io import filesystems
from apache_beam.runners.interactive import cache_manager as cache
-from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
class FileBasedCacheManagerTest(object):
@@ -92,6 +91,32 @@ class FileBasedCacheManagerTest(object):
self.mock_write_cache(cache_version_one, prefix, cache_label)
self.assertTrue(self.cache_manager.exists(prefix, cache_label))
+ def test_size(self):
+ """Test getting the size of some cache label."""
+
+ # The Beam API for writing doesn't return the number of bytes that was
+ # written to disk. So this test is only possible when the coder encodes the
+ # bytes that will be written directly to disk, which only the WriteToText
+ # transform does (with respect to the WriteToTFRecord transform).
+ if self.cache_manager.cache_format != 'text':
+ return
+
+ prefix = 'full'
+ cache_label = 'some-cache-label'
+
+ # Test that if nothing is written the size is 0.
+ self.assertEqual(self.cache_manager.size(prefix, cache_label), 0)
+
+ value = 'a'
+ self.mock_write_cache([value], prefix, cache_label)
+ coder = self.cache_manager.load_pcoder(prefix, cache_label)
+ encoded = coder.encode(value)
+
+ # Add one to the size on disk because of the extra new-line character when
+ # writing to file.
+ self.assertEqual(
+ self.cache_manager.size(prefix, cache_label), len(encoded) + 1)
+
def test_clear(self):
"""Test that CacheManager can correctly tell if the cache exists or not."""
prefix = 'full'
@@ -193,21 +218,6 @@ class FileBasedCacheManagerTest(object):
self.assertTrue(
self.cache_manager.is_latest_version(version, prefix, cache_label))
- def test_read_with_count_limiter(self):
- """Test the condition where the cache is read once after written once."""
- prefix = 'full'
- cache_label = 'some-cache-label'
- cache_version_one = ['cache', 'version', 'one']
-
- self.mock_write_cache(cache_version_one, prefix, cache_label)
- reader, version = self.cache_manager.read(
- prefix, cache_label, limiters=[CountLimiter(2)])
- pcoll_list = list(reader)
- self.assertListEqual(pcoll_list, ['cache', 'version'])
- self.assertEqual(version, 0)
- self.assertTrue(
- self.cache_manager.is_latest_version(version, prefix, cache_label))
-
class TextFileBasedCacheManagerTest(
FileBasedCacheManagerTest,
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 77f976d..52e8388 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -148,28 +148,18 @@ class StreamingCacheSource:
This class is used to read from file and send its to the TestStream via the
StreamingCacheManager.Reader.
"""
- def __init__(
- self,
- cache_dir,
- labels,
- is_cache_complete=None,
- coder=None,
- limiters=None):
+ def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None):
if not coder:
coder = SafeFastPrimitivesCoder()
if not is_cache_complete:
is_cache_complete = lambda _: True
- if not limiters:
- limiters = []
-
self._cache_dir = cache_dir
self._coder = coder
self._labels = labels
self._path = os.path.join(self._cache_dir, *self._labels)
self._is_cache_complete = is_cache_complete
- self._limiters = limiters
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
@@ -221,16 +211,10 @@ class StreamingCacheSource:
proto_cls = TestStreamFileHeader if pos == 0 else TestStreamFileRecord
msg = self._try_parse_as(proto_cls, to_decode)
if msg:
- for l in self._limiters:
- l.update(msg)
-
- if any(l.is_triggered() for l in self._limiters):
- break
+ yield msg
else:
break
- yield msg
-
def _try_parse_as(self, proto_cls, to_decode):
try:
msg = proto_cls()
@@ -288,6 +272,12 @@ class StreamingCache(CacheManager):
# The sinks to capture data from capturable sources.
# Dict([str, StreamingCacheSink])
self._capture_sinks = {}
+ self._capture_keys = set()
+
+ def size(self, *labels):
+ if self.exists(*labels):
+ return os.path.getsize(os.path.join(self._cache_dir, *labels))
+ return 0
@property
def capture_size(self):
@@ -297,25 +287,26 @@ class StreamingCache(CacheManager):
def capture_paths(self):
return list(self._capture_sinks.keys())
+ @property
+ def capture_keys(self):
+ return self._capture_keys
+
def exists(self, *labels):
path = os.path.join(self._cache_dir, *labels)
return os.path.exists(path)
# TODO(srohde): Modify this to return the correct version.
def read(self, *labels, **args):
- """Returns a generator to read all records from file.
+ """Returns a generator to read all records from file."""
+ tail = args.pop('tail', False)
- Does not tail.
- """
- if not self.exists(*labels):
+ # Only immediately return when the file doesn't exist when the user wants a
+ # snapshot of the cache (when tail is false).
+ if not self.exists(*labels) and not tail:
return iter([]), -1
- limiters = args.pop('limiters', [])
- tail = args.pop('tail', False)
-
reader = StreamingCacheSource(
- self._cache_dir, labels, self._is_cache_complete,
- limiters=limiters).read(tail=tail)
+ self._cache_dir, labels, self._is_cache_complete).read(tail=tail)
# Return an empty iterator if there is nothing in the file yet. This can
# only happen when tail is False.
@@ -325,7 +316,7 @@ class StreamingCache(CacheManager):
return iter([]), -1
return StreamingCache.Reader([header], [reader]).read(), 1
- def read_multiple(self, labels, limiters=None, tail=True):
+ def read_multiple(self, labels, tail=True):
"""Returns a generator to read all records from file.
Does tail until the cache is complete. This is because it is used in the
@@ -333,9 +324,9 @@ class StreamingCache(CacheManager):
pipeline runtime which needs to block.
"""
readers = [
- StreamingCacheSource(
- self._cache_dir, l, self._is_cache_complete,
- limiters=limiters).read(tail=tail) for l in labels
+ StreamingCacheSource(self._cache_dir, l,
+ self._is_cache_complete).read(tail=tail)
+ for l in labels
]
headers = [next(r) for r in readers]
return StreamingCache.Reader(headers, readers).read()
@@ -358,6 +349,7 @@ class StreamingCache(CacheManager):
def clear(self, *labels):
directory = os.path.join(self._cache_dir, *labels[:-1])
filepath = os.path.join(directory, labels[-1])
+ self._capture_keys.discard(labels[-1])
if os.path.exists(filepath):
os.remove(filepath)
return True
@@ -383,6 +375,7 @@ class StreamingCache(CacheManager):
sink = StreamingCacheSink(cache_dir, filename, self._sample_resolution_sec)
if is_capture:
self._capture_sinks[sink.path] = sink
+ self._capture_keys.add(filename)
return sink
def save_pcoder(self, pcoder, *labels):
@@ -398,6 +391,7 @@ class StreamingCache(CacheManager):
shutil.rmtree(self._cache_dir)
self._saved_pcoders = {}
self._capture_sinks = {}
+ self._capture_keys = set()
class Reader(object):
"""Abstraction that reads from PCollection readers.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index 23390cc..2238e0d 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -28,8 +28,6 @@ from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileR
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
-from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
-from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
from apache_beam.testing.test_pipeline import TestPipeline
@@ -66,13 +64,25 @@ class StreamingCacheTest(unittest.TestCase):
# Assert that an empty reader returns an empty list.
self.assertFalse([e for e in reader])
+ def test_size(self):
+ cache = StreamingCache(cache_dir=None)
+ cache.write([TestStreamFileRecord()], 'my_label')
+ coder = cache.load_pcoder('my_label')
+
+ # Add one because of the new-line character that is also written.
+ size = len(coder.encode(TestStreamFileRecord().SerializeToString())) + 1
+ self.assertEqual(cache.size('my_label'), size)
+
def test_clear(self):
cache = StreamingCache(cache_dir=None)
self.assertFalse(cache.exists('my_label'))
+ cache.sink(['my_label'], is_capture=True)
cache.write([TestStreamFileRecord()], 'my_label')
self.assertTrue(cache.exists('my_label'))
+ self.assertEqual(cache.capture_keys, set(['my_label']))
self.assertTrue(cache.clear('my_label'))
self.assertFalse(cache.exists('my_label'))
+ self.assertFalse(cache.capture_keys)
def test_single_reader(self):
"""Tests that we expect to see all the correctly emitted TestStreamPayloads.
@@ -277,16 +287,22 @@ class StreamingCacheTest(unittest.TestCase):
coder = SafeFastPrimitivesCoder()
cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
+ # Assert that there are no capture keys at first.
+ self.assertEqual(cache.capture_keys, set())
+
options = StandardOptions(streaming=True)
with TestPipeline(options=options) as p:
records = (p | test_stream)[CACHED_RECORDS]
# pylint: disable=expression-not-assigned
- records | cache.sink([CACHED_RECORDS])
+ records | cache.sink([CACHED_RECORDS], is_capture=True)
reader, _ = cache.read(CACHED_RECORDS)
actual_events = list(reader)
+ # Assert that the capture keys are forwarded correctly.
+ self.assertEqual(cache.capture_keys, set([CACHED_RECORDS]))
+
# Units here are in microseconds.
expected_events = [
TestStreamPayload.Event(
@@ -413,106 +429,6 @@ class StreamingCacheTest(unittest.TestCase):
self.assertListEqual(actual_events, expected_events)
- def test_single_reader_with_count_limiter(self):
- """Tests that we expect to see all the correctly emitted TestStreamPayloads.
- """
- CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
-
- values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
- .add_element(element=0, event_time_secs=0)
- .advance_processing_time(1)
- .add_element(element=1, event_time_secs=1)
- .advance_processing_time(1)
- .add_element(element=2, event_time_secs=2)
- .build()) # yapf: disable
-
- cache = StreamingCache(cache_dir=None)
- cache.write(values, CACHED_PCOLLECTION_KEY)
-
- reader, _ = cache.read(CACHED_PCOLLECTION_KEY, limiters=[CountLimiter(2)])
- coder = coders.FastPrimitivesCoder()
- events = list(reader)
-
- # Units here are in microseconds.
- # These are a slice of the original values such that we only get two
- # elements.
- expected = [
- TestStreamPayload.Event(
- element_event=TestStreamPayload.Event.AddElements(
- elements=[
- TestStreamPayload.TimestampedElement(
- encoded_element=coder.encode(0), timestamp=0)
- ],
- tag=CACHED_PCOLLECTION_KEY)),
- TestStreamPayload.Event(
- processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
- advance_duration=1 * 10**6)),
- TestStreamPayload.Event(
- element_event=TestStreamPayload.Event.AddElements(
- elements=[
- TestStreamPayload.TimestampedElement(
- encoded_element=coder.encode(1), timestamp=1 * 10**6)
- ],
- tag=CACHED_PCOLLECTION_KEY)),
- TestStreamPayload.Event(
- processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
- advance_duration=1 * 10**6)),
- ]
- self.assertSequenceEqual(events, expected)
-
- def test_single_reader_with_processing_time_limiter(self):
- """Tests that we expect to see all the correctly emitted TestStreamPayloads.
- """
- CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
-
- values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
- .advance_processing_time(1e-6)
- .add_element(element=0, event_time_secs=0)
- .advance_processing_time(1)
- .add_element(element=1, event_time_secs=1)
- .advance_processing_time(1)
- .add_element(element=2, event_time_secs=2)
- .advance_processing_time(1)
- .add_element(element=3, event_time_secs=2)
- .advance_processing_time(1)
- .add_element(element=4, event_time_secs=2)
- .build()) # yapf: disable
-
- cache = StreamingCache(cache_dir=None)
- cache.write(values, CACHED_PCOLLECTION_KEY)
-
- reader, _ = cache.read(
- CACHED_PCOLLECTION_KEY, limiters=[ProcessingTimeLimiter(2)])
- coder = coders.FastPrimitivesCoder()
- events = list(reader)
-
- # Units here are in microseconds.
- # Expects that the elements are a slice of the original values where all
- # processing time is less than the duration.
- expected = [
- TestStreamPayload.Event(
- processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
- advance_duration=1)),
- TestStreamPayload.Event(
- element_event=TestStreamPayload.Event.AddElements(
- elements=[
- TestStreamPayload.TimestampedElement(
- encoded_element=coder.encode(0), timestamp=0)
- ],
- tag=CACHED_PCOLLECTION_KEY)),
- TestStreamPayload.Event(
- processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
- advance_duration=1 * 10**6)),
- TestStreamPayload.Event(
- element_event=TestStreamPayload.Event.AddElements(
- elements=[
- TestStreamPayload.TimestampedElement(
- encoded_element=coder.encode(1), timestamp=1 * 10**6)
- ],
- tag=CACHED_PCOLLECTION_KEY)),
- ]
- self.assertSequenceEqual(events, expected)
-
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index 6810de1..82610e2 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -97,9 +97,7 @@ class ElementStream:
limiters = [
CountLimiter(self._n), ProcessingTimeLimiter(self._duration_secs)
]
- reader, _ = cache_manager.read('full', self._cache_key,
- limiters=limiters,
- tail=tail)
+ reader, _ = cache_manager.read('full', self._cache_key, tail=tail)
# Because a single TestStreamFileRecord can yield multiple elements, we
# limit the count again here in the to_element_list call.
@@ -112,8 +110,14 @@ class ElementStream:
coder,
include_window_info=True,
n=self._n):
+ for l in limiters:
+ l.update(e)
+
yield e
+ if any(l.is_triggered() for l in limiters):
+ break
+
# A limiter being triggered means that we have fulfilled the user's request.
# This implies that reading from the cache again won't yield any new
# elements. WLOG, this applies to the user pipeline being terminated.
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
index 098f249..2078818 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
import collections
import itertools
+import sys
import apache_beam as beam
from apache_beam import coders
@@ -49,19 +50,7 @@ class InMemoryCache(CacheManager):
if not self.exists(*labels):
return itertools.chain([]), -1
- limiters = args.pop('limiters', [])
-
- def limit_reader(r):
- for e in r:
- for l in limiters:
- l.update(e)
-
- if any(l.is_triggered() for l in limiters):
- break
-
- yield e
-
- return limit_reader(itertools.chain(self._cached[self._key(*labels)])), None
+ return itertools.chain(self._cached[self._key(*labels)]), None
def write(self, value, *labels):
if not self.exists(*labels):
@@ -85,6 +74,11 @@ class InMemoryCache(CacheManager):
def sink(self, labels, is_capture=False):
return beam.Map(lambda _: _)
+ def size(self, *labels):
+ if self.exists(*labels):
+ return sys.getsizeof(self._cached[self._key(*labels)])
+ return 0
+
def _key(self, *labels):
return '/'.join([l for l in labels])