You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/03/04 17:49:58 UTC

[beam] branch master updated: ReverseTestStream Implementation

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 773528b  ReverseTestStream Implementation
     new 7b583e4  Merge pull request #10497 from [BEAM-8335] Add the ReverseTestStream
773528b is described below

commit 773528b083315daec52d8f1ea310401e00327ca6
Author: Sam Rohde <ro...@gmail.com>
AuthorDate: Wed Feb 19 14:47:17 2020 -0800

    ReverseTestStream Implementation
    
    Change-Id: Ie59b9483f4a36796efa203f811610c7fa6cc318c
---
 sdks/python/apache_beam/runners/direct/clock.py    |   8 +-
 .../apache_beam/runners/direct/test_stream_impl.py |   1 +
 .../runners/direct/transform_evaluator.py          |  67 +++-
 sdks/python/apache_beam/testing/test_stream.py     | 340 ++++++++++++++++-
 .../python/apache_beam/testing/test_stream_test.py | 409 +++++++++++++++++++--
 5 files changed, 785 insertions(+), 40 deletions(-)

diff --git a/sdks/python/apache_beam/runners/direct/clock.py b/sdks/python/apache_beam/runners/direct/clock.py
index 54b5701..e1c9b20 100644
--- a/sdks/python/apache_beam/runners/direct/clock.py
+++ b/sdks/python/apache_beam/runners/direct/clock.py
@@ -26,6 +26,8 @@ from __future__ import absolute_import
 import time
 from builtins import object
 
+from apache_beam.utils.timestamp import Timestamp
+
 
 class Clock(object):
   def time(self):
@@ -44,11 +46,11 @@ class RealClock(object):
 
 class TestClock(object):
   """Clock used for Testing"""
-  def __init__(self, current_time=0):
-    self._current_time = current_time
+  def __init__(self, current_time=None):
+    self._current_time = current_time if current_time else Timestamp()
 
   def time(self):
-    return self._current_time
+    return float(self._current_time)
 
   def advance_time(self, advance_by):
     self._current_time += advance_by
diff --git a/sdks/python/apache_beam/runners/direct/test_stream_impl.py b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
index 8cee1fc..5f325d5 100644
--- a/sdks/python/apache_beam/runners/direct/test_stream_impl.py
+++ b/sdks/python/apache_beam/runners/direct/test_stream_impl.py
@@ -136,6 +136,7 @@ class _TestStream(PTransform):
     self.coder = coder
     self._raw_events = events
     self._events = self._add_watermark_advancements(output_tags, events)
+    self.output_tags = output_tags
 
   def _watermark_starts(self, output_tags):
     """Sentinel values to hold the watermark of outputs to -inf.
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index b1a7d83..acec67a 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -57,7 +57,9 @@ from apache_beam.runners.direct.util import KeyedWorkItem
 from apache_beam.runners.direct.util import TransformResult
 from apache_beam.runners.direct.watermark_manager import WatermarkManager
 from apache_beam.testing.test_stream import ElementEvent
+from apache_beam.testing.test_stream import PairWithTiming
 from apache_beam.testing.test_stream import ProcessingTimeEvent
+from apache_beam.testing.test_stream import TimingInfo
 from apache_beam.testing.test_stream import WatermarkEvent
 from apache_beam.transforms import core
 from apache_beam.transforms.trigger import InMemoryUnmergedState
@@ -110,6 +112,7 @@ class TransformEvaluatorRegistry(object):
         _TestStream: _TestStreamEvaluator,
         ProcessElements: _ProcessElementsEvaluator,
         _WatermarkController: _WatermarkControllerEvaluator,
+        PairWithTiming: _PairWithTimingEvaluator,
     }  # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]]
     self._evaluators.update(self._test_evaluators_overrides)
     self._root_bundle_providers = {
@@ -420,8 +423,12 @@ class _WatermarkControllerEvaluator(_TransformEvaluator):
       main_output = list(self._outputs)[0]
       bundle = self._evaluation_context.create_bundle(main_output)
       for tv in event.timestamped_values:
-        bundle.output(
-            GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp))
+        # Unreify the value into the correct window.
+        try:
+          bundle.output(WindowedValue(**tv.value))
+        except TypeError:
+          bundle.output(
+              GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp))
       self.bundles.append(bundle)
 
   def finish_bundle(self):
@@ -431,6 +438,45 @@ class _WatermarkControllerEvaluator(_TransformEvaluator):
         self, self.bundles, [], None, {None: self._watermark})
 
 
+class _PairWithTimingEvaluator(_TransformEvaluator):
+  """TransformEvaluator for the PairWithTiming transform.
+
+  This transform takes an element as an input and outputs
+  KV(element, `TimingInfo`). Where the `TimingInfo` contains both the
+  processing time timestamp and watermark.
+  """
+  def __init__(
+      self,
+      evaluation_context,
+      applied_ptransform,
+      input_committed_bundle,
+      side_inputs):
+    assert not side_inputs
+    super(_PairWithTimingEvaluator, self).__init__(
+        evaluation_context,
+        applied_ptransform,
+        input_committed_bundle,
+        side_inputs)
+
+  def start_bundle(self):
+    main_output = list(self._outputs)[0]
+    self.bundle = self._evaluation_context.create_bundle(main_output)
+
+    watermark_manager = self._evaluation_context._watermark_manager
+    watermarks = watermark_manager.get_watermarks(self._applied_ptransform)
+
+    output_watermark = watermarks.output_watermark
+    now = Timestamp(seconds=watermark_manager._clock.time())
+    self.timing_info = TimingInfo(now, output_watermark)
+
+  def process_element(self, element):
+    element.value = (element.value, self.timing_info)
+    self.bundle.output(element)
+
+  def finish_bundle(self):
+    return TransformResult(self, [self.bundle], [], None, {})
+
+
 class _TestStreamEvaluator(_TransformEvaluator):
   """TransformEvaluator for the TestStream transform.
 
