You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/03/12 06:20:13 UTC

[beam] branch master updated: [BEAM-8335] Modify the StreamingCache to subclass the CacheManager

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c69d409  [BEAM-8335] Modify the StreamingCache to subclass the CacheManager
     new bb9826c  Merge pull request #11005 from [BEAM-8335] Modify the StreamingCache to subclass the CacheManager
c69d409 is described below

commit c69d409429cf4dd3234667150458f7777b5b7f4b
Author: Sam Rohde <ro...@gmail.com>
AuthorDate: Fri Feb 28 13:32:30 2020 -0800

    [BEAM-8335] Modify the StreamingCache to subclass the CacheManager
    
    Change-Id: Ib61aa3fac53d9109178744e11eeebe5c5da0929c
---
 .../runners/interactive/cache_manager.py           |  71 ++++-
 .../runners/interactive/cache_manager_test.py      |  30 +-
 .../runners/interactive/caching/streaming_cache.py | 303 +++++++++++++++++-
 .../interactive/caching/streaming_cache_test.py    | 338 +++++++++++++++------
 .../runners/interactive/display/display_manager.py |   2 +-
 .../runners/interactive/interactive_runner_test.py |   6 +-
 .../runners/interactive/pipeline_fragment_test.py  |   2 +-
 .../interactive/pipeline_instrument_test.py        |  28 +-
 .../interactive/testing/test_cache_manager.py      | 119 ++++++++
 9 files changed, 733 insertions(+), 166 deletions(-)

diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 7274015..09c44d6 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -51,26 +51,34 @@ class CacheManager(object):
   derivation.
   """
   def exists(self, *labels):
+    # type (*str) -> bool
+
     """Returns if the PCollection cache exists."""
     raise NotImplementedError
 
   def is_latest_version(self, version, *labels):
+    # type (str, *str) -> bool
+
     """Returns if the given version number is the latest."""
     return version == self._latest_version(*labels)
 
   def _latest_version(self, *labels):
+    # type (*str) -> str
+
     """Returns the latest version number of the PCollection cache."""
     raise NotImplementedError
 
   def read(self, *labels):
+    # type (*str) -> Tuple[str, Generator[Any]]
+
     """Return the PCollection as a list as well as the version number.
 
     Args:
       *labels: List of labels for PCollection instance.
 
     Returns:
-      Tuple[List[Any], int]: A tuple containing a list of items in the
-        PCollection and the version number.
+      A tuple containing an iterator for the items in the PCollection and the
+        version number.
 
     It is possible that the version numbers from read() and_latest_version()
     are different. This usually means that the cache's been evicted (thus
@@ -79,15 +87,32 @@ class CacheManager(object):
     """
     raise NotImplementedError
 
+  def write(self, value, *labels):
+    # type (Any, *str) -> None
+
+    """Writes the value to the given cache.
+
+    Args:
+      value: An encodable (with corresponding PCoder) value
+      *labels: List of labels for PCollection instance
+    """
+    raise NotImplementedError
+
   def source(self, *labels):
-    """Returns a beam.io.Source that reads the PCollection cache."""
+    # type (*str) -> ptransform.PTransform
+
+    """Returns a PTransform that reads the PCollection cache."""
     raise NotImplementedError
 
-  def sink(self, *labels):
-    """Returns a beam.io.Sink that writes the PCollection cache."""
+  def sink(self, labels):
+    # type (*str) -> ptransform.PTransform
+
+    """Returns a PTransform that writes the PCollection cache."""
     raise NotImplementedError
 
   def save_pcoder(self, pcoder, *labels):
+    # type (coders.Coder, *str) -> None
+
     """Saves pcoder for given PCollection.
 
     Correct reading of PCollection from Cache requires PCoder to be known.
@@ -103,10 +128,14 @@ class CacheManager(object):
     raise NotImplementedError
 
   def load_pcoder(self, *labels):
+    # type (*str) -> coders.Coder
+
     """Returns previously saved PCoder for reading and writing PCollection."""
     raise NotImplementedError
 
   def cleanup(self):
+    # type () -> None
+
     """Cleans up all the PCollection caches."""
     raise NotImplementedError
 
@@ -167,22 +196,34 @@ class FileBasedCacheManager(CacheManager):
         self._saved_pcoders[self._path(*labels)])
 
   def read(self, *labels):
+    # Return an iterator to an empty list if it doesn't exist.
     if not self.exists(*labels):
-      return [], -1
+      return iter([]), -1
 
-    source = self.source(*labels)
+    # Otherwise, return a generator to the cached PCollection.
+    source = self.source(*labels)._source
     range_tracker = source.get_range_tracker(None, None)
-    result = list(source.read(range_tracker))
+    reader = source.read(range_tracker)
     version = self._latest_version(*labels)
-    return result, version
+    return reader, version
+
+  def write(self, values, *labels):
+    sink = self.sink(labels)._sink
+    path = self._path(*labels)
+
+    init_result = sink.initialize_write()
+    writer = sink.open_writer(init_result, path)
+    for v in values:
+      writer.write(v)
+    writer.close()
 
   def source(self, *labels):
     return self._reader_class(
-        self._glob_path(*labels), coder=self.load_pcoder(*labels))._source
+        self._glob_path(*labels), coder=self.load_pcoder(*labels))
 
-  def sink(self, *labels):
+  def sink(self, labels):
     return self._writer_class(
-        self._path(*labels), coder=self.load_pcoder(*labels))._sink
+        self._path(*labels), coder=self.load_pcoder(*labels))
 
   def cleanup(self):
     if filesystems.FileSystems.exists(self._cache_dir):
