You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/08/29 00:47:37 UTC

[beam] branch master updated: Remove unnecessary limiters and add a method to get the size of a PCollection on disk.

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fe474b5  Remove unnecessary limiters and add a method to get the size of a PCollection on disk.
     new 1d25e2e  Merge pull request #12700 from [BEAM-10603] Remove unnecessary limiters and add a method to get the size of a PCollection on disk.
fe474b5 is described below

commit fe474b506277341890aad1aadf1572797a685050
Author: Sam Rohde <sr...@google.com>
AuthorDate: Thu Aug 27 11:02:38 2020 -0700

    Remove unnecessary limiters and add a method to get the size of a PCollection on disk.
    
    Change-Id: I3b9ff420817c78988a283bc7b046d85cc732a09b
---
 .../runners/interactive/cache_manager.py           |  38 +++----
 .../runners/interactive/cache_manager_test.py      |  42 ++++---
 .../runners/interactive/caching/streaming_cache.py |  56 +++++-----
 .../interactive/caching/streaming_cache_test.py    | 122 ++++-----------------
 .../runners/interactive/recording_manager.py       |  10 +-
 .../interactive/testing/test_cache_manager.py      |  20 ++--
 6 files changed, 100 insertions(+), 188 deletions(-)

diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 48f1fc5..d168a76 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -74,10 +74,9 @@ class CacheManager(object):
 
     Args:
       *labels: List of labels for PCollection instance.
-      **args: Dict of additional arguments. Currently only supports 'limiters'
-        as a list of ElementLimiters, and 'tail' as a boolean. Limiters limits
-        the amount of elements read and duration with respect to processing
-        time.
+      **args: Dict of additional arguments. Currently only 'tail' as a boolean.
+        When tail is True, will wait and read new elements until the cache is
+        complete.
 
     Returns:
       A tuple containing an iterator for the items in the PCollection and the
@@ -157,6 +156,12 @@ class CacheManager(object):
     """Cleans up all the PCollection caches."""
     raise NotImplementedError
 
+  def size(self, *labels):
+    # type: (*str) -> int
+
+    """Returns the size of the PCollection on disk in bytes."""
+    raise NotImplementedError
+
 
 class FileBasedCacheManager(CacheManager):
   """Maps PCollections to local temp files for materialization."""
@@ -173,6 +178,7 @@ class FileBasedCacheManager(CacheManager):
       self._cache_dir = tempfile.mkdtemp(
           prefix='it-', dir=os.environ.get('TEST_TMPDIR', None))
     self._versions = collections.defaultdict(lambda: self._CacheVersion())
+    self.cache_format = cache_format
 
     if cache_format not in self._available_formats:
       raise ValueError("Unsupported cache format: '%s'." % cache_format)
@@ -193,6 +199,11 @@ class FileBasedCacheManager(CacheManager):
     # and its PCoder type.
     self._saved_pcoders = {}
 
+  def size(self, *labels):
+    if self.exists(*labels):
+      return sum(os.path.getsize(path) for path in self._match(*labels))
+    return 0
+
   def exists(self, *labels):
     return bool(self._match(*labels))
 
@@ -216,30 +227,13 @@ class FileBasedCacheManager(CacheManager):
     if not self.exists(*labels):
       return iter([]), -1
 
-    limiters = args.pop('limiters', [])
-
     # Otherwise, return a generator to the cached PCollection.
     source = self.source(*labels)._source
     range_tracker = source.get_range_tracker(None, None)
     reader = source.read(range_tracker)
     version = self._latest_version(*labels)
 
-    # The return type is a generator, so in order to implement the limiter for
-    # the FileBasedCacheManager we wrap the original generator with the logic
-    # to limit yielded elements.
-    def limit_reader(r):
-      for e in r:
-        # Update the limiters and break early out of reading from cache if any
-        # are triggered.
-        for l in limiters:
-          l.update(e)
-
-        if any(l.is_triggered() for l in limiters):
-          break
-
-        yield e
-
-    return limit_reader(reader), version
+    return reader, version
 
   def write(self, values, *labels):
     sink = self.sink(labels)._sink
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index e7dc936..12c644c 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -30,7 +30,6 @@ import unittest
 from apache_beam import coders
 from apache_beam.io import filesystems
 from apache_beam.runners.interactive import cache_manager as cache