@@ -448,12 +494,12 @@ class _TestStreamEvaluator(_TransformEvaluator):
       input_committed_bundle,
       side_inputs):
     assert not side_inputs
-    self.test_stream = applied_ptransform.transform
     super(_TestStreamEvaluator, self).__init__(
         evaluation_context,
         applied_ptransform,
         input_committed_bundle,
         side_inputs)
+    self.test_stream = applied_ptransform.transform
 
   def start_bundle(self):
     self.current_index = 0
@@ -470,7 +516,20 @@ class _TestStreamEvaluator(_TransformEvaluator):
     # We can either have the _TestStream or the _WatermarkController to emit
     # the elements. We chose to emit in the _WatermarkController so that the
     # element is emitted at the correct watermark value.
-    for event in self.test_stream.events(self.current_index):
+
+    # Set up the correct watermark holds in the Watermark controllers and the
+    # TestStream so that the watermarks will not automatically advance to +inf
+    # when elements start streaming. This can happen multiple times in the first
+    # bundle, but the operations are idempotent and adding state to keep track
+    # of this would add unnecessary code complexity.
+    events = []
+    if self.watermark == MIN_TIMESTAMP:
+      for event in self.test_stream._set_up(self.test_stream.output_tags):
+        events.append(event)
+
+    events += [e for e in self.test_stream.events(self.current_index)]
+
+    for event in events:
       if isinstance(event, (ElementEvent, WatermarkEvent)):
         # The WATERMARK_CONTROL_TAG is used to hold the _TestStream's
         # watermark to -inf, then +inf-1, then +inf. This watermark progression
diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py
index 0f75be6..6d95ede 100644
--- a/sdks/python/apache_beam/testing/test_stream.py
+++ b/sdks/python/apache_beam/testing/test_stream.py
@@ -26,19 +26,30 @@ from __future__ import absolute_import
 from abc import ABCMeta
 from abc import abstractmethod
 from builtins import object
+from enum import Enum
 from functools import total_ordering
 
 from future.utils import with_metaclass
 
+import apache_beam as beam
 from apache_beam import coders
 from apache_beam import pvalue
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
+from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import core
 from apache_beam.transforms import window
+from apache_beam.transforms.timeutil import TimeDomain
+from apache_beam.transforms.userstate import TimerSpec
+from apache_beam.transforms.userstate import on_timer
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.utils import timestamp
+from apache_beam.utils.timestamp import MIN_TIMESTAMP
+from apache_beam.utils.timestamp import Duration
+from apache_beam.utils.timestamp import Timestamp
 from apache_beam.utils.windowed_value import WindowedValue
 
 __all__ = [
@@ -107,6 +118,9 @@ class ElementEvent(Event):
     self.tag = tag
 
   def __eq__(self, other):
+    if not isinstance(other, ElementEvent):
+      return False
+
     return (
         self.timestamped_values == other.timestamped_values and
         self.tag == other.tag)
@@ -115,6 +129,9 @@ class ElementEvent(Event):
     return hash(self.timestamped_values)
 
   def __lt__(self, other):
+    if not isinstance(other, ElementEvent):
+      raise TypeError
+
     return self.timestamped_values < other.timestamped_values
 
   def to_runner_api(self, element_coder):
@@ -129,6 +146,11 @@ class ElementEvent(Event):
             ],
             tag=tag))
 
+  def __repr__(self):
+    return 'ElementEvent: <{}, {}>'.format([(e.value, e.timestamp)
+                                            for e in self.timestamped_values],
+                                           self.tag)
+
 
 class WatermarkEvent(Event):
   """Watermark-advancing test stream event."""
@@ -137,12 +159,18 @@ class WatermarkEvent(Event):
     self.tag = tag
 
   def __eq__(self, other):