@@ -229,8 +270,7 @@ class ReadCache(beam.PTransform):
 
   def expand(self, pbegin):
     # pylint: disable=expression-not-assigned
-    return pbegin | 'Read' >> beam.io.Read(
-        self._cache_manager.source('full', self._label))
+    return pbegin | 'Read' >> self._cache_manager.source('full', self._label)
 
 
 class WriteCache(beam.PTransform):
@@ -255,8 +295,7 @@ class WriteCache(beam.PTransform):
           combiners.Sample.FixedSizeGlobally(self._sample_size)
           | beam.FlatMap(lambda sample: sample))
     # pylint: disable=expression-not-assigned
-    return pcoll | 'Write' >> beam.io.Write(
-        self._cache_manager.sink(prefix, self._label))
+    return pcoll | 'Write' >> self._cache_manager.sink((prefix, self._label))
 
 
 class SafeFastPrimitivesCoder(coders.Coder):
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index 6ab51b4..7868e90 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -71,12 +71,11 @@ class FileBasedCacheManagerTest(object):
 
     # Usually, the pcoder will be inferred from `pcoll.element_type`
     pcoder = coders.registry.get_coder(object)
+    # Save a pcoder for reading.
     self.cache_manager.save_pcoder(pcoder, *labels)
-    sink = self.cache_manager.sink(*labels)
-
-    with open(self.cache_manager._path(prefix, cache_file), 'wb') as f:
-      for line in pcoll_list:
-        sink.write_record(f, line)
+    # Save a pcoder for the fake write to the file.
+    self.cache_manager.save_pcoder(pcoder, prefix, cache_file)
+    self.cache_manager.write(pcoll_list, prefix, cache_file)
 
   def test_exists(self):
     """Test that CacheManager can correctly tell if the cache exists or not."""