-from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
 
 
 class FileBasedCacheManagerTest(object):
@@ -92,6 +91,32 @@ class FileBasedCacheManagerTest(object):
     self.mock_write_cache(cache_version_one, prefix, cache_label)
     self.assertTrue(self.cache_manager.exists(prefix, cache_label))
 
+  def test_size(self):
+    """Test getting the size of some cache label."""
+
+    # The Beam API for writing doesn't return the number of bytes that was
+    # written to disk. So this test is only possible when the coder encodes the
+    # bytes that will be written directly to disk, which only the WriteToText
+    # transform does (with respect to the WriteToTFRecord transform).
+    if self.cache_manager.cache_format != 'text':
+      return
+
+    prefix = 'full'
+    cache_label = 'some-cache-label'
+
+    # Test that if nothing is written the size is 0.
+    self.assertEqual(self.cache_manager.size(prefix, cache_label), 0)
+
+    value = 'a'
+    self.mock_write_cache([value], prefix, cache_label)
+    coder = self.cache_manager.load_pcoder(prefix, cache_label)
+    encoded = coder.encode(value)
+
+    # Add one to the size on disk because of the extra new-line character when
+    # writing to file.
+    self.assertEqual(
+        self.cache_manager.size(prefix, cache_label), len(encoded) + 1)
+
   def test_clear(self):
     """Test that CacheManager can correctly tell if the cache exists or not."""
     prefix = 'full'
@@ -193,21 +218,6 @@ class FileBasedCacheManagerTest(object):
     self.assertTrue(
         self.cache_manager.is_latest_version(version, prefix, cache_label))
 
