You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2017/04/21 17:52:53 UTC

[09/50] [abbrv] beam git commit: [BEAM-1994] Remove Flink examples package

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
new file mode 100644
index 0000000..4c826d1
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
@@ -0,0 +1,600 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.hamcrest.Matchers.emptyIterable;
+import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.base.Function;
+import com.google.common.base.Predicate;
+import com.google.common.collect.FluentIterable;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.Collections;
+import java.util.HashMap;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.StatefulDoFnRunner;
+import org.apache.beam.runners.flink.FlinkPipelineOptions;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.PCollectionViewTesting;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.join.RawUnionValue;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.TimeDomain;
+import org.apache.beam.sdk.util.Timer;
+import org.apache.beam.sdk.util.TimerSpec;
+import org.apache.beam.sdk.util.TimerSpecs;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link DoFnOperator}.
+ */
+@RunWith(JUnit4.class)
+public class DoFnOperatorTest {
+
+  // views and windows for testing side inputs
+  private static final long WINDOW_MSECS_1 = 100;
+  private static final long WINDOW_MSECS_2 = 500;
+
+  private WindowingStrategy<Object, IntervalWindow> windowingStrategy1 =
+      WindowingStrategy.of(FixedWindows.of(new Duration(WINDOW_MSECS_1)));
+
+  private PCollectionView<Iterable<String>> view1 =
+      PCollectionViewTesting.testingView(
+          new TupleTag<Iterable<WindowedValue<String>>>() {},
+          new PCollectionViewTesting.IdentityViewFn<String>(),
+          StringUtf8Coder.of(),
+          windowingStrategy1);
+
+  private WindowingStrategy<Object, IntervalWindow> windowingStrategy2 =
+      WindowingStrategy.of(FixedWindows.of(new Duration(WINDOW_MSECS_2)));
+
+  private PCollectionView<Iterable<String>> view2 =
+      PCollectionViewTesting.testingView(
+          new TupleTag<Iterable<WindowedValue<String>>>() {},
+          new PCollectionViewTesting.IdentityViewFn<String>(),
+          StringUtf8Coder.of(),
+          windowingStrategy2);
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testSingleOutput() throws Exception {
+
+    WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
+        WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
+
+    TupleTag<String> outputTag = new TupleTag<>("main-output");
+
+    DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
+        new IdentityDoFn<String>(),
+        windowedValueCoder,
+        outputTag,
+        Collections.<TupleTag<?>>emptyList(),
+        new DoFnOperator.DefaultOutputManagerFactory(),
+        WindowingStrategy.globalDefault(),
+        new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+        Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+        PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+        null);
+
+    OneInputStreamOperatorTestHarness<WindowedValue<String>, String> testHarness =
+        new OneInputStreamOperatorTestHarness<>(doFnOperator);
+
+    testHarness.open();
+
+    testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("Hello")));
+
+    assertThat(
+        this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        contains(WindowedValue.valueInGlobalWindow("Hello")));
+
+    testHarness.close();
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testMultiOutputOutput() throws Exception {
+
+    WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
+        WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
+
+    TupleTag<String> mainOutput = new TupleTag<>("main-output");
+    TupleTag<String> additionalOutput1 = new TupleTag<>("output-1");
+    TupleTag<String> additionalOutput2 = new TupleTag<>("output-2");
+    ImmutableMap<TupleTag<?>, Integer> outputMapping = ImmutableMap.<TupleTag<?>, Integer>builder()
+        .put(mainOutput, 1)
+        .put(additionalOutput1, 2)
+        .put(additionalOutput2, 3)
+        .build();
+
+    DoFnOperator<String, String, RawUnionValue> doFnOperator = new DoFnOperator<>(
+        new MultiOutputDoFn(additionalOutput1, additionalOutput2),
+        windowedValueCoder,
+        mainOutput,
+        ImmutableList.<TupleTag<?>>of(additionalOutput1, additionalOutput2),
+        new DoFnOperator.MultiOutputOutputManagerFactory(outputMapping),
+        WindowingStrategy.globalDefault(),
+        new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+        Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+        PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+        null);
+
+    OneInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue> testHarness =
+        new OneInputStreamOperatorTestHarness<>(doFnOperator);
+
+    testHarness.open();
+
+    testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("one")));
+    testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("two")));
+    testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("hello")));
+
+    assertThat(
+        this.stripStreamRecordFromRawUnion(testHarness.getOutput()),
+        contains(
+            new RawUnionValue(2, WindowedValue.valueInGlobalWindow("extra: one")),
+            new RawUnionValue(3, WindowedValue.valueInGlobalWindow("extra: two")),
+            new RawUnionValue(1, WindowedValue.valueInGlobalWindow("got: hello")),
+            new RawUnionValue(2, WindowedValue.valueInGlobalWindow("got: hello")),
+            new RawUnionValue(3, WindowedValue.valueInGlobalWindow("got: hello"))));
+
+    testHarness.close();
+  }
+
+  @Test
+  public void testLateDroppingForStatefulFn() throws Exception {
+
+    WindowingStrategy<Object, IntervalWindow> windowingStrategy =
+        WindowingStrategy.of(FixedWindows.of(new Duration(10)));
+
+    DoFn<Integer, String> fn = new DoFn<Integer, String>() {
+
+      @StateId("state")
+      private final StateSpec<Object, ValueState<String>> stateSpec =
+          StateSpecs.value(StringUtf8Coder.of());
+
+      @ProcessElement
+      public void processElement(ProcessContext context) {
+        context.output(context.element().toString());
+      }
+    };
+
+    WindowedValue.FullWindowedValueCoder<Integer> windowedValueCoder =
+        WindowedValue.getFullCoder(
+            VarIntCoder.of(),
+            windowingStrategy.getWindowFn().windowCoder());
+
+    TupleTag<String> outputTag = new TupleTag<>("main-output");
+
+    DoFnOperator<Integer, String, WindowedValue<String>> doFnOperator = new DoFnOperator<>(
+        fn,
+        windowedValueCoder,
+        outputTag,
+        Collections.<TupleTag<?>>emptyList(),
+        new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<String>>(),
+        windowingStrategy,
+        new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+        Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+        PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+        VarIntCoder.of() /* key coder */);
+
+    OneInputStreamOperatorTestHarness<WindowedValue<Integer>, WindowedValue<String>> testHarness =
+        new KeyedOneInputStreamOperatorTestHarness<>(
+            doFnOperator,
+            new KeySelector<WindowedValue<Integer>, Integer>() {
+              @Override
+              public Integer getKey(WindowedValue<Integer> integerWindowedValue) throws Exception {
+                return integerWindowedValue.getValue();
+              }
+            },
+            new CoderTypeInformation<>(VarIntCoder.of()));
+
+    testHarness.open();
+
+    testHarness.processWatermark(0);
+
+    IntervalWindow window1 = new IntervalWindow(new Instant(0), Duration.millis(10));
+
+    // this should not be late
+    testHarness.processElement(
+        new StreamRecord<>(WindowedValue.of(13, new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+    assertThat(
+        this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        contains(WindowedValue.of("13", new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+    testHarness.getOutput().clear();
+
+    testHarness.processWatermark(9);
+
+    // this should still not be considered late
+    testHarness.processElement(
+        new StreamRecord<>(WindowedValue.of(17, new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+    assertThat(
+        this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        contains(WindowedValue.of("17", new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+    testHarness.getOutput().clear();
+
+    testHarness.processWatermark(10);
+
+    // this should now be considered late
+    testHarness.processElement(
+        new StreamRecord<>(WindowedValue.of(17, new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+    assertThat(
+        this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        emptyIterable());
+
+    testHarness.close();
+  }
+
+  @Test
+  public void testStateGCForStatefulFn() throws Exception {
+
+    WindowingStrategy<Object, IntervalWindow> windowingStrategy =
+        WindowingStrategy.of(FixedWindows.of(new Duration(10))).withAllowedLateness(Duration.ZERO);
+
+    final String timerId = "boo";
+    final String stateId = "dazzle";
+
+    final int offset = 5000;
+    final int timerOutput = 4093;
+
+    DoFn<KV<String, Integer>, KV<String, Integer>> fn =
+        new DoFn<KV<String, Integer>, KV<String, Integer>>() {
+
+          @TimerId(timerId)
+          private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+          @StateId(stateId)
+          private final StateSpec<Object, ValueState<String>> stateSpec =
+              StateSpecs.value(StringUtf8Coder.of());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext context,
+              @TimerId(timerId) Timer timer,
+              @StateId(stateId) ValueState<String> state,
+              BoundedWindow window) {
+            timer.set(window.maxTimestamp());
+            state.write(context.element().getKey());
+            context.output(
+                KV.of(context.element().getKey(), context.element().getValue() + offset));
+          }
+
+          @OnTimer(timerId)
+          public void onTimer(OnTimerContext context, @StateId(stateId) ValueState<String> state) {
+            context.output(KV.of(state.read(), timerOutput));
+          }
+        };
+
+    WindowedValue.FullWindowedValueCoder<KV<String, Integer>> windowedValueCoder =
+        WindowedValue.getFullCoder(
+            KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()),
+            windowingStrategy.getWindowFn().windowCoder());
+
+    TupleTag<KV<String, Integer>> outputTag = new TupleTag<>("main-output");
+
+    DoFnOperator<
+        KV<String, Integer>, KV<String, Integer>, WindowedValue<KV<String, Integer>>> doFnOperator =
+        new DoFnOperator<>(
+            fn,
+            windowedValueCoder,
+            outputTag,
+            Collections.<TupleTag<?>>emptyList(),
+            new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<KV<String, Integer>>>(),
+            windowingStrategy,
+            new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+            Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+            PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+            StringUtf8Coder.of() /* key coder */);
+
+    KeyedOneInputStreamOperatorTestHarness<
+        String,
+        WindowedValue<KV<String, Integer>>,
+        WindowedValue<KV<String, Integer>>> testHarness =
+        new KeyedOneInputStreamOperatorTestHarness<>(
+            doFnOperator,
+            new KeySelector<WindowedValue<KV<String, Integer>>, String>() {
+              @Override
+              public String getKey(
+                  WindowedValue<KV<String, Integer>> kvWindowedValue) throws Exception {
+                return kvWindowedValue.getValue().getKey();
+              }
+            },
+            new CoderTypeInformation<>(StringUtf8Coder.of()));
+
+    testHarness.open();
+
+    testHarness.processWatermark(0);
+
+    assertEquals(0, testHarness.numKeyedStateEntries());
+
+    IntervalWindow window1 = new IntervalWindow(new Instant(0), Duration.millis(10));
+
+    testHarness.processElement(
+        new StreamRecord<>(
+            WindowedValue.of(KV.of("key1", 5), new Instant(1), window1, PaneInfo.NO_FIRING)));
+
+    testHarness.processElement(
+        new StreamRecord<>(
+            WindowedValue.of(KV.of("key2", 7), new Instant(3), window1, PaneInfo.NO_FIRING)));
+
+    assertThat(
+        this.<KV<String, Integer>>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        contains(
+            WindowedValue.of(
+                KV.of("key1", 5 + offset), new Instant(1), window1, PaneInfo.NO_FIRING),
+            WindowedValue.of(
+                KV.of("key2", 7 + offset), new Instant(3), window1, PaneInfo.NO_FIRING)));
+
+    assertEquals(2, testHarness.numKeyedStateEntries());
+
+    testHarness.getOutput().clear();
+
+    // this should trigger both the window.maxTimestamp() timer and the GC timer
+    // this tests that the GC timer fires after the user timer
+    testHarness.processWatermark(
+        window1.maxTimestamp()
+            .plus(windowingStrategy.getAllowedLateness())
+            .plus(StatefulDoFnRunner.TimeInternalsCleanupTimer.GC_DELAY_MS)
+            .getMillis());
+
+    assertThat(
+        this.<KV<String, Integer>>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        contains(
+            WindowedValue.of(
+                KV.of("key1", timerOutput), new Instant(9), window1, PaneInfo.NO_FIRING),
+            WindowedValue.of(
+                KV.of("key2", timerOutput), new Instant(9), window1, PaneInfo.NO_FIRING)));
+
+    // ensure the state was garbage collected
+    assertEquals(0, testHarness.numKeyedStateEntries());
+
+    testHarness.close();
+  }
+
+  public void testSideInputs(boolean keyed) throws Exception {
+
+    WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
+        WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
+
+    TupleTag<String> outputTag = new TupleTag<>("main-output");
+
+    ImmutableMap<Integer, PCollectionView<?>> sideInputMapping =
+        ImmutableMap.<Integer, PCollectionView<?>>builder()
+            .put(1, view1)
+            .put(2, view2)
+            .build();
+
+    Coder<String> keyCoder = null;
+    if (keyed) {
+      keyCoder = StringUtf8Coder.of();
+    }
+
+    DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
+        new IdentityDoFn<String>(),
+        windowedValueCoder,
+        outputTag,
+        Collections.<TupleTag<?>>emptyList(),
+        new DoFnOperator.DefaultOutputManagerFactory<String>(),
+        WindowingStrategy.globalDefault(),
+        sideInputMapping, /* side-input mapping */
+        ImmutableList.<PCollectionView<?>>of(view1, view2), /* side inputs */
+        PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+        keyCoder);
+
+    TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, String> testHarness =
+        new TwoInputStreamOperatorTestHarness<>(doFnOperator);
+
+    if (keyed) {
+      // we use a dummy key for the second input since it is considered to be broadcast
+      testHarness = new KeyedTwoInputStreamOperatorTestHarness<>(
+          doFnOperator,
+          new StringKeySelector(),
+          new DummyKeySelector(),
+          BasicTypeInfo.STRING_TYPE_INFO);
+    }
+
+    testHarness.open();
+
+    IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(100));
+    IntervalWindow secondWindow = new IntervalWindow(new Instant(0), new Instant(500));
+
+    // test the keep of sideInputs events
+    testHarness.processElement2(
+        new StreamRecord<>(
+            new RawUnionValue(
+                1,
+                valuesInWindow(ImmutableList.of("hello", "ciao"), new Instant(0), firstWindow))));
+    testHarness.processElement2(
+        new StreamRecord<>(
+            new RawUnionValue(
+                2,
+                valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(0), secondWindow))));
+
+    // push in a regular elements
+    WindowedValue<String> helloElement = valueInWindow("Hello", new Instant(0), firstWindow);
+    WindowedValue<String> worldElement = valueInWindow("World", new Instant(1000), firstWindow);
+    testHarness.processElement1(new StreamRecord<>(helloElement));
+    testHarness.processElement1(new StreamRecord<>(worldElement));
+
+    // test the keep of pushed-back events
+    testHarness.processElement2(
+        new StreamRecord<>(
+            new RawUnionValue(
+                1,
+                valuesInWindow(ImmutableList.of("hello", "ciao"),
+                    new Instant(1000), firstWindow))));
+    testHarness.processElement2(
+        new StreamRecord<>(
+            new RawUnionValue(
+                2,
+                valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(1000), secondWindow))));
+
+    assertThat(
+        this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+        contains(helloElement, worldElement));
+
+    testHarness.close();
+
+  }
+
+  /**
+   * {@link TwoInputStreamOperatorTestHarness} support OperatorStateBackend,
+   * but don't support KeyedStateBackend. So we just test sideInput of normal ParDo.
+   */
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testNormalParDoSideInputs() throws Exception {
+    testSideInputs(false);
+  }
+
+  @Test
+  public void testKeyedSideInputs() throws Exception {
+    testSideInputs(true);
+  }
+
+  private <T> Iterable<WindowedValue<T>> stripStreamRecordFromWindowedValue(
+      Iterable<Object> input) {
+
+    return FluentIterable.from(input).filter(new Predicate<Object>() {
+      @Override
+      public boolean apply(@Nullable Object o) {
+        return o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof WindowedValue;
+      }
+    }).transform(new Function<Object, WindowedValue<T>>() {
+      @Nullable
+      @Override
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      public WindowedValue<T> apply(@Nullable Object o) {
+        if (o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof WindowedValue) {
+          return (WindowedValue) ((StreamRecord) o).getValue();
+        }
+        throw new RuntimeException("unreachable");
+      }
+    });
+  }
+
+  private Iterable<RawUnionValue> stripStreamRecordFromRawUnion(Iterable<Object> input) {
+    return FluentIterable.from(input).filter(new Predicate<Object>() {
+      @Override
+      public boolean apply(@Nullable Object o) {
+        return o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof RawUnionValue;
+      }
+    }).transform(new Function<Object, RawUnionValue>() {
+      @Nullable
+      @Override
+      @SuppressWarnings({"unchecked", "rawtypes"})
+      public RawUnionValue apply(@Nullable Object o) {
+        if (o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof RawUnionValue) {
+          return (RawUnionValue) ((StreamRecord) o).getValue();
+        }
+        throw new RuntimeException("unreachable");
+      }
+    });
+  }
+
+  private static class MultiOutputDoFn extends DoFn<String, String> {
+    private TupleTag<String> additionalOutput1;
+    private TupleTag<String> additionalOutput2;
+
+    public MultiOutputDoFn(TupleTag<String> additionalOutput1, TupleTag<String> additionalOutput2) {
+      this.additionalOutput1 = additionalOutput1;
+      this.additionalOutput2 = additionalOutput2;
+    }
+
+    @ProcessElement
+    public void processElement(ProcessContext c) throws Exception {
+      if (c.element().equals("one")) {
+        c.output(additionalOutput1, "extra: one");
+      } else if (c.element().equals("two")) {
+        c.output(additionalOutput2, "extra: two");
+      } else {
+        c.output("got: " + c.element());
+        c.output(additionalOutput1, "got: " + c.element());
+        c.output(additionalOutput2, "got: " + c.element());
+      }
+    }
+  }
+
+  private static class IdentityDoFn<T> extends DoFn<T, T> {
+    @ProcessElement
+    public void processElement(ProcessContext c) throws Exception {
+      c.output(c.element());
+    }
+  }
+
+  @SuppressWarnings({"unchecked", "rawtypes"})
+  private WindowedValue<Iterable<?>> valuesInWindow(
+      Iterable<?> values, Instant timestamp, BoundedWindow window) {
+    return (WindowedValue) WindowedValue.of(values, timestamp, window, PaneInfo.NO_FIRING);
+  }
+
+  @SuppressWarnings({"unchecked", "rawtypes"})
+  private <T> WindowedValue<T> valueInWindow(
+      T value, Instant timestamp, BoundedWindow window) {
+    return WindowedValue.of(value, timestamp, window, PaneInfo.NO_FIRING);
+  }
+
+
+  private static class DummyKeySelector implements KeySelector<RawUnionValue, String> {
+    @Override
+    public String getKey(RawUnionValue stringWindowedValue) throws Exception {
+      return "dummy_key";
+    }
+  }
+
+  private static class StringKeySelector implements KeySelector<WindowedValue<String>, String> {
+    @Override
+    public String getKey(WindowedValue<String> stringWindowedValue) throws Exception {
+      return stringWindowedValue.getValue();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
new file mode 100644
index 0000000..7e7d1e1
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThat;
+
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.GroupingState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkBroadcastStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkBroadcastStateInternalsTest {
+  private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+  private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+  private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR =
+      StateTags.value("stringValue", StringUtf8Coder.of());
+  private static final StateTag<Object, CombiningState<Integer, int[], Integer>>
+      SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal(
+          "sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+
+  FlinkBroadcastStateInternals<String> underTest;
+
+  @Before
+  public void initStateInternals() {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      OperatorStateBackend operatorStateBackend =
+          backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), "");
+      underTest = new FlinkBroadcastStateInternals<>(1, operatorStateBackend);
+
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testValue() throws Exception {
+    ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR);
+
+    assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+    assertNotEquals(
+        underTest.state(NAMESPACE_2, STRING_VALUE_ADDR),
+        value);
+
+    assertThat(value.read(), Matchers.nullValue());
+    value.write("hello");
+    assertThat(value.read(), Matchers.equalTo("hello"));
+    value.write("world");
+    assertThat(value.read(), Matchers.equalTo("world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.nullValue());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeBagIntoSource() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMergeBagIntoNewNamespace() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+    BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+    assertThat(bag1.read(), Matchers.emptyIterable());
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testCombiningValue() throws Exception {
+    GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
+
+    assertThat(value.read(), Matchers.equalTo(0));
+    value.add(2);
+    assertThat(value.read(), Matchers.equalTo(2));
+
+    value.add(3);
+    assertThat(value.read(), Matchers.equalTo(5));
+
+    value.clear();
+    assertThat(value.read(), Matchers.equalTo(0));
+    assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value);
+  }
+
+  @Test
+  public void testCombiningIsEmpty() throws Exception {
+    GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add(5);
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeCombiningValueIntoSource() throws Exception {
+    CombiningState<Integer, int[], Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+    CombiningState<Integer, int[], Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    assertThat(value1.read(), Matchers.equalTo(11));
+    assertThat(value2.read(), Matchers.equalTo(10));
+
+    // Merging clears the old values and updates the result value.
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+    assertThat(value1.read(), Matchers.equalTo(21));
+    assertThat(value2.read(), Matchers.equalTo(0));
+  }
+
+  @Test
+  public void testMergeCombiningValueIntoNewNamespace() throws Exception {
+    CombiningState<Integer, int[], Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+    CombiningState<Integer, int[], Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+    CombiningState<Integer, int[], Integer> value3 =
+        underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+    // Merging clears the old values and updates the result value.
+    assertThat(value1.read(), Matchers.equalTo(0));
+    assertThat(value2.read(), Matchers.equalTo(0));
+    assertThat(value3.read(), Matchers.equalTo(21));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
new file mode 100644
index 0000000..5433d07
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.operators.KeyContext;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkKeyGroupStateInternalsTest {
+  private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+  private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+
+  FlinkKeyGroupStateInternals<String> underTest;
+  private KeyedStateBackend keyedStateBackend;
+
+  @Before
+  public void initStateInternals() {
+    try {
+      keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+      underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend);
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups,
+                                                   KeyGroupRange keyGroupRange) {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend(
+          new DummyEnvironment("test", 1, 0),
+          new JobID(),
+          "test_op",
+          new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()),
+          numberOfKeyGroups,
+          keyGroupRange,
+          new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+      keyedStateBackend.setCurrentKey(ByteBuffer.wrap(
+          CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1")));
+      return keyedStateBackend;
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeBagIntoSource() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMergeBagIntoNewNamespace() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+    BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+    assertThat(bag1.read(), Matchers.emptyIterable());
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testKeyGroupAndCheckpoint() throws Exception {
+    // assign to keyGroup 0
+    ByteBuffer key0 = ByteBuffer.wrap(
+        CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111"));
+    // assign to keyGroup 1
+    ByteBuffer key1 = ByteBuffer.wrap(
+        CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222"));
+    FlinkKeyGroupStateInternals<String> allState;
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+      allState = new FlinkKeyGroupStateInternals<>(
+          StringUtf8Coder.of(), keyedStateBackend);
+      BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR);
+      keyedStateBackend.setCurrentKey(key0);
+      valueForNamespace0.add("0");
+      valueForNamespace1.add("2");
+      keyedStateBackend.setCurrentKey(key1);
+      valueForNamespace0.add("1");
+      valueForNamespace1.add("3");
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3"));
+    }
+
+    ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader();
+
+    // 1. scale up
+    ByteArrayOutputStream out0 = new ByteArrayOutputStream();
+    allState.snapshotKeyGroupState(0, new DataOutputStream(out0));
+    DataInputStream in0 = new DataInputStream(
+        new ByteArrayInputStream(out0.toByteArray()));
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 0));
+      FlinkKeyGroupStateInternals<String> state0 =
+          new FlinkKeyGroupStateInternals<>(
+              StringUtf8Coder.of(), keyedStateBackend);
+      state0.restoreKeyGroupState(0, in0, classLoader);
+      BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR);
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2"));
+    }
+
+    ByteArrayOutputStream out1 = new ByteArrayOutputStream();
+    allState.snapshotKeyGroupState(1, new DataOutputStream(out1));
+    DataInputStream in1 = new DataInputStream(
+        new ByteArrayInputStream(out1.toByteArray()));
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(1, 1));
+      FlinkKeyGroupStateInternals<String> state1 =
+          new FlinkKeyGroupStateInternals<>(
+              StringUtf8Coder.of(), keyedStateBackend);
+      state1.restoreKeyGroupState(1, in1, classLoader);
+      BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR);
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3"));
+    }
+
+    // 2. scale down
+    {
+      KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+      FlinkKeyGroupStateInternals<String> newAllState = new FlinkKeyGroupStateInternals<>(
+          StringUtf8Coder.of(), keyedStateBackend);
+      in0.reset();
+      in1.reset();
+      newAllState.restoreKeyGroupState(0, in0, classLoader);
+      newAllState.restoreKeyGroupState(1, in1, classLoader);
+      BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR);
+      BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR);
+      assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1"));
+      assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3"));
+    }
+
+  }
+
+  private static class TestKeyContext implements KeyContext {
+
+    private Object key;
+
+    @Override
+    public void setCurrentKey(Object key) {
+      this.key = key;
+    }
+
+    @Override
+    public Object getCurrentKey() {
+      return key;
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
new file mode 100644
index 0000000..08ae0c4
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkSplitStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkSplitStateInternalsTest {
+  private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+
+  FlinkSplitStateInternals<String> underTest;
+
+  @Before
+  public void initStateInternals() {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      OperatorStateBackend operatorStateBackend =
+          backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), "");
+      underTest = new FlinkSplitStateInternals<>(operatorStateBackend);
+
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
new file mode 100644
index 0000000..d140271
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThat;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFns;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.GroupingState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkStateInternalsTest {
+  private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10));
+  private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+  private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+  private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+  private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR =
+      StateTags.value("stringValue", StringUtf8Coder.of());
+  private static final StateTag<Object, CombiningState<Integer, int[], Integer>>
+      SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal(
+          "sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+  private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+      StateTags.bag("stringBag", StringUtf8Coder.of());
+  private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
+      WATERMARK_EARLIEST_ADDR =
+      StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp());
+  private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
+      WATERMARK_LATEST_ADDR =
+      StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtLatestInputTimestamp());
+  private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> WATERMARK_EOW_ADDR =
+      StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEndOfWindow());
+
+  FlinkStateInternals<String> underTest;
+
+  @Before
+  public void initStateInternals() {
+    MemoryStateBackend backend = new MemoryStateBackend();
+    try {
+      AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend(
+          new DummyEnvironment("test", 1, 0),
+          new JobID(),
+          "test_op",
+          new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()),
+          1,
+          new KeyGroupRange(0, 0),
+          new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+      underTest = new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of());
+
+      keyedStateBackend.setCurrentKey(
+          ByteBuffer.wrap(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Hello")));
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
+  public void testValue() throws Exception {
+    ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR);
+
+    assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+    assertNotEquals(
+        underTest.state(NAMESPACE_2, STRING_VALUE_ADDR),
+        value);
+
+    assertThat(value.read(), Matchers.nullValue());
+    value.write("hello");
+    assertThat(value.read(), Matchers.equalTo("hello"));
+    value.write("world");
+    assertThat(value.read(), Matchers.equalTo("world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.nullValue());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+
+  }
+
+  @Test
+  public void testBag() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+    assertThat(value.read(), Matchers.emptyIterable());
+    value.add("hello");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+    value.add("world");
+    assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+  }
+
+  @Test
+  public void testBagIsEmpty() throws Exception {
+    BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeBagIntoSource() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMergeBagIntoNewNamespace() throws Exception {
+    BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+    BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+    BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+    bag1.add("Hello");
+    bag2.add("World");
+    bag1.add("!");
+
+    StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+    // Reading the merged bag gets both the contents
+    assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+    assertThat(bag1.read(), Matchers.emptyIterable());
+    assertThat(bag2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testCombiningValue() throws Exception {
+    GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
+
+    assertThat(value.read(), Matchers.equalTo(0));
+    value.add(2);
+    assertThat(value.read(), Matchers.equalTo(2));
+
+    value.add(3);
+    assertThat(value.read(), Matchers.equalTo(5));
+
+    value.clear();
+    assertThat(value.read(), Matchers.equalTo(0));
+    assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value);
+  }
+
+  @Test
+  public void testCombiningIsEmpty() throws Exception {
+    GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add(5);
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeCombiningValueIntoSource() throws Exception {
+    CombiningState<Integer, int[], Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+    CombiningState<Integer, int[], Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    assertThat(value1.read(), Matchers.equalTo(11));
+    assertThat(value2.read(), Matchers.equalTo(10));
+
+    // Merging clears the old values and updates the result value.
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+    assertThat(value1.read(), Matchers.equalTo(21));
+    assertThat(value2.read(), Matchers.equalTo(0));
+  }
+
+  @Test
+  public void testMergeCombiningValueIntoNewNamespace() throws Exception {
+    CombiningState<Integer, int[], Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+    CombiningState<Integer, int[], Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+    CombiningState<Integer, int[], Integer> value3 =
+        underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+    // Merging clears the old values and updates the result value.
+    assertThat(value1.read(), Matchers.equalTo(0));
+    assertThat(value2.read(), Matchers.equalTo(0));
+    assertThat(value3.read(), Matchers.equalTo(21));
+  }
+
+  @Test
+  public void testWatermarkEarliestState() throws Exception {
+    WatermarkHoldState<BoundedWindow> value =
+        underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR)));
+
+    assertThat(value.read(), Matchers.nullValue());
+    value.add(new Instant(2000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+    value.add(new Instant(3000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+    value.add(new Instant(1000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(1000)));
+
+    value.clear();
+    assertThat(value.read(), Matchers.equalTo(null));
+    assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value);
+  }
+
+  @Test
+  public void testWatermarkLatestState() throws Exception {
+    WatermarkHoldState<BoundedWindow> value =
+        underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR)));
+
+    assertThat(value.read(), Matchers.nullValue());
+    value.add(new Instant(2000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+    value.add(new Instant(3000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+
+    value.add(new Instant(1000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+
+    value.clear();
+    assertThat(value.read(), Matchers.equalTo(null));
+    assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value);
+  }
+
+  @Test
+  public void testWatermarkEndOfWindowState() throws Exception {
+    WatermarkHoldState<BoundedWindow> value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR));
+    assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR)));
+
+    assertThat(value.read(), Matchers.nullValue());
+    value.add(new Instant(2000));
+    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+    value.clear();
+    assertThat(value.read(), Matchers.equalTo(null));
+    assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value);
+  }
+
+  @Test
+  public void testWatermarkStateIsEmpty() throws Exception {
+    WatermarkHoldState<BoundedWindow> value =
+        underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add(new Instant(1000));
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeEarliestWatermarkIntoSource() throws Exception {
+    WatermarkHoldState<BoundedWindow> value1 =
+        underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+    WatermarkHoldState<BoundedWindow> value2 =
+        underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR);
+
+    value1.add(new Instant(3000));
+    value2.add(new Instant(5000));
+    value1.add(new Instant(4000));
+    value2.add(new Instant(2000));
+
+    // Merging clears the old values and updates the merged value.
+    StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1);
+
+    assertThat(value1.read(), Matchers.equalTo(new Instant(2000)));
+    assertThat(value2.read(), Matchers.equalTo(null));
+  }
+
+  @Test
+  public void testMergeLatestWatermarkIntoSource() throws Exception {
+    WatermarkHoldState<BoundedWindow> value1 =
+        underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR);
+    WatermarkHoldState<BoundedWindow> value2 =
+        underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR);
+    WatermarkHoldState<BoundedWindow> value3 =
+        underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR);
+
+    value1.add(new Instant(3000));
+    value2.add(new Instant(5000));
+    value1.add(new Instant(4000));
+    value2.add(new Instant(2000));
+
+    // Merging clears the old values and updates the result value.
+    StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1);
+
+    // Merging clears the old values and updates the result value.
+    assertThat(value3.read(), Matchers.equalTo(new Instant(5000)));
+    assertThat(value1.read(), Matchers.equalTo(null));
+    assertThat(value2.read(), Matchers.equalTo(null));
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java
new file mode 100644
index 0000000..663b910
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import com.google.common.base.Joiner;
+import java.io.Serializable;
+import java.util.Arrays;
+import org.apache.beam.runners.flink.FlinkTestPipeline;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.flink.streaming.util.StreamingProgramTestBase;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * Test for GroupByNullKey.
+ */
+public class GroupByNullKeyTest extends StreamingProgramTestBase implements Serializable {
+
+
+  protected String resultPath;
+
+  static final String[] EXPECTED_RESULT = new String[] {
+      "k: null v: user1 user1 user1 user2 user2 user2 user2 user3"
+  };
+
+  public GroupByNullKeyTest(){
+  }
+
+  @Override
+  protected void preSubmit() throws Exception {
+    resultPath = getTempDirPath("result");
+  }
+
+  @Override
+  protected void postSubmit() throws Exception {
+    compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath);
+  }
+
+  /**
+   * DoFn extracting user and timestamp.
+   */
+  private static class ExtractUserAndTimestamp extends DoFn<KV<Integer, String>, String> {
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      KV<Integer, String> record = c.element();
+      int timestamp = record.getKey();
+      String userName = record.getValue();
+      if (userName != null) {
+        // Sets the implicit timestamp field to be used in windowing.
+        c.outputWithTimestamp(userName, new Instant(timestamp));
+      }
+    }
+  }
+
+  @Override
+  protected void testProgram() throws Exception {
+
+    Pipeline p = FlinkTestPipeline.createForStreaming();
+
+    PCollection<String> output =
+      p.apply(Create.of(Arrays.asList(
+          KV.<Integer, String>of(0, "user1"),
+          KV.<Integer, String>of(1, "user1"),
+          KV.<Integer, String>of(2, "user1"),
+          KV.<Integer, String>of(10, "user2"),
+          KV.<Integer, String>of(1, "user2"),
+          KV.<Integer, String>of(15000, "user2"),
+          KV.<Integer, String>of(12000, "user2"),
+          KV.<Integer, String>of(25000, "user3"))))
+          .apply(ParDo.of(new ExtractUserAndTimestamp()))
+          .apply(Window.<String>into(FixedWindows.of(Duration.standardHours(1)))
+              .triggering(AfterWatermark.pastEndOfWindow())
+              .withAllowedLateness(Duration.ZERO)
+              .discardingFiredPanes())
+
+          .apply(ParDo.of(new DoFn<String, KV<Void, String>>() {
+            @ProcessElement
+            public void processElement(ProcessContext c) throws Exception {
+              String elem = c.element();
+              c.output(KV.<Void, String>of(null, elem));
+            }
+          }))
+          .apply(GroupByKey.<Void, String>create())
+          .apply(ParDo.of(new DoFn<KV<Void, Iterable<String>>, String>() {
+            @ProcessElement
+            public void processElement(ProcessContext c) throws Exception {
+              KV<Void, Iterable<String>> elem = c.element();
+              StringBuilder str = new StringBuilder();
+              str.append("k: " + elem.getKey() + " v:");
+              for (String v : elem.getValue()) {
+                str.append(" " + v);
+              }
+              c.output(str.toString());
+            }
+          }));
+    output.apply(TextIO.Write.to(resultPath));
+    p.run();
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
new file mode 100644
index 0000000..3a08088
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An unbounded source for testing the unbounded sources framework code.
+ *
+ * <p>Each split of this sources produces records of the form KV(split_id, i),
+ * where i counts up from 0.  Each record has a timestamp of i, and the watermark
+ * accurately tracks these timestamps.  The reader will occasionally return false
+ * from {@code advance}, in order to simulate a source where not all the data is
+ * available immediately.
+ */
+public class TestCountingSource
+    extends UnboundedSource<KV<Integer, Integer>, TestCountingSource.CounterMark> {
+  private static final Logger LOG = LoggerFactory.getLogger(TestCountingSource.class);
+
+  private static List<Integer> finalizeTracker;
+  private final int numMessagesPerShard;
+  private final int shardNumber;
+  private final boolean dedup;
+  private final boolean throwOnFirstSnapshot;
+  private final boolean allowSplitting;
+
+  /**
+   * We only allow an exception to be thrown from getCheckpointMark
+   * at most once. This must be static since the entire TestCountingSource
+   * instance may re-serialized when the pipeline recovers and retries.
+   */
+  private static boolean thrown = false;
+
+  public static void setFinalizeTracker(List<Integer> finalizeTracker) {
+    TestCountingSource.finalizeTracker = finalizeTracker;
+  }
+
+  public TestCountingSource(int numMessagesPerShard) {
+    this(numMessagesPerShard, 0, false, false, true);
+  }
+
+  public TestCountingSource withDedup() {
+    return new TestCountingSource(
+        numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot, true);
+  }
+
+  private TestCountingSource withShardNumber(int shardNumber) {
+    return new TestCountingSource(
+        numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true);
+  }
+
+  public TestCountingSource withThrowOnFirstSnapshot(boolean throwOnFirstSnapshot) {
+    return new TestCountingSource(
+        numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true);
+  }
+
+  public TestCountingSource withoutSplitting() {
+    return new TestCountingSource(
+        numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, false);
+  }
+
+  private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup,
+                             boolean throwOnFirstSnapshot, boolean allowSplitting) {
+    this.numMessagesPerShard = numMessagesPerShard;
+    this.shardNumber = shardNumber;
+    this.dedup = dedup;
+    this.throwOnFirstSnapshot = throwOnFirstSnapshot;
+    this.allowSplitting = allowSplitting;
+  }
+
+  public int getShardNumber() {
+    return shardNumber;
+  }
+
+  @Override
+  public List<TestCountingSource> split(
+      int desiredNumSplits, PipelineOptions options) {
+    List<TestCountingSource> splits = new ArrayList<>();
+    int numSplits = allowSplitting ? desiredNumSplits : 1;
+    for (int i = 0; i < numSplits; i++) {
+      splits.add(withShardNumber(i));
+    }
+    return splits;
+  }
+
+  class CounterMark implements UnboundedSource.CheckpointMark {
+    int current;
+
+    public CounterMark(int current) {
+      this.current = current;
+    }
+
+    @Override
+    public void finalizeCheckpoint() {
+      if (finalizeTracker != null) {
+        finalizeTracker.add(current);
+      }
+    }
+  }
+
+  @Override
+  public Coder<CounterMark> getCheckpointMarkCoder() {
+    return DelegateCoder.of(
+        VarIntCoder.of(),
+        new DelegateCoder.CodingFunction<CounterMark, Integer>() {
+          @Override
+          public Integer apply(CounterMark input) {
+            return input.current;
+          }
+        },
+        new DelegateCoder.CodingFunction<Integer, CounterMark>() {
+          @Override
+          public CounterMark apply(Integer input) {
+            return new CounterMark(input);
+          }
+        });
+  }
+
+  @Override
+  public boolean requiresDeduping() {
+    return dedup;
+  }
+
+  /**
+   * Public only so that the checkpoint can be conveyed from {@link #getCheckpointMark()} to
+   * {@link TestCountingSource#createReader(PipelineOptions, CounterMark)} without cast.
+   */
+  public class CountingSourceReader extends UnboundedReader<KV<Integer, Integer>> {
+    private int current;
+
+    public CountingSourceReader(int startingPoint) {
+      this.current = startingPoint;
+    }
+
+    @Override
+    public boolean start() {
+      return advance();
+    }
+
+    @Override
+    public boolean advance() {
+      if (current >= numMessagesPerShard - 1) {
+        return false;
+      }
+      // If testing dedup, occasionally insert a duplicate value;
+      if (current >= 0 && dedup && ThreadLocalRandom.current().nextInt(5) == 0) {
+        return true;
+      }
+      current++;
+      return true;
+    }
+
+    @Override
+    public KV<Integer, Integer> getCurrent() {
+      return KV.of(shardNumber, current);
+    }
+
+    @Override
+    public Instant getCurrentTimestamp() {
+      return new Instant(current);
+    }
+
+    @Override
+    public byte[] getCurrentRecordId() {
+      try {
+        return encodeToByteArray(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), getCurrent());
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public void close() {}
+
+    @Override
+    public TestCountingSource getCurrentSource() {
+      return TestCountingSource.this;
+    }
+
+    @Override
+    public Instant getWatermark() {
+      // The watermark is a promise about future elements, and the timestamps of elements are
+      // strictly increasing for this source.
+      return new Instant(current + 1);
+    }
+
+    @Override
+    public CounterMark getCheckpointMark() {
+      if (throwOnFirstSnapshot && !thrown) {
+        thrown = true;
+        LOG.error("Throwing exception while checkpointing counter");
+        throw new RuntimeException("failed during checkpoint");
+      }
+      // The checkpoint can assume all records read, including the current, have
+      // been commited.
+      return new CounterMark(current);
+    }
+
+    @Override
+    public long getSplitBacklogBytes() {
+      return 7L;
+    }
+  }
+
+  @Override
+  public CountingSourceReader createReader(
+      PipelineOptions options, @Nullable CounterMark checkpointMark) {
+    if (checkpointMark == null) {
+      LOG.debug("creating reader");
+    } else {
+      LOG.debug("restoring reader from checkpoint with current = {}", checkpointMark.current);
+    }
+    return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : -1);
+  }
+
+  @Override
+  public void validate() {}
+
+  @Override
+  public Coder<KV<Integer, Integer>> getDefaultOutputCoder() {
+    return KvCoder.of(VarIntCoder.of(), VarIntCoder.of());
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java
new file mode 100644
index 0000000..9e6bba8
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import com.google.api.services.bigquery.model.TableRow;
+import com.google.common.base.Joiner;
+import java.io.Serializable;
+import java.util.Arrays;
+import org.apache.beam.runners.flink.FlinkTestPipeline;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.flink.streaming.util.StreamingProgramTestBase;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+
+/**
+ * Session window test.
+ */
+public class TopWikipediaSessionsITCase extends StreamingProgramTestBase implements Serializable {
+  protected String resultPath;
+
+  public TopWikipediaSessionsITCase(){
+  }
+
+  static final String[] EXPECTED_RESULT = new String[] {
+      "user: user1 value:3",
+      "user: user1 value:1",
+      "user: user2 value:4",
+      "user: user2 value:6",
+      "user: user3 value:7",
+      "user: user3 value:2"
+  };
+
+  @Override
+  protected void preSubmit() throws Exception {
+    resultPath = getTempDirPath("result");
+  }
+
+  @Override
+  protected void postSubmit() throws Exception {
+    compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath);
+  }
+
+  @Override
+  protected void testProgram() throws Exception {
+
+    Pipeline p = FlinkTestPipeline.createForStreaming();
+
+    Long now = (System.currentTimeMillis() + 10000) / 1000;
+
+    PCollection<KV<String, Long>> output =
+      p.apply(Create.of(Arrays.asList(new TableRow().set("timestamp", now).set
+          ("contributor_username", "user1"), new TableRow().set("timestamp", now + 10).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now).set
+          ("contributor_username", "user1"), new TableRow().set("timestamp", now + 2).set
+          ("contributor_username", "user1"), new TableRow().set("timestamp", now).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 1).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 5).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 7).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 8).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 200).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 230).set
+          ("contributor_username", "user1"), new TableRow().set("timestamp", now + 230).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 240).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now + 245).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 235).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 236).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 237).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 238).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 239).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 240).set
+          ("contributor_username", "user3"), new TableRow().set("timestamp", now + 241).set
+          ("contributor_username", "user2"), new TableRow().set("timestamp", now)
+          .set("contributor_username", "user3"))))
+
+
+
+      .apply(ParDo.of(new DoFn<TableRow, String>() {
+        @ProcessElement
+        public void processElement(ProcessContext c) throws Exception {
+          TableRow row = c.element();
+          long timestamp = (Integer) row.get("timestamp");
+          String userName = (String) row.get("contributor_username");
+          if (userName != null) {
+            // Sets the timestamp field to be used in windowing.
+            c.outputWithTimestamp(userName, new Instant(timestamp * 1000L));
+          }
+        }
+      }))
+
+      .apply(Window.<String>into(Sessions.withGapDuration(Duration.standardMinutes(1))))
+
+      .apply(Count.<String>perElement());
+
+    PCollection<String> format = output.apply(ParDo.of(new DoFn<KV<String, Long>, String>() {
+      @ProcessElement
+      public void processElement(ProcessContext c) throws Exception {
+        KV<String, Long> el = c.element();
+        String out = "user: " + el.getKey() + " value:" + el.getValue();
+        c.output(out);
+      }
+    }));
+
+    format.apply(TextIO.Write.to(resultPath));
+
+    p.run();
+  }
+}