@@ -99,7 +98,8 @@ class FileBasedCacheManagerTest(object):
     cache_version_one = ['cache', 'version', 'one']
 
     self.mock_write_cache(cache_version_one, prefix, cache_label)
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
     self.assertListEqual(pcoll_list, cache_version_one)
     self.assertEqual(version, 0)
     self.assertTrue(
@@ -113,13 +113,15 @@ class FileBasedCacheManagerTest(object):
     cache_version_two = ['cache', 'version', 'two']
 
     self.mock_write_cache(cache_version_one, prefix, cache_label)
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
 
     self.mock_write_cache(cache_version_two, prefix, cache_label)
     self.assertFalse(
         self.cache_manager.is_latest_version(version, prefix, cache_label))
 
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
     self.assertListEqual(pcoll_list, cache_version_two)
     self.assertEqual(version, 1)
     self.assertTrue(
@@ -132,7 +134,8 @@ class FileBasedCacheManagerTest(object):
 
     self.assertFalse(self.cache_manager.exists(prefix, cache_label))
 
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
     self.assertListEqual(pcoll_list, [])
     self.assertEqual(version, -1)
     self.assertTrue(
@@ -147,7 +150,8 @@ class FileBasedCacheManagerTest(object):
 
     # The initial write and read.
     self.mock_write_cache(cache_version_one, prefix, cache_label)
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
 
     # Cache cleanup.
     self.cache_manager.cleanup()
@@ -155,7 +159,8 @@ class FileBasedCacheManagerTest(object):
     self.assertTrue(
         self.cache_manager.is_latest_version(version, prefix, cache_label))
 
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
     self.assertListEqual(pcoll_list, [])
     self.assertEqual(version, -1)
     self.assertFalse(
@@ -166,7 +171,8 @@ class FileBasedCacheManagerTest(object):
     self.assertFalse(
         self.cache_manager.is_latest_version(version, prefix, cache_label))
 
-    pcoll_list, version = self.cache_manager.read(prefix, cache_label)
+    reader, version = self.cache_manager.read(prefix, cache_label)
+    pcoll_list = list(reader)
     self.assertListEqual(pcoll_list, cache_version_two)
     # Check that version continues from the previous value instead of starting
     # from 0 again.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index 0aabda3..c313404 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -19,15 +19,295 @@
 
 from __future__ import absolute_import
 
+import os
+import shutil
+import tempfile
+import time
+from collections import OrderedDict
+
+import apache_beam as beam
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.cache_manager import CacheManager
+from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
+from apache_beam.testing.test_stream import OutputFormat
+from apache_beam.testing.test_stream import ReverseTestStream
 from apache_beam.utils import timestamp
 
+try:
+  from pathlib import Path
+except ImportError:
+  from pathlib2 import Path  # python 2 backport
+
+
+class StreamingCacheSink(beam.PTransform):
+  """A PTransform that writes TestStreamFile(Header|Records)s to file.
+
+  This transform takes in an arbitrary element stream and writes the list of
+  TestStream events (as TestStreamFileRecords) to file. When replayed, this
+  will produce the best-effort replay of the original job (e.g. some elements
+  may be produced slightly out of order from the original stream).
+
+  Note that this PTransform is assumed to be only run on a single machine where
+  the following assumptions are correct: elements come in ordered, no two
+  transforms are writing to the same file. This PTransform is assumed to only
+  run correctly with the DirectRunner.
+
+  TODO(BEAM-9447): Generalize this to more source/sink types aside from file
+  based. Also, generalize to cases where there might be multiple workers
+  writing to the same sink.
+  """
+  def __init__(
+      self,
+      cache_dir,
+      filename,
+      sample_resolution_sec,
+      coder=SafeFastPrimitivesCoder()):
+    self._cache_dir = cache_dir
+    self._filename = filename
+    self._sample_resolution_sec = sample_resolution_sec
+    self._coder = coder
+    self._path = os.path.join(self._cache_dir, self._filename)
+
+  @property
+  def path(self):
+    """Returns the path the sink leads to."""
+    return self._path
+
+  def expand(self, pcoll):
+    class StreamingWriteToText(beam.DoFn):
+      """DoFn that performs the writing.
+
+      Note that the other file writing methods cannot be used in streaming
+      contexts.
+      """
+      def __init__(self, full_path, coder=SafeFastPrimitivesCoder()):
+        self._full_path = full_path
+        self._coder = coder
+
+        # Try and make the given path.
+        Path(os.path.dirname(full_path)).mkdir(exist_ok=True)
+
+      def start_bundle(self):
+        # Open the file for 'append-mode' and writing 'bytes'.
+        self._fh = open(self._full_path, 'ab')
+
+      def finish_bundle(self):
+        self._fh.close()
+
+      def process(self, e):
+        """Appends the given element to the file.
+        """
+        self._fh.write(self._coder.encode(e) + b'\n')
+
+    return (
+        pcoll
+        | ReverseTestStream(
+            output_tag=self._filename,
+            sample_resolution_sec=self._sample_resolution_sec,
+            output_format=OutputFormat.SERIALIZED_TEST_STREAM_FILE_RECORDS,
+            coder=self._coder)
+        | beam.ParDo(
+            StreamingWriteToText(full_path=self._path, coder=self._coder)))
 
-class StreamingCache(object):
+
+class StreamingCacheSource:
+  """A class that reads and parses TestStreamFile(Header|Reader)s.
+
+  This source operates in the following way:
+
+    1. Wait for up to `timeout_secs` for the file to be available.
+    2. Read, parse, and emit the entire contents of the file
+    3. Wait for more events to come or until `is_cache_complete` returns True
+    4. If there are more events, then go to 2
+    5. Otherwise, stop emitting.
+
+  This class is used to read from file and send its to the TestStream via the
+  StreamingCacheManager.Reader.
+  """
+  def __init__(
+      self,
+      cache_dir,
+      labels,
+      is_cache_complete=None,
+      coder=SafeFastPrimitivesCoder()):
+    self._cache_dir = cache_dir
+    self._coder = coder
+    self._labels = labels
+    self._is_cache_complete = (
+        is_cache_complete if is_cache_complete else lambda: True)
+
+  def _wait_until_file_exists(self, timeout_secs=30):
+    """Blocks until the file exists for a maximum of timeout_secs.
+    """
+    now_secs = time.time()
+    timeout_timestamp_secs = now_secs + timeout_secs
+
+    # Wait for up to `timeout_secs` for the file to be available.
+    start = time.time()
+    path = os.path.join(self._cache_dir, *self._labels)
+    while not os.path.exists(path):
+      time.sleep(1)
+      if time.time() - start > timeout_timestamp_secs:
+        raise RuntimeError(
+            "Timed out waiting for file '{}' to be available".format(path))
+    return open(path, mode='rb')
+
+  def _emit_from_file(self, fh, tail):
+    """Emits the TestStreamFile(Header|Record)s from file.
+
+    This returns a generator to be able to read all lines from the given file.
+    If `tail` is True, then it will wait until the cache is complete to exit.
+    Otherwise, it will read the file only once.
+    """
+    # Always read at least once to read the whole file.
+    while True:
+      pos = fh.tell()
+      line = fh.readline()
+
+      # Check if we are at EOF or if we have an incomplete line.
+      if not line or (line and line[-1] != b'\n'[0]):
+        # Complete reading only when the cache is complete.
+        if self._is_cache_complete():
+          break
+
+        if not tail:
+          break
+
+        # Otherwise wait for new data in the file to be written.
+        time.sleep(0.5)
+        fh.seek(pos)
+      else:
+        # The first line at pos = 0 is always the header. Read the line without
+        # the new line.
+        to_decode = line[:-1]
+        if pos == 0:
+          header = TestStreamFileHeader()
+          header.ParseFromString(self._coder.decode(to_decode))
+          yield header
+        else:
+          record = TestStreamFileRecord()
+          record.ParseFromString(self._coder.decode(to_decode))
+          yield record
+
+  def read(self, tail):
+    """Reads all TestStreamFile(Header|TestStreamFileRecord)s from file.
+
+    This returns a generator to be able to read all lines from the given file.
+    If `tail` is True, then it will wait until the cache is complete to exit.
+    Otherwise, it will read the file only once.
+    """
+    with self._wait_until_file_exists() as f:
+      for e in self._emit_from_file(f, tail):
+        yield e
+
+
+class StreamingCache(CacheManager):
   """Abstraction that holds the logic for reading and writing to cache.
   """
-  def __init__(self, readers):
-    self._readers = readers
+  def __init__(
+      self, cache_dir, is_cache_complete=None, sample_resolution_sec=0.1):
+    self._sample_resolution_sec = sample_resolution_sec
+    self._is_cache_complete = is_cache_complete
+
+    if cache_dir:
+      self._cache_dir = cache_dir
+    else:
+      self._cache_dir = tempfile.mkdtemp(
+          prefix='interactive-temp-', dir=os.environ.get('TEST_TMPDIR', None))
+
+    # List of saved pcoders keyed by PCollection path. It is OK to keep this
+    # list in memory because once FileBasedCacheManager object is
+    # destroyed/re-created it loses the access to previously written cache
+    # objects anyways even if cache_dir already exists. In other words,
+    # it is not possible to resume execution of Beam pipeline from the
+    # saved cache if FileBasedCacheManager has been reset.
+    #
+    # However, if we are to implement better cache persistence, one needs
+    # to take care of keeping consistency between the cached PCollection
+    # and its PCoder type.
+    self._saved_pcoders = {}
+    self._default_pcoder = SafeFastPrimitivesCoder()
+
+  def exists(self, *labels):
+    path = os.path.join(self._cache_dir, *labels)
+    return os.path.exists(path)
+
+  # TODO(srohde): Modify this to return the correct version.
+  def read(self, *labels):
+    """Returns a generator to read all records from file.
+
+    Does not tail.
+    """
+    if not self.exists(*labels):
+      return iter([]), -1
+
+    reader = StreamingCacheSource(
+        self._cache_dir, labels,
+        is_cache_complete=self._is_cache_complete).read(tail=False)
+    header = next(reader)
+    return StreamingCache.Reader([header], [reader]).read(), 1
+
+  def read_multiple(self, labels):
+    """Returns a generator to read all records from file.
+
+    Does tail until the cache is complete. This is because it is used in the
+    TestStreamServiceController to read from file which is only used during
+    pipeline runtime which needs to block.
+    """
+    readers = [
+        StreamingCacheSource(
+            self._cache_dir, l,
+            is_cache_complete=self._is_cache_complete).read(tail=True)
+        for l in labels
+    ]
+    headers = [next(r) for r in readers]
+    return StreamingCache.Reader(headers, readers).read()
+
+  def write(self, values, *labels):
+    """Writes the given values to cache.
+    """
+    directory = os.path.join(self._cache_dir, *labels[:-1])
+    filepath = os.path.join(directory, labels[-1])
+    if not os.path.exists(directory):
+      os.makedirs(directory)
+    with open(filepath, 'ab') as f:
+      for v in values:
+        f.write(self._default_pcoder.encode(v.SerializeToString()) + b'\n')
+
+  def source(self, *labels):
+    """Returns the StreamingCacheManager source.
+
+    This is beam.Impulse() because unbounded sources will be marked with this
+    and then the PipelineInstrument will replace these with a TestStream.
+    """
+    return beam.Impulse()
+
+  def sink(self, labels):
+    """Returns a StreamingCacheSink to write elements to file.
+
+    Note that this is assumed to only work in the DirectRunner as the underlying
+    StreamingCacheSink assumes a single machine to have correct element
+    ordering.
+    """
+    filename = labels[-1]
+    cache_dir = os.path.join(self._cache_dir, *labels[:-1])
+    return StreamingCacheSink(cache_dir, filename, self._sample_resolution_sec)
+
+  def save_pcoder(self, pcoder, *labels):
+    self._saved_pcoders[os.path.join(*labels)] = pcoder
+
+  def load_pcoder(self, *labels):
+    return (
+        self._default_pcoder if self._default_pcoder is not None else
+        self._saved_pcoders[os.path.join(*labels)])
+
+  def cleanup(self):
+    if os.path.exists(self._cache_dir):
+      shutil.rmtree(self._cache_dir)
+    self._saved_pcoders = {}
+    self._capture_sinks = {}
 
   class Reader(object):
     """Abstraction that reads from PCollection readers.
@@ -38,7 +318,7 @@ class StreamingCache(object):
     This class is also responsible for holding the state of the clock, injecting
     clock advancement events, and watermark advancement events.
     """
-    def __init__(self, readers):
+    def __init__(self, headers, readers):
       # This timestamp is used as the monotonic clock to order events in the
       # replay.
       self._monotonic_clock = timestamp.Timestamp.of(0)
@@ -49,8 +329,9 @@ class StreamingCache(object):
       # The file headers that are metadata for that particular PCollection.
       # The header allows for metadata about an entire stream, so that the data
       # isn't copied per record.
-      self._headers = {r.header().tag: r.header() for r in readers}
-      self._readers = {r.header().tag: r.read() for r in readers}
+      self._headers = {header.tag: header for header in headers}
+      self._readers = OrderedDict(
+          ((h.tag, r) for (h, r) in zip(headers, readers)))
 
       # The most recently read timestamp per tag.
       self._stream_times = {
@@ -141,13 +422,6 @@ class StreamingCache(object):
         unsent_events = events_to_send
         target_timestamp = self._min_timestamp_of(unsent_events)
 
-    def _add_element(self, element, tag):
-      """Constructs an AddElement event for the specified element and tag.
-      """
-      return TestStreamPayload.Event(
-          element_event=TestStreamPayload.Event.AddElements(
-              elements=[element], tag=tag))
-
     def _advance_processing_time(self, new_timestamp):
       """Advances the internal clock and returns an AdvanceProcessingTime event.
       """
@@ -157,6 +431,3 @@ class StreamingCache(object):
               advance_duration=advancy_by))
       self._monotonic_clock = new_timestamp
       return e
-
-  def reader(self):
-    return StreamingCache.Reader(self._readers)
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index f9f450e..002a05a 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -22,12 +22,18 @@ from __future__ import absolute_import
 import unittest
 
 from apache_beam import coders
+from apache_beam.options.pipeline_options import DebugOptions
+from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
 from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
 from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
 from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
-from apache_beam.utils.timestamp import Duration
-from apache_beam.utils.timestamp import Timestamp
+from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.test_stream import TestStream
+from apache_beam.testing.util import *
+from apache_beam.transforms.window import TimestampedValue
 
 # Nose automatically detects tests if they match a regex. Here, it mistakens
 # these protos as tests. For more info see the Nose docs at:
@@ -37,76 +43,45 @@ TestStreamFileHeader.__test__ = False  # type: ignore[attr-defined]
 TestStreamFileRecord.__test__ = False  # type: ignore[attr-defined]
 
 
-class InMemoryReader(object):
-  def __init__(self, tag=None):
-    self._header = TestStreamFileHeader(tag=tag)
-    self._records = []
-    self._coder = coders.FastPrimitivesCoder()
-
-  def add_element(self, element, event_time):
-    element_payload = TestStreamPayload.TimestampedElement(
-        encoded_element=self._coder.encode(element),
-        timestamp=Timestamp.of(event_time).micros)
-    record = TestStreamFileRecord(
-        recorded_event=TestStreamPayload.Event(
-            element_event=TestStreamPayload.Event.AddElements(
-                elements=[element_payload])))
-    self._records.append(record)
-
-  def advance_watermark(self, watermark):
-    record = TestStreamFileRecord(
-        recorded_event=TestStreamPayload.Event(
-            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
-                new_watermark=Timestamp.of(watermark).micros)))
-    self._records.append(record)
-
-  def advance_processing_time(self, processing_time_delta):
-    record = TestStreamFileRecord(
-        recorded_event=TestStreamPayload.Event(
-            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
-                advance_duration=Duration.of(processing_time_delta).micros)))
-    self._records.append(record)
-
-  def header(self):
-    return self._header
-
-  def read(self):
-    for r in self._records:
-      yield r
-
-
-def all_events(reader):
-  events = []
-  for e in reader.read():
-    events.append(e)
-  return events
-
-
 class StreamingCacheTest(unittest.TestCase):
   def setUp(self):
     pass
 
+  def test_exists(self):
+    cache = StreamingCache(cache_dir=None)
+    self.assertFalse(cache.exists('my_label'))
+    cache.write([TestStreamFileRecord()], 'my_label')
+    self.assertTrue(cache.exists('my_label'))
+
   def test_single_reader(self):
     """Tests that we expect to see all the correctly emitted TestStreamPayloads.
     """
-    in_memory_reader = InMemoryReader()
-    in_memory_reader.add_element(element=0, event_time=0)
-    in_memory_reader.advance_processing_time(1)
-    in_memory_reader.add_element(element=1, event_time=1)
-    in_memory_reader.advance_processing_time(1)
-    in_memory_reader.add_element(element=2, event_time=2)
-    cache = StreamingCache([in_memory_reader])
-    reader = cache.reader()
+    CACHED_PCOLLECTION_KEY = 'arbitrary_key'
+
+    values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+              .add_element(element=0, event_time_secs=0)
+              .advance_processing_time(1)
+              .add_element(element=1, event_time_secs=1)
+              .advance_processing_time(1)
+              .add_element(element=2, event_time_secs=2)
+              .build()) # yapf: disable
+
+    cache = StreamingCache(cache_dir=None)
+    cache.write(values, CACHED_PCOLLECTION_KEY)
+
+    reader, _ = cache.read(CACHED_PCOLLECTION_KEY)
     coder = coders.FastPrimitivesCoder()
