You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/12/06 21:31:01 UTC

[beam] branch master updated: [BEAM-1872] Add IdentityWindowFn for use in Reshuffle (#4040)

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c1eeb2  [BEAM-1872] Add IdentityWindowFn for use in Reshuffle (#4040)
9c1eeb2 is described below

commit 9c1eeb2d7efc65978d189ebc259fad6fbe2176ce
Author: Udi Meiri <ud...@users.noreply.github.com>
AuthorDate: Wed Dec 6 13:30:58 2017 -0800

    [BEAM-1872] Add IdentityWindowFn for use in Reshuffle (#4040)
    
    * Implement Reshuffle for Python SDK.
    
    Two flavors of Reshuffle: ReshufflePerKey operates on key-value pairs,
    while Reshuffle adds a random key to each element (key-value or other).
    
    Add _IdentityWindowFn, for internal use in Reshuffle.
    Add and pass current window to WindowFn.AssignContext, for
    IdentityWindowFn implementation.
    
    testing/util.py:
    - Extend assert_that with reify_windows keyword, allowing verification of
    timestamp values and windowing functions.
    - Add contains_in_any_order matcher.
---
 .../apache_beam/runners/worker/bundle_processor.py |   5 +-
 sdks/python/apache_beam/testing/util.py            |  41 +++-
 sdks/python/apache_beam/testing/util_test.py       |  42 ++++
 sdks/python/apache_beam/transforms/core.py         |   6 +-
 sdks/python/apache_beam/transforms/util.py         | 112 +++++++++-
 sdks/python/apache_beam/transforms/util_test.py    | 230 +++++++++++++++++++++
 sdks/python/apache_beam/transforms/window.py       |  12 +-
 sdks/python/apache_beam/utils/urns.py              |   5 +-
 8 files changed, 442 insertions(+), 11 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 0c46b81..94dca8b 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -529,9 +529,10 @@ def create(factory, transform_id, transform_proto, parameter, consumers):
     def __init__(self, windowing):
       self.windowing = windowing
 
-    def process(self, element, timestamp=beam.DoFn.TimestampParam):
+    def process(self, element, timestamp=beam.DoFn.TimestampParam,
+                window=beam.DoFn.WindowParam):
       new_windows = self.windowing.windowfn.assign(
-          WindowFn.AssignContext(timestamp, element=element))
+          WindowFn.AssignContext(timestamp, element=element, window=window))
       yield WindowedValue(element, timestamp, new_windows)
   from apache_beam.transforms.core import Windowing
   from apache_beam.transforms.window import WindowFn, WindowedValue
diff --git a/sdks/python/apache_beam/testing/util.py b/sdks/python/apache_beam/testing/util.py
index 34c15f9..2f18bde 100644
--- a/sdks/python/apache_beam/testing/util.py
+++ b/sdks/python/apache_beam/testing/util.py
@@ -19,13 +19,16 @@
 
 from __future__ import absolute_import
 
+import collections
 import glob
 import tempfile
 
 from apache_beam import pvalue
 from apache_beam.transforms import window
 from apache_beam.transforms.core import Create
+from apache_beam.transforms.core import DoFn
 from apache_beam.transforms.core import Map
+from apache_beam.transforms.core import ParDo
 from apache_beam.transforms.core import WindowInto
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.util import CoGroupByKey
@@ -37,6 +40,7 @@ __all__ = [
     'is_empty',
     # open_shards is internal and has no backwards compatibility guarantees.
     'open_shards',
+    'TestWindowedValue',
     ]
 
 
@@ -46,11 +50,32 @@ class BeamAssertException(Exception):
   pass
 
 
+# Used for reifying timestamps and windows for assert_that matchers.
+TestWindowedValue = collections.namedtuple(
+    'TestWindowedValue', 'value timestamp windows')
+
+
+def contains_in_any_order(iterable):
+  """Creates an object that matches another iterable if they both have the
+  same count of items.
+
+  Arguments:
+    iterable: An iterable of hashable objects.
+  """
+  class InAnyOrder(object):
+    def __init__(self, iterable):
+      self._counter = collections.Counter(iterable)
+
+    def __eq__(self, other):
+      return self._counter == collections.Counter(other)
+
+  return InAnyOrder(iterable)
+
+
 # Note that equal_to always sorts the expected and actual since what we
 # compare are PCollections for which there is no guaranteed order.
 # However the sorting does not go beyond top level therefore [1,2] and [2,1]
 # are considered equal and [[1,2]] and [[2,1]] are not.
-# TODO(silviuc): Add contains_in_any_order-style matchers.
 def equal_to(expected):
   expected = list(expected)
 
@@ -72,7 +97,7 @@ def is_empty():
   return _empty
 
 
-def assert_that(actual, matcher, label='assert_that'):
+def assert_that(actual, matcher, label='assert_that', reify_windows=False):
   """A PTransform that checks a PCollection has an expected value.
 
   Note that assert_that should be used only for testing pipelines since the
@@ -85,15 +110,27 @@ def assert_that(actual, matcher, label='assert_that'):
       expectations and raises BeamAssertException if they are not met.
     label: Optional string label. This is needed in case several assert_that
       transforms are introduced in the same pipeline.
+    reify_windows: If True, matcher is passed a list of TestWindowedValue.
 
   Returns:
     Ignored.
   """
   assert isinstance(actual, pvalue.PCollection)
 
+  class ReifyTimestampWindow(DoFn):
+    def process(self, element, timestamp=DoFn.TimestampParam,
+                window=DoFn.WindowParam):
+      # This returns TestWindowedValue instead of
+      # beam.utils.windowed_value.WindowedValue because ParDo will extract
+      # the timestamp and window out of the latter.
+      return [TestWindowedValue(element, timestamp, [window])]
+
   class AssertThat(PTransform):
 
     def expand(self, pcoll):
+      if reify_windows:
+        pcoll = pcoll | ParDo(ReifyTimestampWindow())
+
       # We must have at least a single element to ensure the matcher
       # code gets run even if the input pcollection is empty.
       keyed_singleton = pcoll.pipeline | Create([(None, None)])
diff --git a/sdks/python/apache_beam/testing/util_test.py b/sdks/python/apache_beam/testing/util_test.py
index 9d38693..e4e8694 100644
--- a/sdks/python/apache_beam/testing/util_test.py
+++ b/sdks/python/apache_beam/testing/util_test.py
@@ -21,9 +21,13 @@ import unittest
 
 from apache_beam import Create
 from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.testing.util import is_empty
+from apache_beam.transforms.window import GlobalWindow
+from apache_beam.transforms.window import IntervalWindow
+from apache_beam.utils.timestamp import MIN_TIMESTAMP
 
 
 class UtilTest(unittest.TestCase):
@@ -32,11 +36,49 @@ class UtilTest(unittest.TestCase):
     with TestPipeline() as p:
       assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3]))
 