+    if not isinstance(other, WatermarkEvent):
+      return False
+
     return self.new_watermark == other.new_watermark and self.tag == other.tag
 
   def __hash__(self):
     return hash(str(self.new_watermark) + str(self.tag))
 
   def __lt__(self, other):
+    if not isinstance(other, WatermarkEvent):
+      raise TypeError
+
     return self.new_watermark < other.new_watermark
 
   def to_runner_api(self, unused_element_coder):
@@ -156,19 +184,28 @@ class WatermarkEvent(Event):
         AdvanceWatermark(
             new_watermark=self.new_watermark.micros // 1000, tag=tag))
 
+  def __repr__(self):
+    return 'WatermarkEvent: <{}, {}>'.format(self.new_watermark, self.tag)
+
 
 class ProcessingTimeEvent(Event):
   """Processing time-advancing test stream event."""
   def __init__(self, advance_by):
-    self.advance_by = timestamp.Duration.of(advance_by)
+    self.advance_by = Duration.of(advance_by)
 
   def __eq__(self, other):
+    if not isinstance(other, ProcessingTimeEvent):
+      return False
+
     return self.advance_by == other.advance_by
 
   def __hash__(self):
     return hash(self.advance_by)
 
   def __lt__(self, other):
+    if not isinstance(other, ProcessingTimeEvent):
+      raise TypeError
+
     return self.advance_by < other.advance_by
 
   def to_runner_api(self, unused_element_coder):
@@ -176,6 +213,9 @@ class ProcessingTimeEvent(Event):
         processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event.
         AdvanceProcessingTime(advance_duration=self.advance_by.micros // 1000))
 
+  def __repr__(self):
+    return 'ProcessingTimeEvent: <{}>'.format(self.advance_by)
+
 
 class TestStream(PTransform):
   """Test stream that generates events on an unbounded PCollection of elements.
@@ -308,6 +348,8 @@ class TestStream(PTransform):
     return self
 
   def to_runner_api_parameter(self, context):
+    # Sort the output tags so that the order is deterministic and we are able
+    # to test equality on a roundtrip through the to/from proto apis.
     return (
         common_urns.primitives.TEST_STREAM.urn,
         beam_runner_api_pb2.TestStreamPayload(
@@ -326,3 +368,299 @@ class TestStream(PTransform):
         coder=coder,
         events=[Event.from_runner_api(e, coder) for e in payload.events],
         output_tags=output_tags)
+
+
+class TimingInfo(object):
+  def __init__(self, processing_time, watermark):
+    self._processing_time = timestamp.Timestamp.of(processing_time)
+    self._watermark = timestamp.Timestamp.of(watermark)
+
+  @property
+  def processing_time(self):
+    return self._processing_time
+
+  @property
+  def watermark(self):
+    return self._watermark
+
+  def __repr__(self):
+    return '({}, {})'.format(self.processing_time, self.watermark)
+
+
+class PairWithTiming(PTransform):
+  """Pairs the input element with timing information.
+
+  Input: element; output: KV(element, timing information)
+  Where timing information := (processing time, watermark)
+
+  This is used in the ReverseTestStream implementation to replay watermark
+  advancements.
+  """
+
+  URN = "beam:transform:pair_with_timing:v1"
+
+  def expand(self, pcoll):
+    return pvalue.PCollection.from_(pcoll)
+
+
+class OutputFormat(Enum):
+  TEST_STREAM_EVENTS = 1
+  TEST_STREAM_FILE_RECORDS = 2
+  SERIALIZED_TEST_STREAM_FILE_RECORDS = 3
+
+
+class ReverseTestStream(PTransform):
+  """A Transform that can create TestStream events from a stream of elements.
+
+  This currently assumes that this the pipeline being run on a single machine
+  and elements come in order and are outputted in the same order that they came
+  in.
+  """
+  def __init__(
+      self, sample_resolution_sec, output_tag, coder=None, output_format=None):
+    self._sample_resolution_sec = sample_resolution_sec
+    self._output_tag = output_tag
+    self._output_format = output_format if output_format \
+                          else OutputFormat.TEST_STREAM_EVENTS
+    self._coder = coder if coder else beam.coders.FastPrimitivesCoder()
+
+  def expand(self, pcoll):
+    ret = (
+        pcoll
+        | beam.WindowInto(beam.window.GlobalWindows())
+
+        # First get the initial timing information. This will be used to start
+        # the periodic timers which will generate processing time and watermark
+        # advancements every `sample_resolution_sec`.
+        | 'initial timing' >> PairWithTiming()
+
+        # Next, map every element to the same key so that only a single timer is
+        # started for this given ReverseTestStream.
+        | 'first key' >> beam.Map(lambda x: (0, x))
+
+        # Next, pass-through each element which will be paired with its timing
+        # info in the next step. Also, start the periodic timers. We use timers
+        # in this situation to capture watermark advancements that occur when
+        # there are no elements being produced upstream.
+        | beam.ParDo(
+            _TimingEventGenerator(
+                output_tag=self._output_tag,
+                sample_resolution_sec=self._sample_resolution_sec))
+
+        # Next, retrieve the timing information for watermark events that were
+        # generated in the previous step. This is because elements generated
+        # through the timers don't have their timing information yet.
+        | 'timing info for watermarks' >> PairWithTiming()
+
+        # Re-key to the same key to keep global state.
+        | 'second key' >> beam.Map(lambda x: (0, x))
+
+        # Format the events properly.
+        | beam.ParDo(_TestStreamFormatter(self._coder, self._output_format)))
+
+    if self._output_format == OutputFormat.SERIALIZED_TEST_STREAM_FILE_RECORDS:
+
+      def serializer(e):
+        return e.SerializeToString()
+
+      ret = ret | 'serializer' >> beam.Map(serializer)
+
+    return ret
+
+
+class _TimingEventGenerator(beam.DoFn):
+  """Generates ProcessingTimeEvents and WatermarkEvents at a regular cadence.
+
+  The runner keeps the state of the clock (which may be faked) and the
+  watermarks, which are inaccessible to SDKs. This DoFn generates
+  ProcessingTimeEvents and WatermarkEvents at a specified sampling rate to
+  capture any clock or watermark advancements between elements.
+  """
+
+  # Used to return the initial timing information.
+  EXECUTE_ONCE_STATE = beam.transforms.userstate.BagStateSpec(
+      name='execute_once_state', coder=beam.coders.FastPrimitivesCoder())
+
+  # A processing time timer in an infinite loop that generates the events that
+  # will be paired with the TimingInfo from the runner.
+  TIMING_SAMPLER = TimerSpec('timing_sampler', TimeDomain.REAL_TIME)
+
+  def __init__(self, output_tag, sample_resolution_sec=0.1):
+    self._output_tag = output_tag
+    self._sample_resolution_sec = sample_resolution_sec
+
+  @on_timer(TIMING_SAMPLER)
+  def on_timing_sampler(
+      self,
+      timestamp=beam.DoFn.TimestampParam,
+      window=beam.DoFn.WindowParam,
+      timing_sampler=beam.DoFn.TimerParam(TIMING_SAMPLER)):
+    """Yields an unbounded stream of ProcessingTimeEvents and WatermarkEvents.
+
+    The returned events will be paired with the TimingInfo. This loop's only
+    purpose is to generate these events even when there are no elements.
+    """
+    next_sample_time = (timestamp.micros * 1e-6) + self._sample_resolution_sec
+    timing_sampler.set(next_sample_time)
+
+    # Generate two events, the delta since the last sample and a place-holder
+    # WatermarkEvent. This is a placeholder because we can't otherwise add the
+    # watermark from the runner to the event.
+    yield ProcessingTimeEvent(self._sample_resolution_sec)
+    yield WatermarkEvent(MIN_TIMESTAMP)
+
+  def process(
+      self,
+      e,
+      timestamp=beam.DoFn.TimestampParam,
+      window=beam.DoFn.WindowParam,
+      timing_sampler=beam.DoFn.TimerParam(TIMING_SAMPLER),
+      execute_once_state=beam.DoFn.StateParam(EXECUTE_ONCE_STATE)):
+
+    _, (element, timing_info) = e
+
+    # Only set the timers once and only send the header once.
+    first_time = next(execute_once_state.read(), True)
+    if first_time:
+      # Generate the initial timing events.
+      execute_once_state.add(False)
+      now_sec = timing_info.processing_time.micros * 1e-6
+      timing_sampler.set(now_sec + self._sample_resolution_sec)
+
+      # Here we capture the initial time offset and initial watermark. This is
+      # where we emit the TestStreamFileHeader.
+      yield TestStreamFileHeader(tag=self._output_tag)
+      yield ProcessingTimeEvent(
+          Duration(micros=timing_info.processing_time.micros))
+      yield WatermarkEvent(MIN_TIMESTAMP)
+    yield element
+
+
+class _TestStreamFormatter(beam.DoFn):
+  """Formats the events to the specified output format.
+  """
+
+  # In order to generate the processing time deltas, we need to keep track of
+  # the previous clock time we got from the runner.
+  PREV_SAMPLE_TIME_STATE = beam.transforms.userstate.BagStateSpec(
+      name='prev_sample_time_state', coder=beam.coders.FastPrimitivesCoder())
+
+  def __init__(self, coder, output_format):
+    self._coder = coder
+    self._output_format = output_format
+
+  def start_bundle(self):
+    self.elements = []
+    self.timing_events = []
+    self.header = None
+
+  def finish_bundle(self):
+    """Outputs all the buffered elements.
+    """
+    if self._output_format == OutputFormat.TEST_STREAM_EVENTS:
+      return self._output_as_events()
+    return self._output_as_records()
+
+  def process(
+      self,
+      e,
+      timestamp=beam.DoFn.TimestampParam,
+      prev_sample_time_state=beam.DoFn.StateParam(PREV_SAMPLE_TIME_STATE)):
+    """Buffers elements until the end of the bundle.
+
+    This buffers elements instead of emitting them immediately to keep elements
+    that come in the same bundle to be outputted in the same bundle.
+    """
+    _, (element, timing_info) = e
+
+    if isinstance(element, TestStreamFileHeader):
+      self.header = element
+    elif isinstance(element, WatermarkEvent):
+      # WatermarkEvents come in with a watermark of MIN_TIMESTAMP. Fill in the
+      # correct watermark from the runner here.
+      element.new_watermark = timing_info.watermark.micros
+      if element not in self.timing_events:
+        self.timing_events.append(element)
+
+    elif isinstance(element, ProcessingTimeEvent):
+      # Because the runner holds the clock, calculate the processing time delta
+      # here. The TestStream may have faked out the clock, and thus the
+      # delta calculated in the SDK with time.time() will be wrong.
+      prev_sample = next(prev_sample_time_state.read(), Timestamp())
+      prev_sample_time_state.clear()
+      prev_sample_time_state.add(timing_info.processing_time)
+
+      advance_by = timing_info.processing_time - prev_sample
+
+      element.advance_by = advance_by
+      self.timing_events.append(element)
+    else:
+      self.elements.append(TimestampedValue(element, timestamp))
+
+  def _output_as_events(self):
+    """Outputs buffered elements as TestStream events.
+    """
+    if self.timing_events:
+      yield WindowedValue(
+          self.timing_events, timestamp=0, windows=[beam.window.GlobalWindow()])
+
+    if self.elements:
+      yield WindowedValue([ElementEvent(self.elements)],
+                          timestamp=0,
+                          windows=[beam.window.GlobalWindow()])
+
+  def _output_as_records(self):
+    """Outputs buffered elements as TestStreamFileRecords.
+    """
+    if self.header:
+      yield WindowedValue(
+          self.header, timestamp=0, windows=[beam.window.GlobalWindow()])
+
+    if self.timing_events:
+      timing_events = self._timing_events_to_records(self.timing_events)
+      for r in timing_events:
+        yield WindowedValue(
+            r, timestamp=0, windows=[beam.window.GlobalWindow()])
+
+    if self.elements:
+      elements = self._elements_to_record(self.elements)
+      yield WindowedValue(
+          elements, timestamp=0, windows=[beam.window.GlobalWindow()])
+
+  def _timing_events_to_records(self, timing_events):
+    """Returns given timing_events as TestStreamFileRecords.
+    """
+    records = []
+    for e in self.timing_events:
+      if isinstance(e, ProcessingTimeEvent):
+        processing_time_event = TestStreamPayload.Event.AdvanceProcessingTime(
+            advance_duration=e.advance_by.micros)
+        records.append(
+            TestStreamFileRecord(
+                recorded_event=TestStreamPayload.Event(
+                    processing_time_event=processing_time_event)))
+
+      elif isinstance(e, WatermarkEvent):
+        watermark_event = TestStreamPayload.Event.AdvanceWatermark(
+            new_watermark=int(e.new_watermark))
+        records.append(
+            TestStreamFileRecord(
+                recorded_event=TestStreamPayload.Event(
+                    watermark_event=watermark_event)))
+
+    return records
+
+  def _elements_to_record(self, elements):
+    """Returns elements as TestStreamFileRecords.
+    """
+    elements = []
+    for tv in self.elements:
+      element_timestamp = tv.timestamp.micros
+      element = beam_runner_api_pb2.TestStreamPayload.TimestampedElement(
+          encoded_element=self._coder.encode(tv.value),
+          timestamp=element_timestamp)
+      elements.append(element)
+
+    element_event = TestStreamPayload.Event.AddElements(elements=elements)
+    return TestStreamFileRecord(
+        recorded_event=TestStreamPayload.Event(element_event=element_event))
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py
index 3ab7421..650953f 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -28,9 +28,14 @@ from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.portability import common_urns
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileHeader
+from apache_beam.portability.api.beam_interactive_api_pb2 import TestStreamFileRecord
+from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import ElementEvent
+from apache_beam.testing.test_stream import OutputFormat
 from apache_beam.testing.test_stream import ProcessingTimeEvent
+from apache_beam.testing.test_stream import ReverseTestStream
 from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.test_stream import WatermarkEvent
 from apache_beam.testing.util import assert_that
@@ -42,16 +47,23 @@ from apache_beam.transforms.window import FixedWindows
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.utils import timestamp
 from apache_beam.utils.timestamp import Timestamp
+from apache_beam.utils.windowed_value import PaneInfo
+from apache_beam.utils.windowed_value import PaneInfoTiming
 from apache_beam.utils.windowed_value import WindowedValue
 
 
 class TestStreamTest(unittest.TestCase):
   def test_basic_test_stream(self):
-    test_stream = (
-        TestStream().advance_watermark_to(0).add_elements([
-            'a', WindowedValue('b', 3, []), TimestampedValue('c', 6)
-        ]).advance_processing_time(10).advance_watermark_to(8).add_elements(
-            ['d']).advance_watermark_to_infinity())
+    test_stream = (TestStream()
+                   .advance_watermark_to(0)
+                   .add_elements([
+                       'a',
+                       WindowedValue('b', 3, []),
+                       TimestampedValue('c', 6)])
+                   .advance_processing_time(10)
+                   .advance_watermark_to(8)
+                   .add_elements(['d'])
+                   .advance_watermark_to_infinity())  # yapf: disable
     self.assertEqual(
         test_stream._events,
         [
@@ -87,15 +99,17 @@ class TestStreamTest(unittest.TestCase):
               [TimestampedValue('a', timestamp.MAX_TIMESTAMP)]))
 
   def test_basic_execution(self):
-    test_stream = (
-        TestStream().advance_watermark_to(10).add_elements([
-            'a', 'b', 'c'
-        ]).advance_watermark_to(20).add_elements(['d']).add_elements([
-            'e'
-        ]).advance_processing_time(10).advance_watermark_to(300).add_elements([
-            TimestampedValue('late', 12)
-        ]).add_elements([TimestampedValue('last', 310)
-                         ]).advance_watermark_to_infinity())
+    test_stream = (TestStream()
+                   .advance_watermark_to(10)
+                   .add_elements(['a', 'b', 'c'])
+                   .advance_watermark_to(20)
+                   .add_elements(['d'])
+                   .add_elements(['e'])
+                   .advance_processing_time(10)
+                   .advance_watermark_to(300)
+                   .add_elements([TimestampedValue('late', 12)])
+                   .add_elements([TimestampedValue('last', 310)])
+                   .advance_watermark_to_infinity())  # yapf: disable
 
     class RecordFn(beam.DoFn):
       def process(
@@ -134,12 +148,11 @@ class TestStreamTest(unittest.TestCase):
         TimestampedValue('2', 12),
         TimestampedValue('3', 13),
     ]
-    test_stream = \
-        (TestStream()
-             .advance_watermark_to(5, tag='letters')
-             .add_elements(letters_elements, tag='letters')
-             .advance_watermark_to(10, tag='numbers')
-         .add_elements(numbers_elements, tag='numbers')) # yapf: disable
+    test_stream = (TestStream()
+        .advance_watermark_to(5, tag='letters')
+        .add_elements(letters_elements, tag='letters')
+        .advance_watermark_to(10, tag='numbers')
+        .add_elements(numbers_elements, tag='numbers'))  # yapf: disable
 
     class RecordFn(beam.DoFn):
       def process(
@@ -263,6 +276,67 @@ class TestStreamTest(unittest.TestCase):
 
     p.run()
 
+  def test_dicts_not_interpreted_as_windowed_values(self):
+    test_stream = (TestStream()
+                   .advance_processing_time(10)
+                   .advance_watermark_to(10)
+                   .add_elements([{'a': 0, 'b': 1, 'c': 2}])
+                   .advance_watermark_to_infinity())  # yapf: disable
+
+    class RecordFn(beam.DoFn):
+      def process(
+          self,
+          element=beam.DoFn.ElementParam,
+          timestamp=beam.DoFn.TimestampParam):
+        yield (element, timestamp)
+
+    options = PipelineOptions()
+    options.view_as(StandardOptions).streaming = True
+    with TestPipeline(options=options) as p:
+      my_record_fn = RecordFn()
+      records = p | test_stream | beam.ParDo(my_record_fn)
+
+      assert_that(
+          records,
+          equal_to([
+              ({
+                  'a': 0, 'b': 1, 'c': 2
+              }, timestamp.Timestamp(10)),
+          ]))
+
+  def test_windowed_values_interpreted_correctly(self):
+    windowed_value_args = {
+        'value': 'a',
+        'timestamp': Timestamp(5),
+        'windows': [beam.window.IntervalWindow(5, 10)],
+        'pane_info': PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)
+    }
+    test_stream = (TestStream()
+                   .advance_processing_time(10)
+                   .advance_watermark_to(10)
+                   .add_elements([windowed_value_args])
+                   .advance_watermark_to_infinity())  # yapf: disable
+
+    class RecordFn(beam.DoFn):
+      def process(
+          self,
+          element=beam.DoFn.ElementParam,
+          timestamp=beam.DoFn.TimestampParam,
+          window=beam.DoFn.WindowParam):
+        yield (element, timestamp, window)
+
+    options = PipelineOptions()
+    options.view_as(StandardOptions).streaming = True
+    with TestPipeline(options=options) as p:
+      my_record_fn = RecordFn()
+      records = p | test_stream | beam.ParDo(my_record_fn)
+
+      assert_that(
+          records,
+          equal_to([
+              ('a', timestamp.Timestamp(5), beam.window.IntervalWindow(5, 10)),
+          ]))
+
   def test_gbk_execution_no_triggers(self):
     test_stream = (
         TestStream().advance_watermark_to(10).add_elements([
@@ -309,10 +383,12 @@ class TestStreamTest(unittest.TestCase):
     p.run()
 
   def test_gbk_execution_after_watermark_trigger(self):
-    test_stream = (
-        TestStream().advance_watermark_to(10).add_elements(
-            [TimestampedValue('a', 11)]).advance_watermark_to(20).add_elements(
-                [TimestampedValue('b', 21)]).advance_watermark_to_infinity())
+    test_stream = (TestStream()
+        .advance_watermark_to(10)
+        .add_elements([TimestampedValue('a', 11)])
+        .advance_watermark_to(20)
+        .add_elements([TimestampedValue('b', 21)])
+        .advance_watermark_to_infinity())  # yapf: disable
 
     options = PipelineOptions()
     options.view_as(StandardOptions).streaming = True
@@ -349,9 +425,11 @@ class TestStreamTest(unittest.TestCase):
     # Advance TestClock to (X + delta) and see the pipeline does finish
     # Possibly to the framework trigger_transcripts.yaml
 
-    test_stream = (
-        TestStream().advance_watermark_to(10).add_elements(
-            ['a']).advance_processing_time(5.1).advance_watermark_to_infinity())
+    test_stream = (TestStream()
+        .advance_watermark_to(10)
+        .add_elements(['a'])
+        .advance_processing_time(5.1)
+        .advance_watermark_to_infinity())  # yapf: disable
 
     options = PipelineOptions()
     options.view_as(StandardOptions).streaming = True
@@ -385,11 +463,11 @@ class TestStreamTest(unittest.TestCase):
     options.view_as(StandardOptions).streaming = True
     p = TestPipeline(options=options)
 
-    main_stream = (
-        p
-        |
-        'main TestStream' >> TestStream().advance_watermark_to(10).add_elements(
-            ['e']).advance_watermark_to_infinity())
+    main_stream = (p
+                   | 'main TestStream' >> TestStream()
+                   .advance_watermark_to(10)
+                   .add_elements(['e'])
+                   .advance_watermark_to_infinity())  # yapf: disable
     side = (
         p
         | beam.Create([2, 1, 4])
@@ -587,5 +665,272 @@ class TestStreamTest(unittest.TestCase):
     self.assertEqual(test_stream.coder, roundtrip_test_stream.coder)
 
 
+class ReverseTestStreamTest(unittest.TestCase):
+  def test_basic_execution(self):
+    test_stream = (TestStream()
+                   .advance_watermark_to(0)
+                   .advance_processing_time(5)
+                   .add_elements(['a', 'b', 'c'])
+                   .advance_watermark_to(2)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(4)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(6)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(8)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(10)
+                   .advance_processing_time(1)
+                   .add_elements([TimestampedValue('1', 15),
+                                  TimestampedValue('2', 15),
+                                  TimestampedValue('3', 15)]))  # yapf: disable
+
+    options = StandardOptions(streaming=True)
+    p = TestPipeline(options=options)
+
+    records = (
+        p
+        | test_stream
+        | ReverseTestStream(sample_resolution_sec=1, output_tag=None))
+
+    assert_that(
+        records,
+        equal_to_per_window({
+            beam.window.GlobalWindow(): [
+                [ProcessingTimeEvent(5), WatermarkEvent(0)],
+                [
+                    ElementEvent([
+                        TimestampedValue('a', 0),
+                        TimestampedValue('b', 0),
+                        TimestampedValue('c', 0)
+                    ])
+                ],
+                [ProcessingTimeEvent(1), WatermarkEvent(2000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(4000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(6000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(8000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(10000000)],
+                [
+                    ElementEvent([
+                        TimestampedValue('1', 15),
+                        TimestampedValue('2', 15),
+                        TimestampedValue('3', 15)
+                    ])
+                ],
+            ],
+        }))
+
+    p.run()
+
+  def test_windowing(self):
+    test_stream = (TestStream()
+                   .advance_watermark_to(0)
+                   .add_elements(['a', 'b', 'c'])
+                   .advance_processing_time(1)
+                   .advance_processing_time(1)
+                   .advance_processing_time(1)
+                   .advance_processing_time(1)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(5)
+                   .add_elements(['1', '2', '3'])
+                   .advance_processing_time(1)
+                   .advance_watermark_to(6)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(7)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(8)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(9)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(10)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(11)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(12)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(13)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(14)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(15)
+                   .advance_processing_time(1)
+                   )  # yapf: disable
+
+    options = StandardOptions(streaming=True)
+    p = TestPipeline(options=options)
+
+    records = (
+        p
+        | test_stream
+        | 'letter windows' >> beam.WindowInto(
+            FixedWindows(5),
+            accumulation_mode=trigger.AccumulationMode.DISCARDING)
+        | 'letter with key' >> beam.Map(lambda x: ('k', x))
+        | 'letter gbk' >> beam.GroupByKey()
+        | ReverseTestStream(sample_resolution_sec=1, output_tag=None))
+
+    assert_that(
+        records,
+        equal_to_per_window({
+            beam.window.GlobalWindow(): [
+                [ProcessingTimeEvent(5), WatermarkEvent(4999998)],
+                [
+                    ElementEvent(
+                        [TimestampedValue(('k', ['a', 'b', 'c']), 4.999999)])
+                ],
+                [ProcessingTimeEvent(1), WatermarkEvent(5000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(6000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(7000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(8000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(9000000)],
+                [
+                    ElementEvent(
+                        [TimestampedValue(('k', ['1', '2', '3']), 9.999999)])
+                ],
+                [ProcessingTimeEvent(1), WatermarkEvent(10000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(11000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(12000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(13000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(14000000)],
+                [ProcessingTimeEvent(1), WatermarkEvent(15000000)],
+            ],
+        }))
+
+    p.run()
+
+  def test_basic_execution_in_records_format(self):
+    test_stream = (TestStream()
+                   .advance_watermark_to(0)
+                   .advance_processing_time(5)
+                   .add_elements(['a', 'b', 'c'])
+                   .advance_watermark_to(2)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(4)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(6)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(8)
+                   .advance_processing_time(1)
+                   .advance_watermark_to(10)
+                   .advance_processing_time(1)
+                   .add_elements([TimestampedValue('1', 15),
+                                  TimestampedValue('2', 15),
+                                  TimestampedValue('3', 15)]))  # yapf: disable
+
+    options = StandardOptions(streaming=True)
+    p = TestPipeline(options=options)
+
+    coder = beam.coders.FastPrimitivesCoder()
+    records = (
+        p
+        | test_stream
+        | ReverseTestStream(
+            sample_resolution_sec=1,
+            coder=coder,
+            output_format=OutputFormat.TEST_STREAM_FILE_RECORDS,
+            output_tag=None)
+        | 'stringify' >> beam.Map(str))
+
+    assert_that(
+        records,
+        equal_to_per_window({
+            beam.window.GlobalWindow(): [
+                str(TestStreamFileHeader()),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            processing_time_event=TestStreamPayload.Event.
+                            AdvanceProcessingTime(advance_duration=5000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            watermark_event=TestStreamPayload.Event.
+                            AdvanceWatermark(new_watermark=0)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            element_event=TestStreamPayload.Event.AddElements(
+                                elements=[
+                                    TestStreamPayload.TimestampedElement(
+                                        encoded_element=coder.encode('a'),
+                                        timestamp=0),
+                                    TestStreamPayload.TimestampedElement(
+                                        encoded_element=coder.encode('b'),
+                                        timestamp=0),
+                                    TestStreamPayload.TimestampedElement(
+                                        encoded_element=coder.encode('c'),
+                                        timestamp=0),
+                                ])))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            watermark_event=TestStreamPayload.Event.
+                            AdvanceWatermark(new_watermark=2000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            processing_time_event=TestStreamPayload.Event.
+                            AdvanceProcessingTime(advance_duration=1000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            watermark_event=TestStreamPayload.Event.
+                            AdvanceWatermark(new_watermark=4000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            processing_time_event=TestStreamPayload.Event.
+                            AdvanceProcessingTime(advance_duration=1000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            watermark_event=TestStreamPayload.Event.
+                            AdvanceWatermark(new_watermark=6000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            processing_time_event=TestStreamPayload.Event.
+                            AdvanceProcessingTime(advance_duration=1000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            watermark_event=TestStreamPayload.Event.
+                            AdvanceWatermark(new_watermark=8000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            processing_time_event=TestStreamPayload.Event.
+                            AdvanceProcessingTime(advance_duration=1000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            watermark_event=TestStreamPayload.Event.
+                            AdvanceWatermark(new_watermark=10000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            processing_time_event=TestStreamPayload.Event.
+                            AdvanceProcessingTime(advance_duration=1000000)))),
+                str(
+                    TestStreamFileRecord(
+                        recorded_event=TestStreamPayload.Event(
+                            element_event=TestStreamPayload.Event.AddElements(
+                                elements=[
+                                    TestStreamPayload.TimestampedElement(
+                                        encoded_element=coder.encode('1'),
+                                        timestamp=15000000),
+                                    TestStreamPayload.TimestampedElement(
+                                        encoded_element=coder.encode('2'),
+                                        timestamp=15000000),
+                                    TestStreamPayload.TimestampedElement(
+                                        encoded_element=coder.encode('3'),
+                                        timestamp=15000000),
+                                ])))),
+            ],
+        }))
+
+    p.run()
+
+
 if __name__ == '__main__':
   unittest.main()