-    events = all_events(reader)
+    events = list(reader)
 
+    # Units here are in microseconds.
     expected = [
         TestStreamPayload.Event(
             element_event=TestStreamPayload.Event.AddElements(
                 elements=[
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode(0), timestamp=0)
-                ])),
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
         TestStreamPayload.Event(
             processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                 advance_duration=1 * 10**6)),
@@ -115,7 +90,8 @@ class StreamingCacheTest(unittest.TestCase):
                 elements=[
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode(1), timestamp=1 * 10**6)
-                ])),
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
         TestStreamPayload.Event(
             processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
                 advance_duration=1 * 10**6)),
@@ -124,38 +100,53 @@ class StreamingCacheTest(unittest.TestCase):
                 elements=[
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode(2), timestamp=2 * 10**6)
-                ])),
+                ],
+                tag=CACHED_PCOLLECTION_KEY)),
     ]
     self.assertSequenceEqual(events, expected)
 
   def test_multiple_readers(self):
-    """Tests that the service advances the clock with multiple outputs."""
-
-    letters = InMemoryReader('letters')
-    letters.advance_processing_time(1)
-    letters.advance_watermark(0)
-    letters.add_element(element='a', event_time=0)
-    letters.advance_processing_time(10)
-    letters.advance_watermark(10)
-    letters.add_element(element='b', event_time=10)
-
-    numbers = InMemoryReader('numbers')
-    numbers.advance_processing_time(2)
-    numbers.add_element(element=1, event_time=0)
-    numbers.advance_processing_time(1)
-    numbers.add_element(element=2, event_time=0)
-    numbers.advance_processing_time(1)
-    numbers.add_element(element=2, event_time=0)
-
-    late = InMemoryReader('late')
-    late.advance_processing_time(101)
-    late.add_element(element='late', event_time=0)
-
-    cache = StreamingCache([letters, numbers, late])
-    reader = cache.reader()
+    """Tests that the service advances the clock with multiple outputs.
+    """
+
+    CACHED_LETTERS = 'letters'
+    CACHED_NUMBERS = 'numbers'
+    CACHED_LATE = 'late'
+
+    letters = (FileRecordsBuilder(CACHED_LETTERS)
+               .advance_processing_time(1)
+               .advance_watermark(watermark_secs=0)
+               .add_element(element='a', event_time_secs=0)
+               .advance_processing_time(10)
+               .advance_watermark(watermark_secs=10)
+               .add_element(element='b', event_time_secs=10)
+               .build()) # yapf: disable
+
+    numbers = (FileRecordsBuilder(CACHED_NUMBERS)
+               .advance_processing_time(2)
+               .add_element(element=1, event_time_secs=0)
+               .advance_processing_time(1)
+               .add_element(element=2, event_time_secs=0)
+               .advance_processing_time(1)
+               .add_element(element=2, event_time_secs=0)
+               .build()) # yapf: disable
+
+    late = (FileRecordsBuilder(CACHED_LATE)
+            .advance_processing_time(101)
+            .add_element(element='late', event_time_secs=0)
+            .build()) # yapf: disable
+
+    cache = StreamingCache(cache_dir=None)
+    cache.write(letters, CACHED_LETTERS)
+    cache.write(numbers, CACHED_NUMBERS)
+    cache.write(late, CACHED_LATE)
+
+    reader = cache.read_multiple([[CACHED_LETTERS], [CACHED_NUMBERS],
+                                  [CACHED_LATE]])
     coder = coders.FastPrimitivesCoder()
