You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/03/12 06:20:13 UTC
[beam] branch master updated: [BEAM-8335] Modify the StreamingCache
to subclass the CacheManager
This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c69d409 [BEAM-8335] Modify the StreamingCache to subclass the CacheManager
new bb9826c Merge pull request #11005 from [BEAM-8335] Modify the StreamingCache to subclass the CacheManager
c69d409 is described below
commit c69d409429cf4dd3234667150458f7777b5b7f4b
Author: Sam Rohde <ro...@gmail.com>
AuthorDate: Fri Feb 28 13:32:30 2020 -0800
[BEAM-8335] Modify the StreamingCache to subclass the CacheManager
Change-Id: Ib61aa3fac53d9109178744e11eeebe5c5da0929c
---
.../runners/interactive/cache_manager.py | 71 ++++-
.../runners/interactive/cache_manager_test.py | 30 +-
.../runners/interactive/caching/streaming_cache.py | 303 +++++++++++++++++-
.../interactive/caching/streaming_cache_test.py | 338 +++++++++++++++------
.../runners/interactive/display/display_manager.py | 2 +-
.../runners/interactive/interactive_runner_test.py | 6 +-
.../runners/interactive/pipeline_fragment_test.py | 2 +-
.../interactive/pipeline_instrument_test.py | 28 +-
.../interactive/testing/test_cache_manager.py | 119 ++++++++
9 files changed, 733 insertions(+), 166 deletions(-)
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 7274015..09c44d6 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -51,26 +51,34 @@ class CacheManager(object):
derivation.
"""
def exists(self, *labels):
+ # type (*str) -> bool
+
"""Returns if the PCollection cache exists."""
raise NotImplementedError
def is_latest_version(self, version, *labels):
+ # type (str, *str) -> bool
+
"""Returns if the given version number is the latest."""
return version == self._latest_version(*labels)
def _latest_version(self, *labels):
+ # type (*str) -> str
+
"""Returns the latest version number of the PCollection cache."""
raise NotImplementedError
def read(self, *labels):
+ # type (*str) -> Tuple[str, Generator[Any]]
+
"""Return the PCollection as a list as well as the version number.
Args:
*labels: List of labels for PCollection instance.
Returns:
- Tuple[List[Any], int]: A tuple containing a list of items in the
- PCollection and the version number.
+ A tuple containing an iterator for the items in the PCollection and the
+ version number.
It is possible that the version numbers from read() and_latest_version()
are different. This usually means that the cache's been evicted (thus
@@ -79,15 +87,32 @@ class CacheManager(object):
"""
raise NotImplementedError
+ def write(self, value, *labels):
+ # type (Any, *str) -> None
+
+ """Writes the value to the given cache.
+
+ Args:
+ value: An encodable (with corresponding PCoder) value
+ *labels: List of labels for PCollection instance
+ """
+ raise NotImplementedError
+
def source(self, *labels):
- """Returns a beam.io.Source that reads the PCollection cache."""
+ # type (*str) -> ptransform.PTransform
+
+ """Returns a PTransform that reads the PCollection cache."""
raise NotImplementedError
- def sink(self, *labels):
- """Returns a beam.io.Sink that writes the PCollection cache."""
+ def sink(self, labels):
+ # type (*str) -> ptransform.PTransform
+
+ """Returns a PTransform that writes the PCollection cache."""
raise NotImplementedError
def save_pcoder(self, pcoder, *labels):
+ # type (coders.Coder, *str) -> None
+
"""Saves pcoder for given PCollection.
Correct reading of PCollection from Cache requires PCoder to be known.
@@ -103,10 +128,14 @@ class CacheManager(object):
raise NotImplementedError
def load_pcoder(self, *labels):
+ # type (*str) -> coders.Coder
+
"""Returns previously saved PCoder for reading and writing PCollection."""
raise NotImplementedError
def cleanup(self):
+ # type () -> None
+
"""Cleans up all the PCollection caches."""
raise NotImplementedError
@@ -167,22 +196,34 @@ class FileBasedCacheManager(CacheManager):
self._saved_pcoders[self._path(*labels)])
def read(self, *labels):
+ # Return an iterator to an empty list if it doesn't exist.
if not self.exists(*labels):
- return [], -1
+ return iter([]), -1
- source = self.source(*labels)
+ # Otherwise, return a generator to the cached PCollection.
+ source = self.source(*labels)._source
range_tracker = source.get_range_tracker(None, None)
- result = list(source.read(range_tracker))
+ reader = source.read(range_tracker)
version = self._latest_version(*labels)
- return result, version
+ return reader, version
+
+ def write(self, values, *labels):
+ sink = self.sink(labels)._sink
+ path = self._path(*labels)
+
+ init_result = sink.initialize_write()
+ writer = sink.open_writer(init_result, path)
+ for v in values:
+ writer.write(v)
+ writer.close()
def source(self, *labels):
return self._reader_class(
- self._glob_path(*labels), coder=self.load_pcoder(*labels))._source
+ self._glob_path(*labels), coder=self.load_pcoder(*labels))
- def sink(self, *labels):
+ def sink(self, labels):
return self._writer_class(
- self._path(*labels), coder=self.load_pcoder(*labels))._sink
+ self._path(*labels), coder=self.load_pcoder(*labels))
def cleanup(self):
if filesystems.FileSystems.exists(self._cache_dir):
@@ -229,8 +270,7 @@ class ReadCache(beam.PTransform):
def expand(self, pbegin):
# pylint: disable=expression-not-assigned
- return pbegin | 'Read' >> beam.io.Read(
- self._cache_manager.source('full', self._label))
+ return pbegin | 'Read' >> self._cache_manager.source('full', self._label)
class WriteCache(beam.PTransform):
@@ -255,8 +295,7 @@ class WriteCache(beam.PTransform):
combiners.Sample.FixedSizeGlobally(self._sample_size)
| beam.FlatMap(lambda sample: sample))
# pylint: disable=expression-not-assigned
- return pcoll | 'Write' >> beam.io.Write(
- self._cache_manager.sink(prefix, self._label))
+ return pcoll | 'Write' >> self._cache_manager.sink((prefix, self._label))
class SafeFastPrimitivesCoder(coders.Coder):
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index 6ab51b4..7868e90 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -71,12 +71,11 @@ class FileBasedCacheManagerTest(object):
# Usually, the pcoder will be inferred from `pcoll.element_type`
pcoder = coders.registry.get_coder(object)
+ # Save a pcoder for reading.
self.cache_manager.save_pcoder(pcoder, *labels)
- sink = self.cache_manager.sink(*labels)
-
- with open(self.cache_manager._path(prefix, cache_file), 'wb') as f:
- for line in pcoll_list:
- sink.write_record(f, line)
+ # Save a pcoder for the fake write to the file.
+ self.cache_manager.save_pcoder(pcoder, prefix, cache_file)
+ self.cache_manager.write(pcoll_list, prefix, cache_file)
def test_exists(self):
"""Test that CacheManager can correctly tell if the cache exists or not."""
@@ -99,7 +98,8 @@ class FileBasedCacheManagerTest(object):
cache_version_one = ['cache', 'version', 'one']
self.mock_write_cache(cache_version_one, prefix, cache_label)
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
self.assertListEqual(pcoll_list, cache_version_one)
self.assertEqual(version, 0)
self.assertTrue(
@@ -113,13 +113,15 @@ class FileBasedCacheManagerTest(object):
cache_version_two = ['cache', 'version', 'two']
self.mock_write_cache(cache_version_one, prefix, cache_label)
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
self.mock_write_cache(cache_version_two, prefix, cache_label)
self.assertFalse(
self.cache_manager.is_latest_version(version, prefix, cache_label))
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
self.assertListEqual(pcoll_list, cache_version_two)
self.assertEqual(version, 1)
self.assertTrue(
@@ -132,7 +134,8 @@ class FileBasedCacheManagerTest(object):
self.assertFalse(self.cache_manager.exists(prefix, cache_label))
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
self.assertListEqual(pcoll_list, [])
self.assertEqual(version, -1)
self.assertTrue(
@@ -147,7 +150,8 @@ class FileBasedCacheManagerTest(object):
# The initial write and read.
self.mock_write_cache(cache_version_one, prefix, cache_label)
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
# Cache cleanup.
self.cache_manager.cleanup()
@@ -155,7 +159,8 @@ class FileBasedCacheManagerTest(object):
self.assertTrue(
self.cache_manager.is_latest_version(version, prefix, cache_label))
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
self.assertListEqual(pcoll_list, [])
self.assertEqual(version, -1)
self.assertFalse(
@@ -166,7 +171,8 @@ class FileBasedCacheManagerTest(object):
self.assertFalse(
self.cache_manager.is_latest_version(version, prefix, cache_label))
- pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+ reader, version = self.cache_manager.read(prefix, cache_label)
+ pcoll_list = list(reader)
self.assertListEqual(pcoll_list, cache_version_two)
# Check that version continues from the previous value instead of starting
# from 0 again.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 0aabda3..c313404 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -19,15 +19,295 @@
from __future__ import absolute_import
+import os
+import shutil
+import tempfile
+import time
+from collections import OrderedDict
+
+import apache_beam as beam
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.cache_manager import CacheManager
+from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
+from apache_beam.testing.test_stream import OutputFormat
+from apache_beam.testing.test_stream import ReverseTestStream
from apache_beam.utils import timestamp
+try:
+ from pathlib import Path
+except ImportError:
+ from pathlib2 import Path # python 2 backport
+
+
+class StreamingCacheSink(beam.PTransform):
+ """A PTransform that writes TestStreamFile(Header|Records)s to file.
+
+ This transform takes in an arbitrary element stream and writes the list of
+ TestStream events (as TestStreamFileRecords) to file. When replayed, this
+ will produce the best-effort replay of the original job (e.g. some elements
+ may be produced slightly out of order from the original stream).
+
+ Note that this PTransform is assumed to be only run on a single machine where
+ the following assumptions are correct: elements come in ordered, no two
+ transforms are writing to the same file. This PTransform is assumed to only
+ run correctly with the DirectRunner.
+
+ TODO(BEAM-9447): Generalize this to more source/sink types aside from file
+ based. Also, generalize to cases where there might be multiple workers
+ writing to the same sink.
+ """
+ def __init__(
+ self,
+ cache_dir,
+ filename,
+ sample_resolution_sec,
+ coder=SafeFastPrimitivesCoder()):
+ self._cache_dir = cache_dir
+ self._filename = filename
+ self._sample_resolution_sec = sample_resolution_sec
+ self._coder = coder
+ self._path = os.path.join(self._cache_dir, self._filename)
+
+ @property
+ def path(self):
+ """Returns the path the sink leads to."""
+ return self._path
+
+ def expand(self, pcoll):
+ class StreamingWriteToText(beam.DoFn):
+ """DoFn that performs the writing.
+
+ Note that the other file writing methods cannot be used in streaming
+ contexts.
+ """
+ def __init__(self, full_path, coder=SafeFastPrimitivesCoder()):
+ self._full_path = full_path
+ self._coder = coder
+
+ # Try and make the given path.
+ Path(os.path.dirname(full_path)).mkdir(exist_ok=True)
+
+ def start_bundle(self):
+ # Open the file for 'append-mode' and writing 'bytes'.
+ self._fh = open(self._full_path, 'ab')
+
+ def finish_bundle(self):
+ self._fh.close()
+
+ def process(self, e):
+ """Appends the given element to the file.
+ """
+ self._fh.write(self._coder.encode(e) + b'\n')
+
+ return (
+ pcoll
+ | ReverseTestStream(
+ output_tag=self._filename,
+ sample_resolution_sec=self._sample_resolution_sec,
+ output_format=OutputFormat.SERIALIZED_TEST_STREAM_FILE_RECORDS,
+ coder=self._coder)
+ | beam.ParDo(
+ StreamingWriteToText(full_path=self._path, coder=self._coder)))
-class StreamingCache(object):
+
+class StreamingCacheSource:
+ """A class that reads and parses TestStreamFile(Header|Reader)s.
+
+ This source operates in the following way:
+
+ 1. Wait for up to `timeout_secs` for the file to be available.
+ 2. Read, parse, and emit the entire contents of the file
+ 3. Wait for more events to come or until `is_cache_complete` returns True
+ 4. If there are more events, then go to 2
+ 5. Otherwise, stop emitting.
+
+ This class is used to read from file and send its to the TestStream via the
+ StreamingCacheManager.Reader.
+ """
+ def __init__(
+ self,
+ cache_dir,
+ labels,
+ is_cache_complete=None,
+ coder=SafeFastPrimitivesCoder()):
+ self._cache_dir = cache_dir
+ self._coder = coder
+ self._labels = labels
+ self._is_cache_complete = (
+ is_cache_complete if is_cache_complete else lambda: True)
+
+ def _wait_until_file_exists(self, timeout_secs=30):
+ """Blocks until the file exists for a maximum of timeout_secs.
+ """
+ now_secs = time.time()
+ timeout_timestamp_secs = now_secs + timeout_secs
+
+ # Wait for up to `timeout_secs` for the file to be available.
+ start = time.time()
+ path = os.path.join(self._cache_dir, *self._labels)
+ while not os.path.exists(path):
+ time.sleep(1)
+ if time.time() - start > timeout_timestamp_secs:
+ raise RuntimeError(
+ "Timed out waiting for file '{}' to be available".format(path))
+ return open(path, mode='rb')
+
+ def _emit_from_file(self, fh, tail):
+ """Emits the TestStreamFile(Header|Record)s from file.
+
+ This returns a generator to be able to read all lines from the given file.
+ If `tail` is True, then it will wait until the cache is complete to exit.
+ Otherwise, it will read the file only once.
+ """
+ # Always read at least once to read the whole file.
+ while True:
+ pos = fh.tell()
+ line = fh.readline()
+
+ # Check if we are at EOF or if we have an incomplete line.
+ if not line or (line and line[-1] != b'\n'[0]):
+ # Complete reading only when the cache is complete.
+ if self._is_cache_complete():
+ break
+
+ if not tail:
+ break
+
+ # Otherwise wait for new data in the file to be written.
+ time.sleep(0.5)
+ fh.seek(pos)
+ else:
+ # The first line at pos = 0 is always the header. Read the line without
+ # the new line.
+ to_decode = line[:-1]
+ if pos == 0:
+ header = TestStreamFileHeader()
+ header.ParseFromString(self._coder.decode(to_decode))
+ yield header
+ else:
+ record = TestStreamFileRecord()
+ record.ParseFromString(self._coder.decode(to_decode))
+ yield record
+
+ def read(self, tail):
+ """Reads all TestStreamFile(Header|TestStreamFileRecord)s from file.
+
+ This returns a generator to be able to read all lines from the given file.
+ If `tail` is True, then it will wait until the cache is complete to exit.
+ Otherwise, it will read the file only once.
+ """
+ with self._wait_until_file_exists() as f:
+ for e in self._emit_from_file(f, tail):
+ yield e
+
+
+class StreamingCache(CacheManager):
"""Abstraction that holds the logic for reading and writing to cache.
"""
- def __init__(self, readers):
- self._readers = readers
+ def __init__(
+ self, cache_dir, is_cache_complete=None, sample_resolution_sec=0.1):
+ self._sample_resolution_sec = sample_resolution_sec
+ self._is_cache_complete = is_cache_complete
+
+ if cache_dir:
+ self._cache_dir = cache_dir
+ else:
+ self._cache_dir = tempfile.mkdtemp(
+ prefix='interactive-temp-', dir=os.environ.get('TEST_TMPDIR', None))
+
+ # List of saved pcoders keyed by PCollection path. It is OK to keep this
+ # list in memory because once FileBasedCacheManager object is
+ # destroyed/re-created it loses the access to previously written cache
+ # objects anyways even if cache_dir already exists. In other words,
+ # it is not possible to resume execution of Beam pipeline from the
+ # saved cache if FileBasedCacheManager has been reset.
+ #
+ # However, if we are to implement better cache persistence, one needs
+ # to take care of keeping consistency between the cached PCollection
+ # and its PCoder type.
+ self._saved_pcoders = {}
+ self._default_pcoder = SafeFastPrimitivesCoder()
+
+ def exists(self, *labels):
+ path = os.path.join(self._cache_dir, *labels)
+ return os.path.exists(path)
+
+ # TODO(srohde): Modify this to return the correct version.
+ def read(self, *labels):
+ """Returns a generator to read all records from file.
+
+ Does not tail.
+ """
+ if not self.exists(*labels):
+ return iter([]), -1
+
+ reader = StreamingCacheSource(
+ self._cache_dir, labels,
+ is_cache_complete=self._is_cache_complete).read(tail=False)
+ header = next(reader)
+ return StreamingCache.Reader([header], [reader]).read(), 1
+
+ def read_multiple(self, labels):
+ """Returns a generator to read all records from file.
+
+ Does tail until the cache is complete. This is because it is used in the
+ TestStreamServiceController to read from file which is only used during
+ pipeline runtime which needs to block.
+ """
+ readers = [
+ StreamingCacheSource(
+ self._cache_dir, l,
+ is_cache_complete=self._is_cache_complete).read(tail=True)
+ for l in labels
+ ]
+ headers = [next(r) for r in readers]
+ return StreamingCache.Reader(headers, readers).read()
+
+ def write(self, values, *labels):
+ """Writes the given values to cache.
+ """
+ directory = os.path.join(self._cache_dir, *labels[:-1])
+ filepath = os.path.join(directory, labels[-1])
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ with open(filepath, 'ab') as f:
+ for v in values:
+ f.write(self._default_pcoder.encode(v.SerializeToString()) + b'\n')
+
+ def source(self, *labels):
+ """Returns the StreamingCacheManager source.
+
+ This is beam.Impulse() because unbounded sources will be marked with this
+ and then the PipelineInstrument will replace these with a TestStream.
+ """
+ return beam.Impulse()
+
+ def sink(self, labels):
+ """Returns a StreamingCacheSink to write elements to file.
+
+ Note that this is assumed to only work in the DirectRunner as the underlying
+ StreamingCacheSink assumes a single machine to have correct element
+ ordering.
+ """
+ filename = labels[-1]
+ cache_dir = os.path.join(self._cache_dir, *labels[:-1])
+ return StreamingCacheSink(cache_dir, filename, self._sample_resolution_sec)
+
+ def save_pcoder(self, pcoder, *labels):
+ self._saved_pcoders[os.path.join(*labels)] = pcoder
+
+ def load_pcoder(self, *labels):
+ return (
+ self._default_pcoder if self._default_pcoder is not None else
+ self._saved_pcoders[os.path.join(*labels)])
+
+ def cleanup(self):
+ if os.path.exists(self._cache_dir):
+ shutil.rmtree(self._cache_dir)
+ self._saved_pcoders = {}
+ self._capture_sinks = {}
class Reader(object):
"""Abstraction that reads from PCollection readers.
@@ -38,7 +318,7 @@ class StreamingCache(object):
This class is also responsible for holding the state of the clock, injecting
clock advancement events, and watermark advancement events.
"""
- def __init__(self, readers):
+ def __init__(self, headers, readers):
# This timestamp is used as the monotonic clock to order events in the
# replay.
self._monotonic_clock = timestamp.Timestamp.of(0)
@@ -49,8 +329,9 @@ class StreamingCache(object):
# The file headers that are metadata for that particular PCollection.
# The header allows for metadata about an entire stream, so that the data
# isn't copied per record.
- self._headers = {r.header().tag: r.header() for r in readers}
- self._readers = {r.header().tag: r.read() for r in readers}
+ self._headers = {header.tag: header for header in headers}
+ self._readers = OrderedDict(
+ ((h.tag, r) for (h, r) in zip(headers, readers)))
# The most recently read timestamp per tag.
self._stream_times = {
@@ -141,13 +422,6 @@ class StreamingCache(object):
unsent_events = events_to_send
target_timestamp = self._min_timestamp_of(unsent_events)
- def _add_element(self, element, tag):
- """Constructs an AddElement event for the specified element and tag.
- """
- return TestStreamPayload.Event(
- element_event=TestStreamPayload.Event.AddElements(
- elements=[element], tag=tag))
-
def _advance_processing_time(self, new_timestamp):
"""Advances the internal clock and returns an AdvanceProcessingTime event.
"""
@@ -157,6 +431,3 @@ class StreamingCache(object):
advance_duration=advancy_by))
self._monotonic_clock = new_timestamp
return e
-
- def reader(self):
- return StreamingCache.Reader(self._readers)
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index f9f450e..002a05a 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -22,12 +22,18 @@ from __future__ import absolute_import
import unittest
from apache_beam import coders
+from apache_beam.options.pipeline_options import DebugOptions
+from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
-from apache_beam.utils.timestamp import Duration
-from apache_beam.utils.timestamp import Timestamp
+from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.test_stream import TestStream
+from apache_beam.testing.util import *
+from apache_beam.transforms.window import TimestampedValue
# Nose automatically detects tests if they match a regex. Here, it mistakens
# these protos as tests. For more info see the Nose docs at:
@@ -37,76 +43,45 @@ TestStreamFileHeader.__test__ = False # type: ignore[attr-defined]
TestStreamFileRecord.__test__ = False # type: ignore[attr-defined]
-class InMemoryReader(object):
- def __init__(self, tag=None):
- self._header = TestStreamFileHeader(tag=tag)
- self._records = []
- self._coder = coders.FastPrimitivesCoder()
-
- def add_element(self, element, event_time):
- element_payload = TestStreamPayload.TimestampedElement(
- encoded_element=self._coder.encode(element),
- timestamp=Timestamp.of(event_time).micros)
- record = TestStreamFileRecord(
- recorded_event=TestStreamPayload.Event(
- element_event=TestStreamPayload.Event.AddElements(
- elements=[element_payload])))
- self._records.append(record)
-
- def advance_watermark(self, watermark):
- record = TestStreamFileRecord(
- recorded_event=TestStreamPayload.Event(
- watermark_event=TestStreamPayload.Event.AdvanceWatermark(
- new_watermark=Timestamp.of(watermark).micros)))
- self._records.append(record)
-
- def advance_processing_time(self, processing_time_delta):
- record = TestStreamFileRecord(
- recorded_event=TestStreamPayload.Event(
- processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
- advance_duration=Duration.of(processing_time_delta).micros)))
- self._records.append(record)
-
- def header(self):
- return self._header
-
- def read(self):
- for r in self._records:
- yield r
-
-
-def all_events(reader):
- events = []
- for e in reader.read():
- events.append(e)
- return events
-
-
class StreamingCacheTest(unittest.TestCase):
def setUp(self):
pass
+ def test_exists(self):
+ cache = StreamingCache(cache_dir=None)
+ self.assertFalse(cache.exists('my_label'))
+ cache.write([TestStreamFileRecord()], 'my_label')
+ self.assertTrue(cache.exists('my_label'))
+
def test_single_reader(self):
"""Tests that we expect to see all the correctly emitted TestStreamPayloads.
"""
- in_memory_reader = InMemoryReader()
- in_memory_reader.add_element(element=0, event_time=0)
- in_memory_reader.advance_processing_time(1)
- in_memory_reader.add_element(element=1, event_time=1)
- in_memory_reader.advance_processing_time(1)
- in_memory_reader.add_element(element=2, event_time=2)
- cache = StreamingCache([in_memory_reader])
- reader = cache.reader()
+ CACHED_PCOLLECTION_KEY = 'arbitrary_key'
+
+ values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+ .add_element(element=0, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=1, event_time_secs=1)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=2)
+ .build()) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None)
+ cache.write(values, CACHED_PCOLLECTION_KEY)
+
+ reader, _ = cache.read(CACHED_PCOLLECTION_KEY)
coder = coders.FastPrimitivesCoder()
- events = all_events(reader)
+ events = list(reader)
+ # Units here are in microseconds.
expected = [
TestStreamPayload.Event(
element_event=TestStreamPayload.Event.AddElements(
elements=[
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode(0), timestamp=0)
- ])),
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
TestStreamPayload.Event(
processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
advance_duration=1 * 10**6)),
@@ -115,7 +90,8 @@ class StreamingCacheTest(unittest.TestCase):
elements=[
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode(1), timestamp=1 * 10**6)
- ])),
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
TestStreamPayload.Event(
processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
advance_duration=1 * 10**6)),
@@ -124,38 +100,53 @@ class StreamingCacheTest(unittest.TestCase):
elements=[
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode(2), timestamp=2 * 10**6)
- ])),
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
]
self.assertSequenceEqual(events, expected)
def test_multiple_readers(self):
- """Tests that the service advances the clock with multiple outputs."""
-
- letters = InMemoryReader('letters')
- letters.advance_processing_time(1)
- letters.advance_watermark(0)
- letters.add_element(element='a', event_time=0)
- letters.advance_processing_time(10)
- letters.advance_watermark(10)
- letters.add_element(element='b', event_time=10)
-
- numbers = InMemoryReader('numbers')
- numbers.advance_processing_time(2)
- numbers.add_element(element=1, event_time=0)
- numbers.advance_processing_time(1)
- numbers.add_element(element=2, event_time=0)
- numbers.advance_processing_time(1)
- numbers.add_element(element=2, event_time=0)
-
- late = InMemoryReader('late')
- late.advance_processing_time(101)
- late.add_element(element='late', event_time=0)
-
- cache = StreamingCache([letters, numbers, late])
- reader = cache.reader()
+ """Tests that the service advances the clock with multiple outputs.
+ """
+
+ CACHED_LETTERS = 'letters'
+ CACHED_NUMBERS = 'numbers'
+ CACHED_LATE = 'late'
+
+ letters = (FileRecordsBuilder(CACHED_LETTERS)
+ .advance_processing_time(1)
+ .advance_watermark(watermark_secs=0)
+ .add_element(element='a', event_time_secs=0)
+ .advance_processing_time(10)
+ .advance_watermark(watermark_secs=10)
+ .add_element(element='b', event_time_secs=10)
+ .build()) # yapf: disable
+
+ numbers = (FileRecordsBuilder(CACHED_NUMBERS)
+ .advance_processing_time(2)
+ .add_element(element=1, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=0)
+ .build()) # yapf: disable
+
+ late = (FileRecordsBuilder(CACHED_LATE)
+ .advance_processing_time(101)
+ .add_element(element='late', event_time_secs=0)
+ .build()) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None)
+ cache.write(letters, CACHED_LETTERS)
+ cache.write(numbers, CACHED_NUMBERS)
+ cache.write(late, CACHED_LATE)
+
+ reader = cache.read_multiple([[CACHED_LETTERS], [CACHED_NUMBERS],
+ [CACHED_LATE]])
coder = coders.FastPrimitivesCoder()
- events = all_events(reader)
+ events = list(reader)
+ # Units here are in microseconds.
expected = [
# Advances clock from 0 to 1
TestStreamPayload.Event(
@@ -163,14 +154,14 @@ class StreamingCacheTest(unittest.TestCase):
advance_duration=1 * 10**6)),
TestStreamPayload.Event(
watermark_event=TestStreamPayload.Event.AdvanceWatermark(
- new_watermark=0, tag='letters')),
+ new_watermark=0, tag=CACHED_LETTERS)),
TestStreamPayload.Event(
element_event=TestStreamPayload.Event.AddElements(
elements=[
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode('a'), timestamp=0)
],
- tag='letters')),
+ tag=CACHED_LETTERS)),
# Advances clock from 1 to 2
TestStreamPayload.Event(
@@ -182,7 +173,7 @@ class StreamingCacheTest(unittest.TestCase):
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode(1), timestamp=0)
],
- tag='numbers')),
+ tag=CACHED_NUMBERS)),
# Advances clock from 2 to 3
TestStreamPayload.Event(
@@ -194,7 +185,7 @@ class StreamingCacheTest(unittest.TestCase):
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode(2), timestamp=0)
],
- tag='numbers')),
+ tag=CACHED_NUMBERS)),
# Advances clock from 3 to 4
TestStreamPayload.Event(
@@ -206,7 +197,7 @@ class StreamingCacheTest(unittest.TestCase):
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode(2), timestamp=0)
],
- tag='numbers')),
+ tag=CACHED_NUMBERS)),
# Advances clock from 4 to 11
TestStreamPayload.Event(
@@ -214,14 +205,14 @@ class StreamingCacheTest(unittest.TestCase):
advance_duration=7 * 10**6)),
TestStreamPayload.Event(
watermark_event=TestStreamPayload.Event.AdvanceWatermark(
- new_watermark=10 * 10**6, tag='letters')),
+ new_watermark=10 * 10**6, tag=CACHED_LETTERS)),
TestStreamPayload.Event(
element_event=TestStreamPayload.Event.AddElements(
elements=[
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode('b'), timestamp=10 * 10**6)
],
- tag='letters')),
+ tag=CACHED_LETTERS)),
# Advances clock from 11 to 101
TestStreamPayload.Event(
@@ -233,11 +224,174 @@ class StreamingCacheTest(unittest.TestCase):
TestStreamPayload.TimestampedElement(
encoded_element=coder.encode('late'), timestamp=0)
],
- tag='late')),
+ tag=CACHED_LATE)),
]
self.assertSequenceEqual(events, expected)
+ def test_read_and_write(self):
+ """An integration test between the Sink and Source.
+
+ This ensures that the sink and source speak the same language in terms of
+ coders, protos, order, and units.
+ """
+
+ # Units here are in seconds.
+ test_stream = (TestStream()
+ .advance_watermark_to(0, tag='records')
+ .advance_processing_time(5)
+ .add_elements(['a', 'b', 'c'], tag='records')
+ .advance_watermark_to(10, tag='records')
+ .advance_processing_time(1)
+ .add_elements(
+ [
+ TimestampedValue('1', 15),
+ TimestampedValue('2', 15),
+ TimestampedValue('3', 15)
+ ],
+ tag='records')) # yapf: disable
+
+ coder = SafeFastPrimitivesCoder()
+ cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
+
+ options = StandardOptions(streaming=True)
+ options.view_as(DebugOptions).add_experiment(
+ 'passthrough_pcollection_output_ids')
+ with TestPipeline(options=options) as p:
+ # pylint: disable=expression-not-assigned
+ p | test_stream | cache.sink(['records'])
+
+ reader, _ = cache.read('records')
+ actual_events = list(reader)
+
+ # Units here are in microseconds.
+ expected_events = [
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=5 * 10**6)),
+ TestStreamPayload.Event(
+ watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=0, tag='records')),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('a'), timestamp=0),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('b'), timestamp=0),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('c'), timestamp=0),
+ ],
+ tag='records')),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ TestStreamPayload.Event(
+ watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=10 * 10**6, tag='records')),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('1'), timestamp=15 *
+ 10**6),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('2'), timestamp=15 *
+ 10**6),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('3'), timestamp=15 *
+ 10**6),
+ ],
+ tag='records')),
+ ]
+ self.assertEqual(actual_events, expected_events)
+
+ def test_read_and_write_multiple_outputs(self):
+ """An integration test between the Sink and Source with multiple outputs.
+
+ This tests the funcionatlity that the StreamingCache reads from multiple
+ files and combines them into a single sorted output.
+ """
+ LETTERS_TAG = 'letters'
+ NUMBERS_TAG = 'numbers'
+
+ # Units here are in seconds.
+ test_stream = (TestStream()
+ .advance_watermark_to(0, tag=LETTERS_TAG)
+ .advance_processing_time(5)
+ .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG)
+ .advance_watermark_to(10, tag=NUMBERS_TAG)
+ .advance_processing_time(1)
+ .add_elements(
+ [
+ TimestampedValue('1', 15),
+ TimestampedValue('2', 15),
+ TimestampedValue('3', 15)
+ ],
+ tag=NUMBERS_TAG)) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
+
+ coder = SafeFastPrimitivesCoder()
+
+ options = StandardOptions(streaming=True)
+ options.view_as(DebugOptions).add_experiment(
+ 'passthrough_pcollection_output_ids')
+ with TestPipeline(options=options) as p:
+ # pylint: disable=expression-not-assigned
+ events = p | test_stream
+ events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG])
+ events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG])
+
+ reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]])
+ actual_events = list(reader)
+
+ # Units here are in microseconds.
+ expected_events = [
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=5 * 10**6)),
+ TestStreamPayload.Event(
+ watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=0, tag=LETTERS_TAG)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('a'), timestamp=0),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('b'), timestamp=0),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('c'), timestamp=0),
+ ],
+ tag=LETTERS_TAG)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ TestStreamPayload.Event(
+ watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=10 * 10**6, tag=NUMBERS_TAG)),
+ TestStreamPayload.Event(
+ watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=0, tag=LETTERS_TAG)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('1'), timestamp=15 *
+ 10**6),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('2'), timestamp=15 *
+ 10**6),
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode('3'), timestamp=15 *
+ 10**6),
+ ],
+ tag=NUMBERS_TAG)),
+ ]
+
+ self.assertListEqual(actual_events, expected_events)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/display/display_manager.py b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
index 1b7af26..3c58ce2 100644
--- a/sdks/python/apache_beam/runners/interactive/display/display_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
@@ -142,7 +142,7 @@ class DisplayManager(object):
if force or not self._cache_manager.is_latest_version(
version, 'sample', cache_label):
pcoll_list, version = self._cache_manager.read('sample', cache_label)
- stats['sample'] = pcoll_list
+ stats['sample'] = list(pcoll_list)
stats['version'] = version
stats_updated = True
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index 7146067..3e1ad80 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -150,11 +150,11 @@ class InteractiveRunnerTest(unittest.TestCase):
ib.watch(locals())
result = p.run()
self.assertTrue(init in ie.current_env().computed_pcollections)
- self.assertEqual([0, 1, 2, 3, 4], result.get(init))
+ self.assertEqual([0, 1, 2, 3, 4], list(result.get(init)))
self.assertTrue(square in ie.current_env().computed_pcollections)
- self.assertEqual([0, 1, 4, 9, 16], result.get(square))
+ self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
self.assertTrue(cube in ie.current_env().computed_pcollections)
- self.assertEqual([0, 1, 8, 27, 64], result.get(cube))
+ self.assertEqual([0, 1, 8, 27, 64], list(result.get(cube)))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py
index cc5f5de..d8ab43c 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py
@@ -115,7 +115,7 @@ class PipelineFragmentTest(unittest.TestCase):
ib.watch(locals())
result = pf.PipelineFragment([square]).run()
- self.assertEqual([0, 1, 4, 9, 16], result.get(square))
+ self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
index a936850..771f015 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
@@ -21,12 +21,10 @@
from __future__ import absolute_import
import tempfile
-import time
import unittest
import apache_beam as beam
from apache_beam import coders
-from apache_beam.io import filesystems
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive import interactive_beam as ib
@@ -35,17 +33,12 @@ from apache_beam.runners.interactive import pipeline_instrument as instr
from apache_beam.runners.interactive import interactive_runner
from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_equal
from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal
-
-# Work around nose tests using Python2 without unittest.mock module.
-try:
- from unittest.mock import MagicMock
-except ImportError:
- from mock import MagicMock
+from apache_beam.runners.interactive.testing.test_cache_manager import InMemoryCache
class PipelineInstrumentTest(unittest.TestCase):
def setUp(self):
- ie.new_env(cache_manager=cache.FileBasedCacheManager())
+ ie.new_env(cache_manager=InMemoryCache())
def test_pcolls_to_pcoll_id(self):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
@@ -205,26 +198,12 @@ class PipelineInstrumentTest(unittest.TestCase):
def _mock_write_cache(self, pcoll, cache_key):
"""Cache the PCollection where cache.WriteCache would write to."""
- cache_path = filesystems.FileSystems.join(
- ie.current_env().cache_manager()._cache_dir, 'full')
- if not filesystems.FileSystems.exists(cache_path):
- filesystems.FileSystems.mkdirs(cache_path)
-
- # Pause for 0.1 sec, because the Jenkins test runs so fast that the file
- # writes happen at the same timestamp.
- time.sleep(0.1)
-
- cache_file = cache_key + '-1-of-2'
labels = ['full', cache_key]
# Usually, the pcoder will be inferred from `pcoll.element_type`
pcoder = coders.registry.get_coder(object)
ie.current_env().cache_manager().save_pcoder(pcoder, *labels)
- sink = ie.current_env().cache_manager().sink(*labels)
-
- with open(ie.current_env().cache_manager()._path('full', cache_file),
- 'wb') as f:
- sink.write_record(f, pcoll)
+ ie.current_env().cache_manager().write([b''], *labels)
def test_instrument_example_pipeline_to_write_cache(self):
# Original instance defined by user code has all variables handlers.
@@ -256,7 +235,6 @@ class PipelineInstrumentTest(unittest.TestCase):
second_pcoll_cache_key = 'second_pcoll_' + str(
id(second_pcoll)) + '_' + str(id(second_pcoll.producer))
self._mock_write_cache(second_pcoll, second_pcoll_cache_key)
- ie.current_env().cache_manager().exists = MagicMock(return_value=True)
# Mark the completeness of PCollections from the original(user) pipeline.
ie.current_env().mark_pcollection_computed(
(p_origin, init_pcoll, second_pcoll))
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
new file mode 100644
index 0000000..f39f016
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import absolute_import
+
+import collections
+import itertools
+
+import apache_beam as beam
+from apache_beam import coders
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
+from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.cache_manager import CacheManager
+from apache_beam.utils.timestamp import Duration
+from apache_beam.utils.timestamp import Timestamp
+
+
+class InMemoryCache(CacheManager):
+ """A cache that stores all PCollections in an in-memory map.
+
+ This is only used for checking the pipeline shape. This can't be used for
+ running the pipeline isn't shared between the SDK and the Runner.
+ """
+ def __init__(self):
+ self._cached = {}
+ self._pcoders = {}
+
+ def exists(self, *labels):
+ return self._key(*labels) in self._cached
+
+ def _latest_version(self, *labels):
+ return True
+
+ def read(self, *labels):
+ if not self.exists(*labels):
+ return itertools.chain([]), -1
+ ret = itertools.chain(self._cached[self._key(*labels)])
+ return ret, None
+
+ def write(self, value, *labels):
+ if not self.exists(*labels):
+ self._cached[self._key(*labels)] = []
+ self._cached[self._key(*labels)] += value
+
+ def save_pcoder(self, pcoder, *labels):
+ self._pcoders[self._key(*labels)] = pcoder
+
+ def load_pcoder(self, *labels):
+ return self._pcoders[self._key(*labels)]
+
+ def cleanup(self):
+ self._cached = collections.defaultdict(list)
+ self._pcoders = {}
+
+ def source(self, *labels):
+ vals = self._cached[self._key(*labels)]
+ return beam.Create(vals)
+
+ def sink(self, labels, is_capture=False):
+ return beam.Map(lambda _: _)
+
+ def _key(self, *labels):
+ return '/'.join([l for l in labels])
+
+
+class NoopSink(beam.PTransform):
+ def expand(self, pcoll):
+ return pcoll | beam.Map(lambda x: x)
+
+
+class FileRecordsBuilder(object):
+ def __init__(self, tag=None):
+ self._header = TestStreamFileHeader(tag=tag)
+ self._records = []
+ self._coder = coders.FastPrimitivesCoder()
+
+ def add_element(self, element, event_time_secs):
+ element_payload = TestStreamPayload.TimestampedElement(
+ encoded_element=self._coder.encode(element),
+ timestamp=Timestamp.of(event_time_secs).micros)
+ record = TestStreamFileRecord(
+ recorded_event=TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[element_payload])))
+ self._records.append(record)
+ return self
+
+ def advance_watermark(self, watermark_secs):
+ record = TestStreamFileRecord(
+ recorded_event=TestStreamPayload.Event(
+ watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+ new_watermark=Timestamp.of(watermark_secs).micros)))
+ self._records.append(record)
+ return self
+
+ def advance_processing_time(self, delta_secs):
+ record = TestStreamFileRecord(
+ recorded_event=TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=Duration.of(delta_secs).micros)))
+ self._records.append(record)
+ return self
+
+ def build(self):
+ return [self._header] + self._records