+  def test_assert_that_passes_empty_equal_to(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([]), equal_to([]))
+
+  def test_assert_that_passes_empty_is_empty(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([]), is_empty())
+
+  def test_windowed_value_passes(self):
+    expected = [TestWindowedValue(v, MIN_TIMESTAMP, [GlobalWindow()])
+                for v in [1, 2, 3]]
+    with TestPipeline() as p:
+      assert_that(p | Create([2, 3, 1]), equal_to(expected), reify_windows=True)
+
   def test_assert_that_fails(self):
     with self.assertRaises(Exception):
       with TestPipeline() as p:
         assert_that(p | Create([1, 10, 100]), equal_to([1, 2, 3]))
 
+  def test_windowed_value_assert_fail_unmatched_value(self):
+    expected = [TestWindowedValue(v + 1, MIN_TIMESTAMP, [GlobalWindow()])
+                for v in [1, 2, 3]]
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([2, 3, 1]), equal_to(expected),
+                    reify_windows=True)
+
+  def test_windowed_value_assert_fail_unmatched_timestamp(self):
+    expected = [TestWindowedValue(v, 1, [GlobalWindow()])
+                for v in [1, 2, 3]]
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([2, 3, 1]), equal_to(expected),
+                    reify_windows=True)
+
+  def test_windowed_value_assert_fail_unmatched_window(self):
+    expected = [TestWindowedValue(v, MIN_TIMESTAMP, [IntervalWindow(0, 1)])
+                for v in [1, 2, 3]]
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([2, 3, 1]), equal_to(expected),
+                    reify_windows=True)
+
   def test_assert_that_fails_on_empty_input(self):
     with self.assertRaises(Exception):
       with TestPipeline() as p:
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index e650b39..533634d 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1579,8 +1579,10 @@ class WindowInto(ParDo):
     def __init__(self, windowing):
       self.windowing = windowing
 