-    events = all_events(reader)
+    events = list(reader)
 
+    # Units here are in microseconds.
     expected = [
         # Advances clock from 0 to 1
         TestStreamPayload.Event(
@@ -163,14 +154,14 @@ class StreamingCacheTest(unittest.TestCase):
                 advance_duration=1 * 10**6)),
         TestStreamPayload.Event(
             watermark_event=TestStreamPayload.Event.AdvanceWatermark(
-                new_watermark=0, tag='letters')),
+                new_watermark=0, tag=CACHED_LETTERS)),
         TestStreamPayload.Event(
             element_event=TestStreamPayload.Event.AddElements(
                 elements=[
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode('a'), timestamp=0)
                 ],
-                tag='letters')),
+                tag=CACHED_LETTERS)),
 
         # Advances clock from 1 to 2
         TestStreamPayload.Event(
@@ -182,7 +173,7 @@ class StreamingCacheTest(unittest.TestCase):
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode(1), timestamp=0)
                 ],
-                tag='numbers')),
+                tag=CACHED_NUMBERS)),
 
         # Advances clock from 2 to 3
         TestStreamPayload.Event(
@@ -194,7 +185,7 @@ class StreamingCacheTest(unittest.TestCase):
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode(2), timestamp=0)
                 ],