-  def test_read_with_count_limiter(self):
-    """Test the condition where the cache is read once after written once."""
-    prefix = 'full'
-    cache_label = 'some-cache-label'
-    cache_version_one = ['cache', 'version', 'one']
-
-    self.mock_write_cache(cache_version_one, prefix, cache_label)
-    reader, version = self.cache_manager.read(
-        prefix, cache_label, limiters=[CountLimiter(2)])
-    pcoll_list = list(reader)
-    self.assertListEqual(pcoll_list, ['cache', 'version'])
-    self.assertEqual(version, 0)
-    self.assertTrue(
-        self.cache_manager.is_latest_version(version, prefix, cache_label))
-
 
 class TextFileBasedCacheManagerTest(
     FileBasedCacheManagerTest,
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 77f976d..52e8388 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -148,28 +148,18 @@ class StreamingCacheSource:
   This class is used to read from file and send its to the TestStream via the
   StreamingCacheManager.Reader.
   """
-  def __init__(
-      self,
-      cache_dir,
-      labels,
-      is_cache_complete=None,
-      coder=None,
-      limiters=None):
+  def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None):
     if not coder:
       coder = SafeFastPrimitivesCoder()
 
     if not is_cache_complete:
       is_cache_complete = lambda _: True
 
-    if not limiters:
-      limiters = []
-
     self._cache_dir = cache_dir
     self._coder = coder
     self._labels = labels
     self._path = os.path.join(self._cache_dir, *self._labels)
     self._is_cache_complete = is_cache_complete
-    self._limiters = limiters
 
     from apache_beam.runners.interactive.pipeline_instrument import CacheKey
     self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
@@ -221,16 +211,10 @@ class StreamingCacheSource:
         proto_cls = TestStreamFileHeader if pos == 0 else TestStreamFileRecord
         msg = self._try_parse_as(proto_cls, to_decode)
         if msg:
-          for l in self._limiters:
-            l.update(msg)
-
-          if any(l.is_triggered() for l in self._limiters):
-            break
+          yield msg
         else:
           break
 
-        yield msg
-
   def _try_parse_as(self, proto_cls, to_decode):
     try:
       msg = proto_cls()
@@ -288,6 +272,12 @@ class StreamingCache(CacheManager):
     # The sinks to capture data from capturable sources.
     # Dict([str, StreamingCacheSink])
     self._capture_sinks = {}
+    self._capture_keys = set()
+
+  def size(self, *labels):
+    if self.exists(*labels):
+      return os.path.getsize(os.path.join(self._cache_dir, *labels))
+    return 0
 
   @property
   def capture_size(self):
@@ -297,25 +287,26 @@ class StreamingCache(CacheManager):
   def capture_paths(self):
     return list(self._capture_sinks.keys())
 
+  @property
+  def capture_keys(self):
+    return self._capture_keys
+
   def exists(self, *labels):
     path = os.path.join(self._cache_dir, *labels)
     return os.path.exists(path)
 
   # TODO(srohde): Modify this to return the correct version.
   def read(self, *labels, **args):
-    """Returns a generator to read all records from file.
+    """Returns a generator to read all records from file."""
+    tail = args.pop('tail', False)
 
-    Does not tail.
-    """
-    if not self.exists(*labels):
+    # Only immediately return when the file doesn't exist when the user wants a
+    # snapshot of the cache (when tail is false).
+    if not self.exists(*labels) and not tail:
       return iter([]), -1
 
-    limiters = args.pop('limiters', [])
-    tail = args.pop('tail', False)
-
     reader = StreamingCacheSource(
-        self._cache_dir, labels, self._is_cache_complete,
-        limiters=limiters).read(tail=tail)
+        self._cache_dir, labels, self._is_cache_complete).read(tail=tail)
 
     # Return an empty iterator if there is nothing in the file yet. This can
     # only happen when tail is False.
@@ -325,7 +316,7 @@ class StreamingCache(CacheManager):
       return iter([]), -1
     return StreamingCache.Reader([header], [reader]).read(), 1
 
-  def read_multiple(self, labels, limiters=None, tail=True):
+  def read_multiple(self, labels, tail=True):
     """Returns a generator to read all records from file.
 
     Does tail until the cache is complete. This is because it is used in the
@@ -333,9 +324,9 @@ class StreamingCache(CacheManager):
     pipeline runtime which needs to block.
     """
     readers = [
-        StreamingCacheSource(
-            self._cache_dir, l, self._is_cache_complete,
-            limiters=limiters).read(tail=tail) for l in labels
+        StreamingCacheSource(self._cache_dir, l,
+                             self._is_cache_complete).read(tail=tail)
+        for l in labels
     ]
     headers = [next(r) for r in readers]
     return StreamingCache.Reader(headers, readers).read()
@@ -358,6 +349,7 @@ class StreamingCache(CacheManager):
   def clear(self, *labels):
     directory = os.path.join(self._cache_dir, *labels[:-1])
     filepath = os.path.join(directory, labels[-1])
+    self._capture_keys.discard(labels[-1])
     if os.path.exists(filepath):
       os.remove(filepath)
       return True
@@ -383,6 +375,7 @@ class StreamingCache(CacheManager):
     sink = StreamingCacheSink(cache_dir, filename, self._sample_resolution_sec)
     if is_capture:
       self._capture_sinks[sink.path] = sink
+      self._capture_keys.add(filename)
     return sink
 
   def save_pcoder(self, pcoder, *labels):
@@ -398,6 +391,7 @@ class StreamingCache(CacheManager):
       shutil.rmtree(self._cache_dir)
     self._saved_pcoders = {}
     self._capture_sinks = {}
+    self._capture_keys = set()
 
   class Reader(object):
     """Abstraction that reads from PCollection readers.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index 23390cc..2238e0d 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -28,8 +28,6 @@ from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileR
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
 from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
-from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
-from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
 from apache_beam.runners.interactive.pipeline_instrument import CacheKey
 from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -66,13 +64,25 @@ class StreamingCacheTest(unittest.TestCase):
     # Assert that an empty reader returns an empty list.
     self.assertFalse([e for e in reader])
 
+  def test_size(self):
+    cache = StreamingCache(cache_dir=None)
+    cache.write([TestStreamFileRecord()], 'my_label')
+    coder = cache.load_pcoder('my_label')
+
+    # Add one because of the new-line character that is also written.
+    size = len(coder.encode(TestStreamFileRecord().SerializeToString())) + 1
+    self.assertEqual(cache.size('my_label'), size)
+
   def test_clear(self):
     cache = StreamingCache(cache_dir=None)
     self.assertFalse(cache.exists('my_label'))
+    cache.sink(['my_label'], is_capture=True)
     cache.write([TestStreamFileRecord()], 'my_label')
     self.assertTrue(cache.exists('my_label'))
+    self.assertEqual(cache.capture_keys, set(['my_label']))
     self.assertTrue(cache.clear('my_label'))
     self.assertFalse(cache.exists('my_label'))
+    self.assertFalse(cache.capture_keys)
 
   def test_single_reader(self):
     """Tests that we expect to see all the correctly emitted TestStreamPayloads.
@@ -277,16 +287,22 @@ class StreamingCacheTest(unittest.TestCase):
     coder = SafeFastPrimitivesCoder()
     cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
 
+    # Assert that there are no capture keys at first.
+    self.assertEqual(cache.capture_keys, set())
+
     options = StandardOptions(streaming=True)
     with TestPipeline(options=options) as p:
       records = (p | test_stream)[CACHED_RECORDS]
 
       # pylint: disable=expression-not-assigned
-      records | cache.sink([CACHED_RECORDS])
+      records | cache.sink([CACHED_RECORDS], is_capture=True)
 
     reader, _ = cache.read(CACHED_RECORDS)
     actual_events = list(reader)
 
+    # Assert that the capture keys are forwarded correctly.
+    self.assertEqual(cache.capture_keys, set([CACHED_RECORDS]))
+
     # Units here are in microseconds.
     expected_events = [
         TestStreamPayload.Event(
@@ -413,106 +429,6 @@ class StreamingCacheTest(unittest.TestCase):
 
     self.assertListEqual(actual_events, expected_events)
 
-  def test_single_reader_with_count_limiter(self):
-    """Tests that we expect to see all the correctly emitted TestStreamPayloads.
-    """
-    CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
-
-    values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
-              .add_element(element=0, event_time_secs=0)
-              .advance_processing_time(1)
-              .add_element(element=1, event_time_secs=1)
-              .advance_processing_time(1)
-              .add_element(element=2, event_time_secs=2)
-              .build()) # yapf: disable
-
-    cache = StreamingCache(cache_dir=None)
-    cache.write(values, CACHED_PCOLLECTION_KEY)
-
-    reader, _ = cache.read(CACHED_PCOLLECTION_KEY, limiters=[CountLimiter(2)])
-    coder = coders.FastPrimitivesCoder()
-    events = list(reader)
-
-    # Units here are in microseconds.
-    # These are a slice of the original values such that we only get two
-    # elements.
-    expected = [
-        TestStreamPayload.Event(
-            element_event=TestStreamPayload.Event.AddElements(
-                elements=[
-                    TestStreamPayload.TimestampedElement(
-                        encoded_element=coder.encode(0), timestamp=0)
-                ],
-                tag=CACHED_PCOLLECTION_KEY)),
-        TestStreamPayload.Event(
-            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
-                advance_duration=1 * 10**6)),
-        TestStreamPayload.Event(
-            element_event=TestStreamPayload.Event.AddElements(
-                elements=[
-                    TestStreamPayload.TimestampedElement(
-                        encoded_element=coder.encode(1), timestamp=1 * 10**6)
-                ],
-                tag=CACHED_PCOLLECTION_KEY)),
-        TestStreamPayload.Event(
-            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
-                advance_duration=1 * 10**6)),
-    ]
-    self.assertSequenceEqual(events, expected)
-
-  def test_single_reader_with_processing_time_limiter(self):
-    """Tests that we expect to see all the correctly emitted TestStreamPayloads.
-    """
-    CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
-
-    values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
-              .advance_processing_time(1e-6)
-              .add_element(element=0, event_time_secs=0)
-              .advance_processing_time(1)
-              .add_element(element=1, event_time_secs=1)
-              .advance_processing_time(1)
-              .add_element(element=2, event_time_secs=2)
-              .advance_processing_time(1)
-              .add_element(element=3, event_time_secs=2)
-              .advance_processing_time(1)
-              .add_element(element=4, event_time_secs=2)
-              .build()) # yapf: disable
-
-    cache = StreamingCache(cache_dir=None)
-    cache.write(values, CACHED_PCOLLECTION_KEY)
-
-    reader, _ = cache.read(
-        CACHED_PCOLLECTION_KEY, limiters=[ProcessingTimeLimiter(2)])
-    coder = coders.FastPrimitivesCoder()
-    events = list(reader)
-
-    # Units here are in microseconds.
-    # Expects that the elements are a slice of the original values where all
-    # processing time is less than the duration.
-    expected = [
-        TestStreamPayload.Event(
-            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
-                advance_duration=1)),
-        TestStreamPayload.Event(
-            element_event=TestStreamPayload.Event.AddElements(
-                elements=[
-                    TestStreamPayload.TimestampedElement(
-                        encoded_element=coder.encode(0), timestamp=0)
-                ],
-                tag=CACHED_PCOLLECTION_KEY)),
-        TestStreamPayload.Event(
-            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
-                advance_duration=1 * 10**6)),
-        TestStreamPayload.Event(
-            element_event=TestStreamPayload.Event.AddElements(
-                elements=[
-                    TestStreamPayload.TimestampedElement(
-                        encoded_element=coder.encode(1), timestamp=1 * 10**6)
-                ],
-                tag=CACHED_PCOLLECTION_KEY)),
-    ]
-    self.assertSequenceEqual(events, expected)
-
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py
index 6810de1..82610e2 100644
--- a/sdks/python/apache_beam/runners/interactive/recording_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py
@@ -97,9 +97,7 @@ class ElementStream:
     limiters = [
         CountLimiter(self._n), ProcessingTimeLimiter(self._duration_secs)
     ]
-    reader, _ = cache_manager.read('full', self._cache_key,
-                                   limiters=limiters,
-                                   tail=tail)
+    reader, _ = cache_manager.read('full', self._cache_key, tail=tail)
 
     # Because a single TestStreamFileRecord can yield multiple elements, we
     # limit the count again here in the to_element_list call.
@@ -112,8 +110,14 @@ class ElementStream:
                                    coder,
                                    include_window_info=True,
                                    n=self._n):
+      for l in limiters:
+        l.update(e)
+
       yield e
 
+      if any(l.is_triggered() for l in limiters):
+        break
+
     # A limiter being triggered means that we have fulfilled the user's request.
     # This implies that reading from the cache again won't yield any new
     # elements. WLOG, this applies to the user pipeline being terminated.
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
index 098f249..2078818 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 
 import collections
 import itertools
+import sys
 
 import apache_beam as beam
 from apache_beam import coders
@@ -49,19 +50,7 @@ class InMemoryCache(CacheManager):
     if not self.exists(*labels):
       return itertools.chain([]), -1
 
-    limiters = args.pop('limiters', [])
-
-    def limit_reader(r):
-      for e in r:
-        for l in limiters:
-          l.update(e)
-
-        if any(l.is_triggered() for l in limiters):
-          break
-
-        yield e
-
-    return limit_reader(itertools.chain(self._cached[self._key(*labels)])), None
+    return itertools.chain(self._cached[self._key(*labels)]), None
 
   def write(self, value, *labels):
     if not self.exists(*labels):
@@ -85,6 +74,11 @@ class InMemoryCache(CacheManager):
   def sink(self, labels, is_capture=False):
     return beam.Map(lambda _: _)
 
+  def size(self, *labels):
+    if self.exists(*labels):
+      return sys.getsizeof(self._cached[self._key(*labels)])
+    return 0
+
   def _key(self, *labels):
     return '/'.join([l for l in labels])