-    def process(self, element, timestamp=DoFn.TimestampParam):
-      context = WindowFn.AssignContext(timestamp, element=element)
+    def process(self, element, timestamp=DoFn.TimestampParam,
+                window=DoFn.WindowParam):
+      context = WindowFn.AssignContext(timestamp, element=element,
+                                       window=window)
       new_windows = self.windowing.windowfn.assign(context)
       yield WindowedValue(element, context.timestamp, new_windows)
 
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 85d4975..332387a 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -22,6 +22,7 @@ from __future__ import absolute_import
 
 import collections
 import contextlib
+import random
 import time
 
 from apache_beam import typehints
@@ -29,12 +30,20 @@ from apache_beam.metrics import Metrics
 from apache_beam.transforms import window
 from apache_beam.transforms.core import CombinePerKey
 from apache_beam.transforms.core import DoFn
+from apache_beam.transforms.core import FlatMap
 from apache_beam.transforms.core import Flatten
 from apache_beam.transforms.core import GroupByKey
 from apache_beam.transforms.core import Map
 from apache_beam.transforms.core import ParDo
+from apache_beam.transforms.core import WindowInto
 from apache_beam.transforms.ptransform import PTransform
 from apache_beam.transforms.ptransform import ptransform_fn
+from apache_beam.transforms.trigger import AccumulationMode
+from apache_beam.transforms.trigger import AfterCount
+from apache_beam.transforms.window import NonMergingWindowFn
+from apache_beam.transforms.window import TimestampCombiner
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.utils import urns
 from apache_beam.utils import windowed_value
 
 __all__ = [
@@ -43,10 +52,12 @@ __all__ = [
     'Keys',
     'KvSwap',
     'RemoveDuplicates',
+    'Reshuffle',
     'Values',
     ]
 
-
+K = typehints.TypeVariable('K')
+V = typehints.TypeVariable('V')
 T = typehints.TypeVariable('T')
 
 
@@ -423,3 +434,102 @@ class BatchElements(PTransform):
           self._batch_size_estimator))
     else:
       return pcoll | ParDo(_WindowAwareBatchingDoFn(self._batch_size_estimator))