-                tag='numbers')),
+                tag=CACHED_NUMBERS)),
 
         # Advances clock from 3 to 4
         TestStreamPayload.Event(
@@ -206,7 +197,7 @@ class StreamingCacheTest(unittest.TestCase):
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode(2), timestamp=0)
                 ],
-                tag='numbers')),
+                tag=CACHED_NUMBERS)),
 
         # Advances clock from 4 to 11
         TestStreamPayload.Event(
@@ -214,14 +205,14 @@ class StreamingCacheTest(unittest.TestCase):
                 advance_duration=7 * 10**6)),
         TestStreamPayload.Event(
             watermark_event=TestStreamPayload.Event.AdvanceWatermark(
-                new_watermark=10 * 10**6, tag='letters')),
+                new_watermark=10 * 10**6, tag=CACHED_LETTERS)),
         TestStreamPayload.Event(
             element_event=TestStreamPayload.Event.AddElements(
                 elements=[
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode('b'), timestamp=10 * 10**6)
                 ],
-                tag='letters')),
+                tag=CACHED_LETTERS)),
 
         # Advances clock from 11 to 101
         TestStreamPayload.Event(
@@ -233,11 +224,174 @@ class StreamingCacheTest(unittest.TestCase):
                     TestStreamPayload.TimestampedElement(
                         encoded_element=coder.encode('late'), timestamp=0)
                 ],
-                tag='late')),
+                tag=CACHED_LATE)),
     ]
 
     self.assertSequenceEqual(events, expected)
 
+  def test_read_and_write(self):
+    """An integration test between the Sink and Source.
+
+    This ensures that the sink and source speak the same language in terms of
+    coders, protos, order, and units.
+    """
+
+    # Units here are in seconds.
+    test_stream = (TestStream()
+                   .advance_watermark_to(0, tag='records')
+                   .advance_processing_time(5)
+                   .add_elements(['a', 'b', 'c'], tag='records')
+                   .advance_watermark_to(10, tag='records')
+                   .advance_processing_time(1)
+                   .add_elements(
+                       [
+                           TimestampedValue('1', 15),
+                           TimestampedValue('2', 15),
+                           TimestampedValue('3', 15)
+                       ],
+                       tag='records')) # yapf: disable
+
+    coder = SafeFastPrimitivesCoder()
+    cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
+
+    options = StandardOptions(streaming=True)
+    options.view_as(DebugOptions).add_experiment(
+        'passthrough_pcollection_output_ids')
+    with TestPipeline(options=options) as p:
+      # pylint: disable=expression-not-assigned
+      p | test_stream | cache.sink(['records'])
+
+    reader, _ = cache.read('records')
+    actual_events = list(reader)
+
+    # Units here are in microseconds.
+    expected_events = [
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=5 * 10**6)),
+        TestStreamPayload.Event(
+            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+                new_watermark=0, tag='records')),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('a'), timestamp=0),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('b'), timestamp=0),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('c'), timestamp=0),
+                ],
+                tag='records')),
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=1 * 10**6)),
+        TestStreamPayload.Event(
+            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+                new_watermark=10 * 10**6, tag='records')),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('1'), timestamp=15 *
+                        10**6),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('2'), timestamp=15 *
+                        10**6),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('3'), timestamp=15 *
+                        10**6),
+                ],
+                tag='records')),
+    ]
+    self.assertEqual(actual_events, expected_events)
+
+  def test_read_and_write_multiple_outputs(self):
+    """An integration test between the Sink and Source with multiple outputs.
+
+    This tests the funcionatlity that the StreamingCache reads from multiple
+    files and combines them into a single sorted output.
+    """
+    LETTERS_TAG = 'letters'
+    NUMBERS_TAG = 'numbers'
+
+    # Units here are in seconds.
+    test_stream = (TestStream()
+                   .advance_watermark_to(0, tag=LETTERS_TAG)
+                   .advance_processing_time(5)
+                   .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG)
+                   .advance_watermark_to(10, tag=NUMBERS_TAG)
+                   .advance_processing_time(1)
+                   .add_elements(
+                       [
+                           TimestampedValue('1', 15),
+                           TimestampedValue('2', 15),
+                           TimestampedValue('3', 15)
+                       ],
+                       tag=NUMBERS_TAG)) # yapf: disable
+
+    cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
+
+    coder = SafeFastPrimitivesCoder()
+
+    options = StandardOptions(streaming=True)
+    options.view_as(DebugOptions).add_experiment(
+        'passthrough_pcollection_output_ids')
+    with TestPipeline(options=options) as p:
+      # pylint: disable=expression-not-assigned
+      events = p | test_stream
+      events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG])
+      events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG])
+
+    reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]])
+    actual_events = list(reader)
+
+    # Units here are in microseconds.
+    expected_events = [
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=5 * 10**6)),
+        TestStreamPayload.Event(
+            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+                new_watermark=0, tag=LETTERS_TAG)),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('a'), timestamp=0),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('b'), timestamp=0),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('c'), timestamp=0),
+                ],
+                tag=LETTERS_TAG)),
+        TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=1 * 10**6)),
+        TestStreamPayload.Event(
+            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+                new_watermark=10 * 10**6, tag=NUMBERS_TAG)),
+        TestStreamPayload.Event(
+            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+                new_watermark=0, tag=LETTERS_TAG)),
+        TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('1'), timestamp=15 *
+                        10**6),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('2'), timestamp=15 *
+                        10**6),
+                    TestStreamPayload.TimestampedElement(
+                        encoded_element=coder.encode('3'), timestamp=15 *
+                        10**6),
+                ],
+                tag=NUMBERS_TAG)),
+    ]
+
+    self.assertListEqual(actual_events, expected_events)
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/display/display_manager.py b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
index 1b7af26..3c58ce2 100644
--- a/sdks/python/apache_beam/runners/interactive/display/display_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/display/display_manager.py
@@ -142,7 +142,7 @@ class DisplayManager(object):
         if force or not self._cache_manager.is_latest_version(
             version, 'sample', cache_label):
           pcoll_list, version = self._cache_manager.read('sample', cache_label)
-          stats['sample'] = pcoll_list
+          stats['sample'] = list(pcoll_list)
           stats['version'] = version
           stats_updated = True
 
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
index 7146067..3e1ad80 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py
@@ -150,11 +150,11 @@ class InteractiveRunnerTest(unittest.TestCase):
     ib.watch(locals())
     result = p.run()
     self.assertTrue(init in ie.current_env().computed_pcollections)
-    self.assertEqual([0, 1, 2, 3, 4], result.get(init))
+    self.assertEqual([0, 1, 2, 3, 4], list(result.get(init)))
     self.assertTrue(square in ie.current_env().computed_pcollections)
-    self.assertEqual([0, 1, 4, 9, 16], result.get(square))
+    self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
     self.assertTrue(cube in ie.current_env().computed_pcollections)
-    self.assertEqual([0, 1, 8, 27, 64], result.get(cube))
+    self.assertEqual([0, 1, 8, 27, 64], list(result.get(cube)))
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py
index cc5f5de..d8ab43c 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py
@@ -115,7 +115,7 @@ class PipelineFragmentTest(unittest.TestCase):
 
     ib.watch(locals())
     result = pf.PipelineFragment([square]).run()
-    self.assertEqual([0, 1, 4, 9, 16], result.get(square))
+    self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
index a936850..771f015 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
@@ -21,12 +21,10 @@
 from __future__ import absolute_import
 
 import tempfile
-import time
 import unittest
 
 import apache_beam as beam
 from apache_beam import coders
-from apache_beam.io import filesystems
 from apache_beam.pipeline import PipelineVisitor
 from apache_beam.runners.interactive import cache_manager as cache
 from apache_beam.runners.interactive import interactive_beam as ib
@@ -35,17 +33,12 @@ from apache_beam.runners.interactive import pipeline_instrument as instr
 from apache_beam.runners.interactive import interactive_runner
 from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_equal
 from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal
-
-# Work around nose tests using Python2 without unittest.mock module.
-try:
-  from unittest.mock import MagicMock
-except ImportError:
-  from mock import MagicMock
+from apache_beam.runners.interactive.testing.test_cache_manager import InMemoryCache
 
 
 class PipelineInstrumentTest(unittest.TestCase):
   def setUp(self):