+
+
+class _IdentityWindowFn(NonMergingWindowFn):
+  """Windowing function that preserves existing windows.
+
+  To be used internally with the Reshuffle transform.
+  Will raise an exception when used after DoFns that return TimestampedValue
+  elements.
+  """
+
+  def __init__(self, window_coder):
+    """Create a new WindowFn with compatible coder.
+    To be applied to PCollections with windows that are compatible with the
+    given coder.
+
+    Arguments:
+      window_coder: coders.Coder object to be used on windows.
+    """
+    super(_IdentityWindowFn, self).__init__()
+    if window_coder is None:
+      raise ValueError('window_coder should not be None')
+    self._window_coder = window_coder
+
+  def assign(self, assign_context):
+    if assign_context.window is None:
+      raise ValueError(
+          'assign_context.window should not be None. '
+          'This might be due to a DoFn returning a TimestampedValue.')
+    return [assign_context.window]
+
+  def get_window_coder(self):
+    return self._window_coder
+
+  def to_runner_api_parameter(self, unused_context):
+    pass  # Overridden by register_pickle_urn below.
+
+  urns.RunnerApiFn.register_pickle_urn(urns.RESHUFFLE_TRANSFORM)
+
+
+@typehints.with_input_types(typehints.KV[K, V])
+@typehints.with_output_types(typehints.KV[K, V])
+class ReshufflePerKey(PTransform):
+  """PTransform that returns a PCollection equivalent to its input,
+  but operationally provides some of the side effects of a GroupByKey,
+  in particular preventing fusion of the surrounding transforms,
+  checkpointing, and deduplication by id.
+
+  ReshufflePerKey is experimental. No backwards compatibility guarantees.
+  """
+
+  def expand(self, pcoll):
+    class ReifyTimestamps(DoFn):
+      def process(self, element, timestamp=DoFn.TimestampParam):
+        yield element[0], TimestampedValue(element[1], timestamp)
+
+    class RestoreTimestamps(DoFn):
+      def process(self, element, window=DoFn.WindowParam):
+        # Pass the current window since _IdentityWindowFn wouldn't know how
+        # to generate it.
+        yield windowed_value.WindowedValue(
+            (element[0], element[1].value), element[1].timestamp, [window])
+
+    windowing_saved = pcoll.windowing
+    result = (pcoll
+              | ParDo(ReifyTimestamps())
+              | 'IdentityWindow' >> WindowInto(
+                  _IdentityWindowFn(
+                      windowing_saved.windowfn.get_window_coder()),
+                  trigger=AfterCount(1),
+                  accumulation_mode=AccumulationMode.DISCARDING,
+                  timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST,
+                  )
+              | GroupByKey()
+              | 'ExpandIterable' >> FlatMap(
+                  lambda e: [(e[0], value) for value in e[1]])
+              | ParDo(RestoreTimestamps()))
+    result._windowing = windowing_saved
+    return result
+
+
+@typehints.with_input_types(T)
+@typehints.with_output_types(T)
+class Reshuffle(PTransform):
+  """PTransform that returns a PCollection equivalent to its input,
+  but operationally provides some of the side effects of a GroupByKey,
+  in particular preventing fusion of the surrounding transforms,
+  checkpointing, and deduplication by id.
+
+  Reshuffle adds a temporary random key to each element, performs a
+  ReshufflePerKey, and finally removes the temporary key.
+
+  Reshuffle is experimental. No backwards compatibility guarantees.
+  """
+
+  def expand(self, pcoll):
+    return (pcoll
+            | 'AddRandomKeys' >> Map(lambda t: (random.getrandbits(32), t))
+            | ReshufflePerKey()
+            | 'RemoveRandomKeys' >> Map(lambda t: t[1]))
diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py
index 6064e2c..0be4180 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -21,11 +21,24 @@ import time
 import unittest
 
 import apache_beam as beam
+from apache_beam.coders import coders
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import TestWindowedValue
 from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import contains_in_any_order
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import util
 from apache_beam.transforms import window
+from apache_beam.transforms.window import GlobalWindow
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.transforms.window import IntervalWindow
+from apache_beam.transforms.window import Sessions
+from apache_beam.transforms.window import SlidingWindows
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.utils import timestamp
+from apache_beam.utils.windowed_value import WindowedValue
 
 
 class FakeClock(object):
@@ -106,3 +119,220 @@ class BatchElementsTest(unittest.TestCase):
       with batch_estimator.record_time(actual_sizes[-1]):
         clock.sleep(batch_duration(actual_sizes[-1]))
     self.assertEqual(expected_sizes, actual_sizes)