-    ie.new_env(cache_manager=cache.FileBasedCacheManager())
+    ie.new_env(cache_manager=InMemoryCache())
 
   def test_pcolls_to_pcoll_id(self):
     p = beam.Pipeline(interactive_runner.InteractiveRunner())
@@ -205,26 +198,12 @@ class PipelineInstrumentTest(unittest.TestCase):
 
   def _mock_write_cache(self, pcoll, cache_key):
     """Cache the PCollection where cache.WriteCache would write to."""
-    cache_path = filesystems.FileSystems.join(
-        ie.current_env().cache_manager()._cache_dir, 'full')
-    if not filesystems.FileSystems.exists(cache_path):
-      filesystems.FileSystems.mkdirs(cache_path)
-
-    # Pause for 0.1 sec, because the Jenkins test runs so fast that the file
-    # writes happen at the same timestamp.
-    time.sleep(0.1)
-
-    cache_file = cache_key + '-1-of-2'
     labels = ['full', cache_key]
 
     # Usually, the pcoder will be inferred from `pcoll.element_type`
     pcoder = coders.registry.get_coder(object)
     ie.current_env().cache_manager().save_pcoder(pcoder, *labels)
-    sink = ie.current_env().cache_manager().sink(*labels)
-
-    with open(ie.current_env().cache_manager()._path('full', cache_file),
-              'wb') as f:
-      sink.write_record(f, pcoll)
+    ie.current_env().cache_manager().write([b''], *labels)
 
   def test_instrument_example_pipeline_to_write_cache(self):
     # Original instance defined by user code has all variables handlers.
@@ -256,7 +235,6 @@ class PipelineInstrumentTest(unittest.TestCase):
     second_pcoll_cache_key = 'second_pcoll_' + str(
         id(second_pcoll)) + '_' + str(id(second_pcoll.producer))
     self._mock_write_cache(second_pcoll, second_pcoll_cache_key)
-    ie.current_env().cache_manager().exists = MagicMock(return_value=True)
     # Mark the completeness of PCollections from the original(user) pipeline.
     ie.current_env().mark_pcollection_computed(
         (p_origin, init_pcoll, second_pcoll))
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
new file mode 100644
index 0000000..f39f016
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import absolute_import
+
+import collections
+import itertools
+
+import apache_beam as beam
+from apache_beam import coders
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
+from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
+from apache_beam.runners.interactive.cache_manager import CacheManager
+from apache_beam.utils.timestamp import Duration
+from apache_beam.utils.timestamp import Timestamp
+
+
+class InMemoryCache(CacheManager):
+  """A cache that stores all PCollections in an in-memory map.
+
+  This is only used for checking the pipeline shape. This can't be used for
+  running the pipeline isn't shared between the SDK and the Runner.
+  """
+  def __init__(self):
+    self._cached = {}
+    self._pcoders = {}
+
+  def exists(self, *labels):
+    return self._key(*labels) in self._cached
+
+  def _latest_version(self, *labels):
+    return True
+
+  def read(self, *labels):
+    if not self.exists(*labels):
+      return itertools.chain([]), -1
+    ret = itertools.chain(self._cached[self._key(*labels)])
+    return ret, None
+
+  def write(self, value, *labels):
+    if not self.exists(*labels):
+      self._cached[self._key(*labels)] = []
+    self._cached[self._key(*labels)] += value
+
+  def save_pcoder(self, pcoder, *labels):
+    self._pcoders[self._key(*labels)] = pcoder
+
+  def load_pcoder(self, *labels):
+    return self._pcoders[self._key(*labels)]
+
+  def cleanup(self):
+    self._cached = collections.defaultdict(list)
+    self._pcoders = {}
+
+  def source(self, *labels):
+    vals = self._cached[self._key(*labels)]
+    return beam.Create(vals)
+
+  def sink(self, labels, is_capture=False):
+    return beam.Map(lambda _: _)
+
+  def _key(self, *labels):
+    return '/'.join([l for l in labels])
+
+
+class NoopSink(beam.PTransform):
+  def expand(self, pcoll):
+    return pcoll | beam.Map(lambda x: x)
+
+
+class FileRecordsBuilder(object):
+  def __init__(self, tag=None):
+    self._header = TestStreamFileHeader(tag=tag)
+    self._records = []
+    self._coder = coders.FastPrimitivesCoder()
+
+  def add_element(self, element, event_time_secs):
+    element_payload = TestStreamPayload.TimestampedElement(
+        encoded_element=self._coder.encode(element),
+        timestamp=Timestamp.of(event_time_secs).micros)
+    record = TestStreamFileRecord(
+        recorded_event=TestStreamPayload.Event(
+            element_event=TestStreamPayload.Event.AddElements(
+                elements=[element_payload])))
+    self._records.append(record)
+    return self
+
+  def advance_watermark(self, watermark_secs):
+    record = TestStreamFileRecord(
+        recorded_event=TestStreamPayload.Event(
+            watermark_event=TestStreamPayload.Event.AdvanceWatermark(
+                new_watermark=Timestamp.of(watermark_secs).micros)))
+    self._records.append(record)
+    return self
+
+  def advance_processing_time(self, delta_secs):
+    record = TestStreamFileRecord(
+        recorded_event=TestStreamPayload.Event(
+            processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+                advance_duration=Duration.of(delta_secs).micros)))
+    self._records.append(record)
+    return self
+
+  def build(self):
+    return [self._header] + self._records