+
+
+class IdentityWindowTest(unittest.TestCase):
+
+  def test_window_preserved(self):
+    expected_timestamp = timestamp.Timestamp(5)
+    expected_window = window.IntervalWindow(1.0, 2.0)
+
+    class AddWindowDoFn(beam.DoFn):
+      def process(self, element):
+        yield WindowedValue(
+            element, expected_timestamp, [expected_window])
+
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_windows = [
+        TestWindowedValue(kv, expected_timestamp, [expected_window])
+        for kv in data]
+    before_identity = (pipeline
+                       | 'start' >> beam.Create(data)
+                       | 'add_windows' >> beam.ParDo(AddWindowDoFn()))
+    assert_that(before_identity, equal_to(expected_windows),
+                label='before_identity', reify_windows=True)
+    after_identity = (before_identity
+                      | 'window' >> beam.WindowInto(
+                          beam.transforms.util._IdentityWindowFn(
+                              coders.IntervalWindowCoder())))
+    assert_that(after_identity, equal_to(expected_windows),
+                label='after_identity', reify_windows=True)
+    pipeline.run()
+
+  def test_no_window_context_fails(self):
+    expected_timestamp = timestamp.Timestamp(5)
+    # Assuming the default window function is window.GlobalWindows.
+    expected_window = window.GlobalWindow()
+
+    class AddTimestampDoFn(beam.DoFn):
+      def process(self, element):
+        yield window.TimestampedValue(element, expected_timestamp)
+
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_windows = [
+        TestWindowedValue(kv, expected_timestamp, [expected_window])
+        for kv in data]
+    before_identity = (pipeline
+                       | 'start' >> beam.Create(data)
+                       | 'add_timestamps' >> beam.ParDo(AddTimestampDoFn()))
+    assert_that(before_identity, equal_to(expected_windows),
+                label='before_identity', reify_windows=True)
+    after_identity = (before_identity
+                      | 'window' >> beam.WindowInto(
+                          beam.transforms.util._IdentityWindowFn(
+                              coders.GlobalWindowCoder()))
+                      # This DoFn will return TimestampedValues, making
+                      # WindowFn.AssignContext passed to IdentityWindowFn
+                      # contain a window of None. IdentityWindowFn should
+                      # raise an exception.
+                      | 'add_timestamps2' >> beam.ParDo(AddTimestampDoFn()))
+    assert_that(after_identity, equal_to(expected_windows),
+                label='after_identity', reify_windows=True)
+    with self.assertRaisesRegexp(ValueError, r'window.*None.*add_timestamps2'):
+      pipeline.run()
+
+
+class ReshuffleTest(unittest.TestCase):
+
+  def test_reshuffle_contents_unchanged(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+    result = (pipeline
+              | 'start' >> beam.Create(data)
+              | 'reshuffle' >> beam.Reshuffle())
+    assert_that(result, equal_to(data))
+    pipeline.run()
+
+  def test_reshuffle_after_gbk_contents_unchanged(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+    expected_result = [(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]
+
+    after_gbk = (pipeline
+                 | 'start' >> beam.Create(data)
+                 | 'group_by_key' >> beam.GroupByKey())
+    assert_that(after_gbk, equal_to(expected_result), label='after_gbk')
+    after_reshuffle = (after_gbk
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_result),
+                label='after_reshuffle')
+    pipeline.run()
+
+  def test_reshuffle_timestamps_unchanged(self):
+    pipeline = TestPipeline()
+    timestamp = 5
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 3)]
+    expected_result = [TestWindowedValue(v, timestamp, [GlobalWindow()])
+                       for v in data]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'add_timestamp' >> beam.Map(
+                            lambda v: beam.window.TimestampedValue(v,
+                                                                   timestamp)))
+    assert_that(before_reshuffle, equal_to(expected_result),
+                label='before_reshuffle', reify_windows=True)
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_result),
+                label='after_reshuffle', reify_windows=True)
+    pipeline.run()
+
+  def test_reshuffle_windows_unchanged(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_data = [TestWindowedValue(v, t, [w]) for (v, t, w) in
+                     [((1, [2, 1]), 4.0, IntervalWindow(1.0, 4.0)),
+                      ((2, [2, 1]), 4.0, IntervalWindow(1.0, 4.0)),
+                      ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
+                      ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'add_timestamp' >> beam.Map(
+                            lambda v: beam.window.TimestampedValue(v, v[1]))
+                        | 'window' >> beam.WindowInto(Sessions(gap_size=2))
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle', reify_windows=True)
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle', reify_windows=True)
+    pipeline.run()
+
+  def test_reshuffle_window_fn_preserved(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
+        ((1, 1), 1.0, IntervalWindow(1.0, 3.0)),
+        ((2, 1), 1.0, IntervalWindow(1.0, 3.0)),
+        ((3, 1), 1.0, IntervalWindow(1.0, 3.0)),
+        ((1, 2), 2.0, IntervalWindow(2.0, 4.0)),
+        ((2, 2), 2.0, IntervalWindow(2.0, 4.0)),
+        ((1, 4), 4.0, IntervalWindow(4.0, 6.0))]]
+    expected_merged_windows = [TestWindowedValue(v, t, [w]) for (v, t, w) in [
+        ((1, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+        ((2, contains_in_any_order([2, 1])), 4.0, IntervalWindow(1.0, 4.0)),
+        ((3, [1]), 3.0, IntervalWindow(1.0, 3.0)),
+        ((1, [4]), 6.0, IntervalWindow(4.0, 6.0))]]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'add_timestamp' >> beam.Map(
+                            lambda v: TimestampedValue(v, v[1]))
+                        | 'window' >> beam.WindowInto(Sessions(gap_size=2)))
+    assert_that(before_reshuffle, equal_to(expected_windows),
+                label='before_reshuffle', reify_windows=True)
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_windows),
+                label='after_reshuffle', reify_windows=True)
+    after_group = (after_reshuffle
+                   | 'group_by_key' >> beam.GroupByKey())
+    assert_that(after_group, equal_to(expected_merged_windows),
+                label='after_group', reify_windows=True)
+    pipeline.run()
+
+  def test_reshuffle_global_window(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'window' >> beam.WindowInto(GlobalWindows())
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle')
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle')
+    pipeline.run()
+
+  def test_reshuffle_sliding_window(self):
+    pipeline = TestPipeline()
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    window_size = 2
+    expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])] * window_size
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'window' >> beam.WindowInto(SlidingWindows(
+                            size=window_size, period=1))
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle')
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    # If Reshuffle applies the sliding window function a second time there
+    # should be extra values for each key.
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle')
+    pipeline.run()
+
+  def test_reshuffle_streaming_global_window(self):
+    options = PipelineOptions()
+    options.view_as(StandardOptions).streaming = True
+    pipeline = TestPipeline(options=options)
+    data = [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (1, 4)]
+    expected_data = [(1, [1, 2, 4]), (2, [1, 2]), (3, [1])]
+    before_reshuffle = (pipeline
+                        | 'start' >> beam.Create(data)
+                        | 'window' >> beam.WindowInto(GlobalWindows())
+                        | 'group_by_key' >> beam.GroupByKey())
+    assert_that(before_reshuffle, equal_to(expected_data),
+                label='before_reshuffle')
+    after_reshuffle = (before_reshuffle
+                       | 'reshuffle' >> beam.Reshuffle())
+    assert_that(after_reshuffle, equal_to(expected_data),
+                label='after reshuffle')
+    pipeline.run()
diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py
index 8c8bf33..ee9d6f9 100644
--- a/sdks/python/apache_beam/transforms/window.py
+++ b/sdks/python/apache_beam/transforms/window.py
@@ -114,13 +114,21 @@ class WindowFn(urns.RunnerApiFn):
   class AssignContext(object):
     """Context passed to WindowFn.assign()."""
 
-    def __init__(self, timestamp, element=None):
+    def __init__(self, timestamp, element=None, window=None):
       self.timestamp = Timestamp.of(timestamp)
       self.element = element
+      self.window = window
 
   @abc.abstractmethod
   def assign(self, assign_context):
-    """Associates a timestamp to an element."""
+    """Associates windows to an element.
+
+    Arguments:
+      assign_context: Instance of AssignContext.
+
+    Returns:
+      An iterable of BoundedWindow.
+    """
     raise NotImplementedError
 
   class MergeContext(object):
diff --git a/sdks/python/apache_beam/utils/urns.py b/sdks/python/apache_beam/utils/urns.py
index 1359f32..387c8d6 100644
--- a/sdks/python/apache_beam/utils/urns.py
+++ b/sdks/python/apache_beam/utils/urns.py
@@ -44,6 +44,7 @@ COMBINE_PER_KEY_TRANSFORM = "beam:ptransform:combine_per_key:v0.1"
 COMBINE_GROUPED_VALUES_TRANSFORM = "beam:ptransform:combine_grouped_values:v0.1"
 FLATTEN_TRANSFORM = "beam:ptransform:flatten:v0.1"
 READ_TRANSFORM = "beam:ptransform:read:v0.1"
+RESHUFFLE_TRANSFORM = "beam:ptransform:reshuffle:v0.1"
 WINDOW_INTO_TRANSFORM = "beam:ptransform:window_into:v0.1"
 
 PICKLED_SOURCE = "beam:source:pickled_python:v0.1"
@@ -90,9 +91,9 @@ class RunnerApiFn(object):
 
   @classmethod
   def register_urn(cls, urn, parameter_type, fn=None):
-    """Registeres a urn with a constructor.
+    """Registers a urn with a constructor.
 
-    For example, if 'beam:fn:foo' had paramter type FooPayload, one could
+    For example, if 'beam:fn:foo' had parameter type FooPayload, one could
     write `RunnerApiFn.register_urn('bean:fn:foo', FooPayload, foo_from_proto)`
     where foo_from_proto took as arguments a FooPayload and a PipelineContext.
     This function can also be used as a decorator rather than passing the

-- 
To stop receiving notification emails like this one, please contact
['"commits@beam.apache.org" <co...@beam.apache.org>'].