You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2017/04/21 17:52:53 UTC
[09/50] [abbrv] beam git commit: [BEAM-1994] Remove Flink examples
package
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
new file mode 100644
index 0000000..4c826d1
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/DoFnOperatorTest.java
@@ -0,0 +1,600 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.hamcrest.Matchers.emptyIterable;
+import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.base.Function;
+import com.google.common.base.Predicate;
+import com.google.common.collect.FluentIterable;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.util.Collections;
+import java.util.HashMap;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.StatefulDoFnRunner;
+import org.apache.beam.runners.flink.FlinkPipelineOptions;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.testing.PCollectionViewTesting;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.join.RawUnionValue;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.TimeDomain;
+import org.apache.beam.sdk.util.Timer;
+import org.apache.beam.sdk.util.TimerSpec;
+import org.apache.beam.sdk.util.TimerSpecs;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link DoFnOperator}.
+ */
+@RunWith(JUnit4.class)
+public class DoFnOperatorTest {
+
+ // views and windows for testing side inputs
+ private static final long WINDOW_MSECS_1 = 100;
+ private static final long WINDOW_MSECS_2 = 500;
+
+ private WindowingStrategy<Object, IntervalWindow> windowingStrategy1 =
+ WindowingStrategy.of(FixedWindows.of(new Duration(WINDOW_MSECS_1)));
+
+ private PCollectionView<Iterable<String>> view1 =
+ PCollectionViewTesting.testingView(
+ new TupleTag<Iterable<WindowedValue<String>>>() {},
+ new PCollectionViewTesting.IdentityViewFn<String>(),
+ StringUtf8Coder.of(),
+ windowingStrategy1);
+
+ private WindowingStrategy<Object, IntervalWindow> windowingStrategy2 =
+ WindowingStrategy.of(FixedWindows.of(new Duration(WINDOW_MSECS_2)));
+
+ private PCollectionView<Iterable<String>> view2 =
+ PCollectionViewTesting.testingView(
+ new TupleTag<Iterable<WindowedValue<String>>>() {},
+ new PCollectionViewTesting.IdentityViewFn<String>(),
+ StringUtf8Coder.of(),
+ windowingStrategy2);
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testSingleOutput() throws Exception {
+
+ WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
+ WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
+
+ TupleTag<String> outputTag = new TupleTag<>("main-output");
+
+ DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
+ new IdentityDoFn<String>(),
+ windowedValueCoder,
+ outputTag,
+ Collections.<TupleTag<?>>emptyList(),
+ new DoFnOperator.DefaultOutputManagerFactory(),
+ WindowingStrategy.globalDefault(),
+ new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+ Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+ PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+ null);
+
+ OneInputStreamOperatorTestHarness<WindowedValue<String>, String> testHarness =
+ new OneInputStreamOperatorTestHarness<>(doFnOperator);
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("Hello")));
+
+ assertThat(
+ this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ contains(WindowedValue.valueInGlobalWindow("Hello")));
+
+ testHarness.close();
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testMultiOutputOutput() throws Exception {
+
+ WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
+ WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
+
+ TupleTag<String> mainOutput = new TupleTag<>("main-output");
+ TupleTag<String> additionalOutput1 = new TupleTag<>("output-1");
+ TupleTag<String> additionalOutput2 = new TupleTag<>("output-2");
+ ImmutableMap<TupleTag<?>, Integer> outputMapping = ImmutableMap.<TupleTag<?>, Integer>builder()
+ .put(mainOutput, 1)
+ .put(additionalOutput1, 2)
+ .put(additionalOutput2, 3)
+ .build();
+
+ DoFnOperator<String, String, RawUnionValue> doFnOperator = new DoFnOperator<>(
+ new MultiOutputDoFn(additionalOutput1, additionalOutput2),
+ windowedValueCoder,
+ mainOutput,
+ ImmutableList.<TupleTag<?>>of(additionalOutput1, additionalOutput2),
+ new DoFnOperator.MultiOutputOutputManagerFactory(outputMapping),
+ WindowingStrategy.globalDefault(),
+ new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+ Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+ PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+ null);
+
+ OneInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue> testHarness =
+ new OneInputStreamOperatorTestHarness<>(doFnOperator);
+
+ testHarness.open();
+
+ testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("one")));
+ testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("two")));
+ testHarness.processElement(new StreamRecord<>(WindowedValue.valueInGlobalWindow("hello")));
+
+ assertThat(
+ this.stripStreamRecordFromRawUnion(testHarness.getOutput()),
+ contains(
+ new RawUnionValue(2, WindowedValue.valueInGlobalWindow("extra: one")),
+ new RawUnionValue(3, WindowedValue.valueInGlobalWindow("extra: two")),
+ new RawUnionValue(1, WindowedValue.valueInGlobalWindow("got: hello")),
+ new RawUnionValue(2, WindowedValue.valueInGlobalWindow("got: hello")),
+ new RawUnionValue(3, WindowedValue.valueInGlobalWindow("got: hello"))));
+
+ testHarness.close();
+ }
+
+ @Test
+ public void testLateDroppingForStatefulFn() throws Exception {
+
+ WindowingStrategy<Object, IntervalWindow> windowingStrategy =
+ WindowingStrategy.of(FixedWindows.of(new Duration(10)));
+
+ DoFn<Integer, String> fn = new DoFn<Integer, String>() {
+
+ @StateId("state")
+ private final StateSpec<Object, ValueState<String>> stateSpec =
+ StateSpecs.value(StringUtf8Coder.of());
+
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ context.output(context.element().toString());
+ }
+ };
+
+ WindowedValue.FullWindowedValueCoder<Integer> windowedValueCoder =
+ WindowedValue.getFullCoder(
+ VarIntCoder.of(),
+ windowingStrategy.getWindowFn().windowCoder());
+
+ TupleTag<String> outputTag = new TupleTag<>("main-output");
+
+ DoFnOperator<Integer, String, WindowedValue<String>> doFnOperator = new DoFnOperator<>(
+ fn,
+ windowedValueCoder,
+ outputTag,
+ Collections.<TupleTag<?>>emptyList(),
+ new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<String>>(),
+ windowingStrategy,
+ new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+ Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+ PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+ VarIntCoder.of() /* key coder */);
+
+ OneInputStreamOperatorTestHarness<WindowedValue<Integer>, WindowedValue<String>> testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ doFnOperator,
+ new KeySelector<WindowedValue<Integer>, Integer>() {
+ @Override
+ public Integer getKey(WindowedValue<Integer> integerWindowedValue) throws Exception {
+ return integerWindowedValue.getValue();
+ }
+ },
+ new CoderTypeInformation<>(VarIntCoder.of()));
+
+ testHarness.open();
+
+ testHarness.processWatermark(0);
+
+ IntervalWindow window1 = new IntervalWindow(new Instant(0), Duration.millis(10));
+
+ // this should not be late
+ testHarness.processElement(
+ new StreamRecord<>(WindowedValue.of(13, new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+ assertThat(
+ this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ contains(WindowedValue.of("13", new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+ testHarness.getOutput().clear();
+
+ testHarness.processWatermark(9);
+
+ // this should still not be considered late
+ testHarness.processElement(
+ new StreamRecord<>(WindowedValue.of(17, new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+ assertThat(
+ this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ contains(WindowedValue.of("17", new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+ testHarness.getOutput().clear();
+
+ testHarness.processWatermark(10);
+
+ // this should now be considered late
+ testHarness.processElement(
+ new StreamRecord<>(WindowedValue.of(17, new Instant(0), window1, PaneInfo.NO_FIRING)));
+
+ assertThat(
+ this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ emptyIterable());
+
+ testHarness.close();
+ }
+
+ @Test
+ public void testStateGCForStatefulFn() throws Exception {
+
+ WindowingStrategy<Object, IntervalWindow> windowingStrategy =
+ WindowingStrategy.of(FixedWindows.of(new Duration(10))).withAllowedLateness(Duration.ZERO);
+
+ final String timerId = "boo";
+ final String stateId = "dazzle";
+
+ final int offset = 5000;
+ final int timerOutput = 4093;
+
+ DoFn<KV<String, Integer>, KV<String, Integer>> fn =
+ new DoFn<KV<String, Integer>, KV<String, Integer>>() {
+
+ @TimerId(timerId)
+ private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+ @StateId(stateId)
+ private final StateSpec<Object, ValueState<String>> stateSpec =
+ StateSpecs.value(StringUtf8Coder.of());
+
+ @ProcessElement
+ public void processElement(
+ ProcessContext context,
+ @TimerId(timerId) Timer timer,
+ @StateId(stateId) ValueState<String> state,
+ BoundedWindow window) {
+ timer.set(window.maxTimestamp());
+ state.write(context.element().getKey());
+ context.output(
+ KV.of(context.element().getKey(), context.element().getValue() + offset));
+ }
+
+ @OnTimer(timerId)
+ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState<String> state) {
+ context.output(KV.of(state.read(), timerOutput));
+ }
+ };
+
+ WindowedValue.FullWindowedValueCoder<KV<String, Integer>> windowedValueCoder =
+ WindowedValue.getFullCoder(
+ KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()),
+ windowingStrategy.getWindowFn().windowCoder());
+
+ TupleTag<KV<String, Integer>> outputTag = new TupleTag<>("main-output");
+
+ DoFnOperator<
+ KV<String, Integer>, KV<String, Integer>, WindowedValue<KV<String, Integer>>> doFnOperator =
+ new DoFnOperator<>(
+ fn,
+ windowedValueCoder,
+ outputTag,
+ Collections.<TupleTag<?>>emptyList(),
+ new DoFnOperator.DefaultOutputManagerFactory<WindowedValue<KV<String, Integer>>>(),
+ windowingStrategy,
+ new HashMap<Integer, PCollectionView<?>>(), /* side-input mapping */
+ Collections.<PCollectionView<?>>emptyList(), /* side inputs */
+ PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+ StringUtf8Coder.of() /* key coder */);
+
+ KeyedOneInputStreamOperatorTestHarness<
+ String,
+ WindowedValue<KV<String, Integer>>,
+ WindowedValue<KV<String, Integer>>> testHarness =
+ new KeyedOneInputStreamOperatorTestHarness<>(
+ doFnOperator,
+ new KeySelector<WindowedValue<KV<String, Integer>>, String>() {
+ @Override
+ public String getKey(
+ WindowedValue<KV<String, Integer>> kvWindowedValue) throws Exception {
+ return kvWindowedValue.getValue().getKey();
+ }
+ },
+ new CoderTypeInformation<>(StringUtf8Coder.of()));
+
+ testHarness.open();
+
+ testHarness.processWatermark(0);
+
+ assertEquals(0, testHarness.numKeyedStateEntries());
+
+ IntervalWindow window1 = new IntervalWindow(new Instant(0), Duration.millis(10));
+
+ testHarness.processElement(
+ new StreamRecord<>(
+ WindowedValue.of(KV.of("key1", 5), new Instant(1), window1, PaneInfo.NO_FIRING)));
+
+ testHarness.processElement(
+ new StreamRecord<>(
+ WindowedValue.of(KV.of("key2", 7), new Instant(3), window1, PaneInfo.NO_FIRING)));
+
+ assertThat(
+ this.<KV<String, Integer>>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ contains(
+ WindowedValue.of(
+ KV.of("key1", 5 + offset), new Instant(1), window1, PaneInfo.NO_FIRING),
+ WindowedValue.of(
+ KV.of("key2", 7 + offset), new Instant(3), window1, PaneInfo.NO_FIRING)));
+
+ assertEquals(2, testHarness.numKeyedStateEntries());
+
+ testHarness.getOutput().clear();
+
+ // this should trigger both the window.maxTimestamp() timer and the GC timer
+ // this tests that the GC timer fires after the user timer
+ testHarness.processWatermark(
+ window1.maxTimestamp()
+ .plus(windowingStrategy.getAllowedLateness())
+ .plus(StatefulDoFnRunner.TimeInternalsCleanupTimer.GC_DELAY_MS)
+ .getMillis());
+
+ assertThat(
+ this.<KV<String, Integer>>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ contains(
+ WindowedValue.of(
+ KV.of("key1", timerOutput), new Instant(9), window1, PaneInfo.NO_FIRING),
+ WindowedValue.of(
+ KV.of("key2", timerOutput), new Instant(9), window1, PaneInfo.NO_FIRING)));
+
+ // ensure the state was garbage collected
+ assertEquals(0, testHarness.numKeyedStateEntries());
+
+ testHarness.close();
+ }
+
+ public void testSideInputs(boolean keyed) throws Exception {
+
+ WindowedValue.ValueOnlyWindowedValueCoder<String> windowedValueCoder =
+ WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
+
+ TupleTag<String> outputTag = new TupleTag<>("main-output");
+
+ ImmutableMap<Integer, PCollectionView<?>> sideInputMapping =
+ ImmutableMap.<Integer, PCollectionView<?>>builder()
+ .put(1, view1)
+ .put(2, view2)
+ .build();
+
+ Coder<String> keyCoder = null;
+ if (keyed) {
+ keyCoder = StringUtf8Coder.of();
+ }
+
+ DoFnOperator<String, String, String> doFnOperator = new DoFnOperator<>(
+ new IdentityDoFn<String>(),
+ windowedValueCoder,
+ outputTag,
+ Collections.<TupleTag<?>>emptyList(),
+ new DoFnOperator.DefaultOutputManagerFactory<String>(),
+ WindowingStrategy.globalDefault(),
+ sideInputMapping, /* side-input mapping */
+ ImmutableList.<PCollectionView<?>>of(view1, view2), /* side inputs */
+ PipelineOptionsFactory.as(FlinkPipelineOptions.class),
+ keyCoder);
+
+ TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, String> testHarness =
+ new TwoInputStreamOperatorTestHarness<>(doFnOperator);
+
+ if (keyed) {
+ // we use a dummy key for the second input since it is considered to be broadcast
+ testHarness = new KeyedTwoInputStreamOperatorTestHarness<>(
+ doFnOperator,
+ new StringKeySelector(),
+ new DummyKeySelector(),
+ BasicTypeInfo.STRING_TYPE_INFO);
+ }
+
+ testHarness.open();
+
+ IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(100));
+ IntervalWindow secondWindow = new IntervalWindow(new Instant(0), new Instant(500));
+
+ // test the keep of sideInputs events
+ testHarness.processElement2(
+ new StreamRecord<>(
+ new RawUnionValue(
+ 1,
+ valuesInWindow(ImmutableList.of("hello", "ciao"), new Instant(0), firstWindow))));
+ testHarness.processElement2(
+ new StreamRecord<>(
+ new RawUnionValue(
+ 2,
+ valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(0), secondWindow))));
+
+ // push in a regular elements
+ WindowedValue<String> helloElement = valueInWindow("Hello", new Instant(0), firstWindow);
+ WindowedValue<String> worldElement = valueInWindow("World", new Instant(1000), firstWindow);
+ testHarness.processElement1(new StreamRecord<>(helloElement));
+ testHarness.processElement1(new StreamRecord<>(worldElement));
+
+ // test the keep of pushed-back events
+ testHarness.processElement2(
+ new StreamRecord<>(
+ new RawUnionValue(
+ 1,
+ valuesInWindow(ImmutableList.of("hello", "ciao"),
+ new Instant(1000), firstWindow))));
+ testHarness.processElement2(
+ new StreamRecord<>(
+ new RawUnionValue(
+ 2,
+ valuesInWindow(ImmutableList.of("foo", "bar"), new Instant(1000), secondWindow))));
+
+ assertThat(
+ this.<String>stripStreamRecordFromWindowedValue(testHarness.getOutput()),
+ contains(helloElement, worldElement));
+
+ testHarness.close();
+
+ }
+
+ /**
+ * {@link TwoInputStreamOperatorTestHarness} support OperatorStateBackend,
+ * but don't support KeyedStateBackend. So we just test sideInput of normal ParDo.
+ */
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testNormalParDoSideInputs() throws Exception {
+ testSideInputs(false);
+ }
+
+ @Test
+ public void testKeyedSideInputs() throws Exception {
+ testSideInputs(true);
+ }
+
+ private <T> Iterable<WindowedValue<T>> stripStreamRecordFromWindowedValue(
+ Iterable<Object> input) {
+
+ return FluentIterable.from(input).filter(new Predicate<Object>() {
+ @Override
+ public boolean apply(@Nullable Object o) {
+ return o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof WindowedValue;
+ }
+ }).transform(new Function<Object, WindowedValue<T>>() {
+ @Nullable
+ @Override
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public WindowedValue<T> apply(@Nullable Object o) {
+ if (o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof WindowedValue) {
+ return (WindowedValue) ((StreamRecord) o).getValue();
+ }
+ throw new RuntimeException("unreachable");
+ }
+ });
+ }
+
+ private Iterable<RawUnionValue> stripStreamRecordFromRawUnion(Iterable<Object> input) {
+ return FluentIterable.from(input).filter(new Predicate<Object>() {
+ @Override
+ public boolean apply(@Nullable Object o) {
+ return o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof RawUnionValue;
+ }
+ }).transform(new Function<Object, RawUnionValue>() {
+ @Nullable
+ @Override
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public RawUnionValue apply(@Nullable Object o) {
+ if (o instanceof StreamRecord && ((StreamRecord) o).getValue() instanceof RawUnionValue) {
+ return (RawUnionValue) ((StreamRecord) o).getValue();
+ }
+ throw new RuntimeException("unreachable");
+ }
+ });
+ }
+
+ private static class MultiOutputDoFn extends DoFn<String, String> {
+ private TupleTag<String> additionalOutput1;
+ private TupleTag<String> additionalOutput2;
+
+ public MultiOutputDoFn(TupleTag<String> additionalOutput1, TupleTag<String> additionalOutput2) {
+ this.additionalOutput1 = additionalOutput1;
+ this.additionalOutput2 = additionalOutput2;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ if (c.element().equals("one")) {
+ c.output(additionalOutput1, "extra: one");
+ } else if (c.element().equals("two")) {
+ c.output(additionalOutput2, "extra: two");
+ } else {
+ c.output("got: " + c.element());
+ c.output(additionalOutput1, "got: " + c.element());
+ c.output(additionalOutput2, "got: " + c.element());
+ }
+ }
+ }
+
+ private static class IdentityDoFn<T> extends DoFn<T, T> {
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ c.output(c.element());
+ }
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private WindowedValue<Iterable<?>> valuesInWindow(
+ Iterable<?> values, Instant timestamp, BoundedWindow window) {
+ return (WindowedValue) WindowedValue.of(values, timestamp, window, PaneInfo.NO_FIRING);
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private <T> WindowedValue<T> valueInWindow(
+ T value, Instant timestamp, BoundedWindow window) {
+ return WindowedValue.of(value, timestamp, window, PaneInfo.NO_FIRING);
+ }
+
+
+ private static class DummyKeySelector implements KeySelector<RawUnionValue, String> {
+ @Override
+ public String getKey(RawUnionValue stringWindowedValue) throws Exception {
+ return "dummy_key";
+ }
+ }
+
+ private static class StringKeySelector implements KeySelector<WindowedValue<String>, String> {
+ @Override
+ public String getKey(WindowedValue<String> stringWindowedValue) throws Exception {
+ return stringWindowedValue.getValue();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
new file mode 100644
index 0000000..7e7d1e1
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkBroadcastStateInternalsTest.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThat;
+
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.GroupingState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkBroadcastStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkBroadcastStateInternalsTest {
+ private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+ private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+ private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+ private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR =
+ StateTags.value("stringValue", StringUtf8Coder.of());
+ private static final StateTag<Object, CombiningState<Integer, int[], Integer>>
+ SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal(
+ "sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+ private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+ StateTags.bag("stringBag", StringUtf8Coder.of());
+
+ FlinkBroadcastStateInternals<String> underTest;
+
+ @Before
+ public void initStateInternals() {
+ MemoryStateBackend backend = new MemoryStateBackend();
+ try {
+ OperatorStateBackend operatorStateBackend =
+ backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), "");
+ underTest = new FlinkBroadcastStateInternals<>(1, operatorStateBackend);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ public void testValue() throws Exception {
+ ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR);
+
+ assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+ assertNotEquals(
+ underTest.state(NAMESPACE_2, STRING_VALUE_ADDR),
+ value);
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.write("hello");
+ assertThat(value.read(), Matchers.equalTo("hello"));
+ value.write("world");
+ assertThat(value.read(), Matchers.equalTo("world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.nullValue());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+
+ }
+
+ @Test
+ public void testBag() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+ assertThat(value.read(), Matchers.emptyIterable());
+ value.add("hello");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+ value.add("world");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.emptyIterable());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+ }
+
+ @Test
+ public void testBagIsEmpty() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add("hello");
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeBagIntoSource() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testMergeBagIntoNewNamespace() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+ BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag1.read(), Matchers.emptyIterable());
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testCombiningValue() throws Exception {
+ GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
+
+ assertThat(value.read(), Matchers.equalTo(0));
+ value.add(2);
+ assertThat(value.read(), Matchers.equalTo(2));
+
+ value.add(3);
+ assertThat(value.read(), Matchers.equalTo(5));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(0));
+ assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value);
+ }
+
+ @Test
+ public void testCombiningIsEmpty() throws Exception {
+ GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add(5);
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeCombiningValueIntoSource() throws Exception {
+ CombiningState<Integer, int[], Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+ CombiningState<Integer, int[], Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ assertThat(value1.read(), Matchers.equalTo(11));
+ assertThat(value2.read(), Matchers.equalTo(10));
+
+ // Merging clears the old values and updates the result value.
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+ assertThat(value1.read(), Matchers.equalTo(21));
+ assertThat(value2.read(), Matchers.equalTo(0));
+ }
+
+ @Test
+ public void testMergeCombiningValueIntoNewNamespace() throws Exception {
+ CombiningState<Integer, int[], Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+ CombiningState<Integer, int[], Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+ CombiningState<Integer, int[], Integer> value3 =
+ underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+ // Merging clears the old values and updates the result value.
+ assertThat(value1.read(), Matchers.equalTo(0));
+ assertThat(value2.read(), Matchers.equalTo(0));
+ assertThat(value3.read(), Matchers.equalTo(21));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
new file mode 100644
index 0000000..5433d07
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkKeyGroupStateInternalsTest.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkKeyGroupStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.operators.KeyContext;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkKeyGroupStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkKeyGroupStateInternalsTest {
+ private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+ private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+ private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+ private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+ StateTags.bag("stringBag", StringUtf8Coder.of());
+
+ FlinkKeyGroupStateInternals<String> underTest;
+ private KeyedStateBackend keyedStateBackend;
+
+ @Before
+ public void initStateInternals() {
+ try {
+ keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+ underTest = new FlinkKeyGroupStateInternals<>(StringUtf8Coder.of(), keyedStateBackend);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private KeyedStateBackend getKeyedStateBackend(int numberOfKeyGroups,
+ KeyGroupRange keyGroupRange) {
+ MemoryStateBackend backend = new MemoryStateBackend();
+ try {
+ AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend(
+ new DummyEnvironment("test", 1, 0),
+ new JobID(),
+ "test_op",
+ new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()),
+ numberOfKeyGroups,
+ keyGroupRange,
+ new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+ keyedStateBackend.setCurrentKey(ByteBuffer.wrap(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "1")));
+ return keyedStateBackend;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ public void testBag() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+ assertThat(value.read(), Matchers.emptyIterable());
+ value.add("hello");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+ value.add("world");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.emptyIterable());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+ }
+
+ @Test
+ public void testBagIsEmpty() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add("hello");
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeBagIntoSource() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testMergeBagIntoNewNamespace() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+ BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag1.read(), Matchers.emptyIterable());
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testKeyGroupAndCheckpoint() throws Exception {
+ // assign to keyGroup 0
+ ByteBuffer key0 = ByteBuffer.wrap(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "11111111"));
+ // assign to keyGroup 1
+ ByteBuffer key1 = ByteBuffer.wrap(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "22222222"));
+ FlinkKeyGroupStateInternals<String> allState;
+ {
+ KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+ allState = new FlinkKeyGroupStateInternals<>(
+ StringUtf8Coder.of(), keyedStateBackend);
+ BagState<String> valueForNamespace0 = allState.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> valueForNamespace1 = allState.state(NAMESPACE_2, STRING_BAG_ADDR);
+ keyedStateBackend.setCurrentKey(key0);
+ valueForNamespace0.add("0");
+ valueForNamespace1.add("2");
+ keyedStateBackend.setCurrentKey(key1);
+ valueForNamespace0.add("1");
+ valueForNamespace1.add("3");
+ assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1"));
+ assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3"));
+ }
+
+ ClassLoader classLoader = FlinkKeyGroupStateInternalsTest.class.getClassLoader();
+
+ // 1. scale up
+ ByteArrayOutputStream out0 = new ByteArrayOutputStream();
+ allState.snapshotKeyGroupState(0, new DataOutputStream(out0));
+ DataInputStream in0 = new DataInputStream(
+ new ByteArrayInputStream(out0.toByteArray()));
+ {
+ KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 0));
+ FlinkKeyGroupStateInternals<String> state0 =
+ new FlinkKeyGroupStateInternals<>(
+ StringUtf8Coder.of(), keyedStateBackend);
+ state0.restoreKeyGroupState(0, in0, classLoader);
+ BagState<String> valueForNamespace0 = state0.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> valueForNamespace1 = state0.state(NAMESPACE_2, STRING_BAG_ADDR);
+ assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0"));
+ assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2"));
+ }
+
+ ByteArrayOutputStream out1 = new ByteArrayOutputStream();
+ allState.snapshotKeyGroupState(1, new DataOutputStream(out1));
+ DataInputStream in1 = new DataInputStream(
+ new ByteArrayInputStream(out1.toByteArray()));
+ {
+ KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(1, 1));
+ FlinkKeyGroupStateInternals<String> state1 =
+ new FlinkKeyGroupStateInternals<>(
+ StringUtf8Coder.of(), keyedStateBackend);
+ state1.restoreKeyGroupState(1, in1, classLoader);
+ BagState<String> valueForNamespace0 = state1.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> valueForNamespace1 = state1.state(NAMESPACE_2, STRING_BAG_ADDR);
+ assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("1"));
+ assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("3"));
+ }
+
+ // 2. scale down
+ {
+ KeyedStateBackend keyedStateBackend = getKeyedStateBackend(2, new KeyGroupRange(0, 1));
+ FlinkKeyGroupStateInternals<String> newAllState = new FlinkKeyGroupStateInternals<>(
+ StringUtf8Coder.of(), keyedStateBackend);
+ in0.reset();
+ in1.reset();
+ newAllState.restoreKeyGroupState(0, in0, classLoader);
+ newAllState.restoreKeyGroupState(1, in1, classLoader);
+ BagState<String> valueForNamespace0 = newAllState.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> valueForNamespace1 = newAllState.state(NAMESPACE_2, STRING_BAG_ADDR);
+ assertThat(valueForNamespace0.read(), Matchers.containsInAnyOrder("0", "1"));
+ assertThat(valueForNamespace1.read(), Matchers.containsInAnyOrder("2", "3"));
+ }
+
+ }
+
+ private static class TestKeyContext implements KeyContext {
+
+ private Object key;
+
+ @Override
+ public void setCurrentKey(Object key) {
+ this.key = key;
+ }
+
+ @Override
+ public Object getCurrentKey() {
+ return key;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
new file mode 100644
index 0000000..08ae0c4
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkSplitStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkSplitStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkSplitStateInternalsTest {
+ private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+ private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+
+ private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+ StateTags.bag("stringBag", StringUtf8Coder.of());
+
+ FlinkSplitStateInternals<String> underTest;
+
+ @Before
+ public void initStateInternals() {
+ MemoryStateBackend backend = new MemoryStateBackend();
+ try {
+ OperatorStateBackend operatorStateBackend =
+ backend.createOperatorStateBackend(new DummyEnvironment("test", 1, 0), "");
+ underTest = new FlinkSplitStateInternals<>(operatorStateBackend);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ public void testBag() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+ assertThat(value.read(), Matchers.emptyIterable());
+ value.add("hello");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+ value.add("world");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.emptyIterable());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+ }
+
+ @Test
+ public void testBagIsEmpty() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add("hello");
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
new file mode 100644
index 0000000..d140271
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertThat;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import org.apache.beam.runners.core.StateMerging;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaceForTest;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFns;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.GroupingState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.hamcrest.Matchers;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link FlinkStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+@RunWith(JUnit4.class)
+public class FlinkStateInternalsTest {
+ private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10));
+ private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+ private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+ private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+ private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR =
+ StateTags.value("stringValue", StringUtf8Coder.of());
+ private static final StateTag<Object, CombiningState<Integer, int[], Integer>>
+ SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal(
+ "sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+ private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+ StateTags.bag("stringBag", StringUtf8Coder.of());
+ private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
+ WATERMARK_EARLIEST_ADDR =
+ StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp());
+ private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
+ WATERMARK_LATEST_ADDR =
+ StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtLatestInputTimestamp());
+ private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> WATERMARK_EOW_ADDR =
+ StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEndOfWindow());
+
+ FlinkStateInternals<String> underTest;
+
+ @Before
+ public void initStateInternals() {
+ MemoryStateBackend backend = new MemoryStateBackend();
+ try {
+ AbstractKeyedStateBackend<ByteBuffer> keyedStateBackend = backend.createKeyedStateBackend(
+ new DummyEnvironment("test", 1, 0),
+ new JobID(),
+ "test_op",
+ new GenericTypeInfo<>(ByteBuffer.class).createSerializer(new ExecutionConfig()),
+ 1,
+ new KeyGroupRange(0, 0),
+ new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+ underTest = new FlinkStateInternals<>(keyedStateBackend, StringUtf8Coder.of());
+
+ keyedStateBackend.setCurrentKey(
+ ByteBuffer.wrap(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Hello")));
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ public void testValue() throws Exception {
+ ValueState<String> value = underTest.state(NAMESPACE_1, STRING_VALUE_ADDR);
+
+ assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+ assertNotEquals(
+ underTest.state(NAMESPACE_2, STRING_VALUE_ADDR),
+ value);
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.write("hello");
+ assertThat(value.read(), Matchers.equalTo("hello"));
+ value.write("world");
+ assertThat(value.read(), Matchers.equalTo("world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.nullValue());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+
+ }
+
+ @Test
+ public void testBag() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+ assertThat(value.read(), Matchers.emptyIterable());
+ value.add("hello");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+ value.add("world");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.emptyIterable());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+ }
+
+ @Test
+ public void testBagIsEmpty() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add("hello");
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeBagIntoSource() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testMergeBagIntoNewNamespace() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+ BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag1.read(), Matchers.emptyIterable());
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testCombiningValue() throws Exception {
+ GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
+
+ assertThat(value.read(), Matchers.equalTo(0));
+ value.add(2);
+ assertThat(value.read(), Matchers.equalTo(2));
+
+ value.add(3);
+ assertThat(value.read(), Matchers.equalTo(5));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(0));
+ assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value);
+ }
+
+ @Test
+ public void testCombiningIsEmpty() throws Exception {
+ GroupingState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add(5);
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeCombiningValueIntoSource() throws Exception {
+ CombiningState<Integer, int[], Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+ CombiningState<Integer, int[], Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ assertThat(value1.read(), Matchers.equalTo(11));
+ assertThat(value2.read(), Matchers.equalTo(10));
+
+ // Merging clears the old values and updates the result value.
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+ assertThat(value1.read(), Matchers.equalTo(21));
+ assertThat(value2.read(), Matchers.equalTo(0));
+ }
+
+ @Test
+ public void testMergeCombiningValueIntoNewNamespace() throws Exception {
+ CombiningState<Integer, int[], Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+ CombiningState<Integer, int[], Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+ CombiningState<Integer, int[], Integer> value3 =
+ underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+ // Merging clears the old values and updates the result value.
+ assertThat(value1.read(), Matchers.equalTo(0));
+ assertThat(value2.read(), Matchers.equalTo(0));
+ assertThat(value3.read(), Matchers.equalTo(21));
+ }
+
+ @Test
+ public void testWatermarkEarliestState() throws Exception {
+ WatermarkHoldState<BoundedWindow> value =
+ underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR)));
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.add(new Instant(2000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.add(new Instant(3000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.add(new Instant(1000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(1000)));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(null));
+ assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value);
+ }
+
+ @Test
+ public void testWatermarkLatestState() throws Exception {
+ WatermarkHoldState<BoundedWindow> value =
+ underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR)));
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.add(new Instant(2000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.add(new Instant(3000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+
+ value.add(new Instant(1000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(null));
+ assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value);
+ }
+
+ @Test
+ public void testWatermarkEndOfWindowState() throws Exception {
+ WatermarkHoldState<BoundedWindow> value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR)));
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.add(new Instant(2000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(null));
+ assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value);
+ }
+
+ @Test
+ public void testWatermarkStateIsEmpty() throws Exception {
+ WatermarkHoldState<BoundedWindow> value =
+ underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add(new Instant(1000));
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeEarliestWatermarkIntoSource() throws Exception {
+ WatermarkHoldState<BoundedWindow> value1 =
+ underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+ WatermarkHoldState<BoundedWindow> value2 =
+ underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR);
+
+ value1.add(new Instant(3000));
+ value2.add(new Instant(5000));
+ value1.add(new Instant(4000));
+ value2.add(new Instant(2000));
+
+ // Merging clears the old values and updates the merged value.
+ StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1);
+
+ assertThat(value1.read(), Matchers.equalTo(new Instant(2000)));
+ assertThat(value2.read(), Matchers.equalTo(null));
+ }
+
+ @Test
+ public void testMergeLatestWatermarkIntoSource() throws Exception {
+ WatermarkHoldState<BoundedWindow> value1 =
+ underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR);
+ WatermarkHoldState<BoundedWindow> value2 =
+ underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR);
+ WatermarkHoldState<BoundedWindow> value3 =
+ underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR);
+
+ value1.add(new Instant(3000));
+ value2.add(new Instant(5000));
+ value1.add(new Instant(4000));
+ value2.add(new Instant(2000));
+
+ // Merging clears the old values and updates the result value.
+ StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1);
+
+ // Merging clears the old values and updates the result value.
+ assertThat(value3.read(), Matchers.equalTo(new Instant(5000)));
+ assertThat(value1.read(), Matchers.equalTo(null));
+ assertThat(value2.read(), Matchers.equalTo(null));
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java
new file mode 100644
index 0000000..663b910
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/GroupByNullKeyTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import com.google.common.base.Joiner;
+import java.io.Serializable;
+import java.util.Arrays;
+import org.apache.beam.runners.flink.FlinkTestPipeline;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.AfterWatermark;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.flink.streaming.util.StreamingProgramTestBase;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * Test for GroupByNullKey.
+ */
+public class GroupByNullKeyTest extends StreamingProgramTestBase implements Serializable {
+
+
+ protected String resultPath;
+
+ static final String[] EXPECTED_RESULT = new String[] {
+ "k: null v: user1 user1 user1 user2 user2 user2 user2 user3"
+ };
+
+ public GroupByNullKeyTest(){
+ }
+
+ @Override
+ protected void preSubmit() throws Exception {
+ resultPath = getTempDirPath("result");
+ }
+
+ @Override
+ protected void postSubmit() throws Exception {
+ compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath);
+ }
+
+ /**
+ * DoFn extracting user and timestamp.
+ */
+ private static class ExtractUserAndTimestamp extends DoFn<KV<Integer, String>, String> {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ KV<Integer, String> record = c.element();
+ int timestamp = record.getKey();
+ String userName = record.getValue();
+ if (userName != null) {
+ // Sets the implicit timestamp field to be used in windowing.
+ c.outputWithTimestamp(userName, new Instant(timestamp));
+ }
+ }
+ }
+
+ @Override
+ protected void testProgram() throws Exception {
+
+ Pipeline p = FlinkTestPipeline.createForStreaming();
+
+ PCollection<String> output =
+ p.apply(Create.of(Arrays.asList(
+ KV.<Integer, String>of(0, "user1"),
+ KV.<Integer, String>of(1, "user1"),
+ KV.<Integer, String>of(2, "user1"),
+ KV.<Integer, String>of(10, "user2"),
+ KV.<Integer, String>of(1, "user2"),
+ KV.<Integer, String>of(15000, "user2"),
+ KV.<Integer, String>of(12000, "user2"),
+ KV.<Integer, String>of(25000, "user3"))))
+ .apply(ParDo.of(new ExtractUserAndTimestamp()))
+ .apply(Window.<String>into(FixedWindows.of(Duration.standardHours(1)))
+ .triggering(AfterWatermark.pastEndOfWindow())
+ .withAllowedLateness(Duration.ZERO)
+ .discardingFiredPanes())
+
+ .apply(ParDo.of(new DoFn<String, KV<Void, String>>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ String elem = c.element();
+ c.output(KV.<Void, String>of(null, elem));
+ }
+ }))
+ .apply(GroupByKey.<Void, String>create())
+ .apply(ParDo.of(new DoFn<KV<Void, Iterable<String>>, String>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ KV<Void, Iterable<String>> elem = c.element();
+ StringBuilder str = new StringBuilder();
+ str.append("k: " + elem.getKey() + " v:");
+ for (String v : elem.getValue()) {
+ str.append(" " + v);
+ }
+ c.output(str.toString());
+ }
+ }));
+ output.apply(TextIO.Write.to(resultPath));
+ p.run();
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
new file mode 100644
index 0000000..3a08088
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TestCountingSource.java
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An unbounded source for testing the unbounded sources framework code.
+ *
+ * <p>Each split of this sources produces records of the form KV(split_id, i),
+ * where i counts up from 0. Each record has a timestamp of i, and the watermark
+ * accurately tracks these timestamps. The reader will occasionally return false
+ * from {@code advance}, in order to simulate a source where not all the data is
+ * available immediately.
+ */
+public class TestCountingSource
+ extends UnboundedSource<KV<Integer, Integer>, TestCountingSource.CounterMark> {
+ private static final Logger LOG = LoggerFactory.getLogger(TestCountingSource.class);
+
+ private static List<Integer> finalizeTracker;
+ private final int numMessagesPerShard;
+ private final int shardNumber;
+ private final boolean dedup;
+ private final boolean throwOnFirstSnapshot;
+ private final boolean allowSplitting;
+
+ /**
+ * We only allow an exception to be thrown from getCheckpointMark
+ * at most once. This must be static since the entire TestCountingSource
+ * instance may re-serialized when the pipeline recovers and retries.
+ */
+ private static boolean thrown = false;
+
+ public static void setFinalizeTracker(List<Integer> finalizeTracker) {
+ TestCountingSource.finalizeTracker = finalizeTracker;
+ }
+
+ public TestCountingSource(int numMessagesPerShard) {
+ this(numMessagesPerShard, 0, false, false, true);
+ }
+
+ public TestCountingSource withDedup() {
+ return new TestCountingSource(
+ numMessagesPerShard, shardNumber, true, throwOnFirstSnapshot, true);
+ }
+
+ private TestCountingSource withShardNumber(int shardNumber) {
+ return new TestCountingSource(
+ numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true);
+ }
+
+ public TestCountingSource withThrowOnFirstSnapshot(boolean throwOnFirstSnapshot) {
+ return new TestCountingSource(
+ numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, true);
+ }
+
+ public TestCountingSource withoutSplitting() {
+ return new TestCountingSource(
+ numMessagesPerShard, shardNumber, dedup, throwOnFirstSnapshot, false);
+ }
+
+ private TestCountingSource(int numMessagesPerShard, int shardNumber, boolean dedup,
+ boolean throwOnFirstSnapshot, boolean allowSplitting) {
+ this.numMessagesPerShard = numMessagesPerShard;
+ this.shardNumber = shardNumber;
+ this.dedup = dedup;
+ this.throwOnFirstSnapshot = throwOnFirstSnapshot;
+ this.allowSplitting = allowSplitting;
+ }
+
+ public int getShardNumber() {
+ return shardNumber;
+ }
+
+ @Override
+ public List<TestCountingSource> split(
+ int desiredNumSplits, PipelineOptions options) {
+ List<TestCountingSource> splits = new ArrayList<>();
+ int numSplits = allowSplitting ? desiredNumSplits : 1;
+ for (int i = 0; i < numSplits; i++) {
+ splits.add(withShardNumber(i));
+ }
+ return splits;
+ }
+
+ class CounterMark implements UnboundedSource.CheckpointMark {
+ int current;
+
+ public CounterMark(int current) {
+ this.current = current;
+ }
+
+ @Override
+ public void finalizeCheckpoint() {
+ if (finalizeTracker != null) {
+ finalizeTracker.add(current);
+ }
+ }
+ }
+
+ @Override
+ public Coder<CounterMark> getCheckpointMarkCoder() {
+ return DelegateCoder.of(
+ VarIntCoder.of(),
+ new DelegateCoder.CodingFunction<CounterMark, Integer>() {
+ @Override
+ public Integer apply(CounterMark input) {
+ return input.current;
+ }
+ },
+ new DelegateCoder.CodingFunction<Integer, CounterMark>() {
+ @Override
+ public CounterMark apply(Integer input) {
+ return new CounterMark(input);
+ }
+ });
+ }
+
+ @Override
+ public boolean requiresDeduping() {
+ return dedup;
+ }
+
+ /**
+ * Public only so that the checkpoint can be conveyed from {@link #getCheckpointMark()} to
+ * {@link TestCountingSource#createReader(PipelineOptions, CounterMark)} without cast.
+ */
+ public class CountingSourceReader extends UnboundedReader<KV<Integer, Integer>> {
+ private int current;
+
+ public CountingSourceReader(int startingPoint) {
+ this.current = startingPoint;
+ }
+
+ @Override
+ public boolean start() {
+ return advance();
+ }
+
+ @Override
+ public boolean advance() {
+ if (current >= numMessagesPerShard - 1) {
+ return false;
+ }
+ // If testing dedup, occasionally insert a duplicate value;
+ if (current >= 0 && dedup && ThreadLocalRandom.current().nextInt(5) == 0) {
+ return true;
+ }
+ current++;
+ return true;
+ }
+
+ @Override
+ public KV<Integer, Integer> getCurrent() {
+ return KV.of(shardNumber, current);
+ }
+
+ @Override
+ public Instant getCurrentTimestamp() {
+ return new Instant(current);
+ }
+
+ @Override
+ public byte[] getCurrentRecordId() {
+ try {
+ return encodeToByteArray(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), getCurrent());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void close() {}
+
+ @Override
+ public TestCountingSource getCurrentSource() {
+ return TestCountingSource.this;
+ }
+
+ @Override
+ public Instant getWatermark() {
+ // The watermark is a promise about future elements, and the timestamps of elements are
+ // strictly increasing for this source.
+ return new Instant(current + 1);
+ }
+
+ @Override
+ public CounterMark getCheckpointMark() {
+ if (throwOnFirstSnapshot && !thrown) {
+ thrown = true;
+ LOG.error("Throwing exception while checkpointing counter");
+ throw new RuntimeException("failed during checkpoint");
+ }
+ // The checkpoint can assume all records read, including the current, have
+ // been commited.
+ return new CounterMark(current);
+ }
+
+ @Override
+ public long getSplitBacklogBytes() {
+ return 7L;
+ }
+ }
+
+ @Override
+ public CountingSourceReader createReader(
+ PipelineOptions options, @Nullable CounterMark checkpointMark) {
+ if (checkpointMark == null) {
+ LOG.debug("creating reader");
+ } else {
+ LOG.debug("restoring reader from checkpoint with current = {}", checkpointMark.current);
+ }
+ return new CountingSourceReader(checkpointMark != null ? checkpointMark.current : -1);
+ }
+
+ @Override
+ public void validate() {}
+
+ @Override
+ public Coder<KV<Integer, Integer>> getDefaultOutputCoder() {
+ return KvCoder.of(VarIntCoder.of(), VarIntCoder.of());
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java
new file mode 100644
index 0000000..9e6bba8
--- /dev/null
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/TopWikipediaSessionsITCase.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.streaming;
+
+import com.google.api.services.bigquery.model.TableRow;
+import com.google.common.base.Joiner;
+import java.io.Serializable;
+import java.util.Arrays;
+import org.apache.beam.runners.flink.FlinkTestPipeline;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.flink.streaming.util.StreamingProgramTestBase;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+
+/**
+ * Session window test.
+ */
+public class TopWikipediaSessionsITCase extends StreamingProgramTestBase implements Serializable {
+ protected String resultPath;
+
+ public TopWikipediaSessionsITCase(){
+ }
+
+ static final String[] EXPECTED_RESULT = new String[] {
+ "user: user1 value:3",
+ "user: user1 value:1",
+ "user: user2 value:4",
+ "user: user2 value:6",
+ "user: user3 value:7",
+ "user: user3 value:2"
+ };
+
+ @Override
+ protected void preSubmit() throws Exception {
+ resultPath = getTempDirPath("result");
+ }
+
+ @Override
+ protected void postSubmit() throws Exception {
+ compareResultsByLinesInMemory(Joiner.on('\n').join(EXPECTED_RESULT), resultPath);
+ }
+
+ @Override
+ protected void testProgram() throws Exception {
+
+ Pipeline p = FlinkTestPipeline.createForStreaming();
+
+ Long now = (System.currentTimeMillis() + 10000) / 1000;
+
+ PCollection<KV<String, Long>> output =
+ p.apply(Create.of(Arrays.asList(new TableRow().set("timestamp", now).set
+ ("contributor_username", "user1"), new TableRow().set("timestamp", now + 10).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now).set
+ ("contributor_username", "user1"), new TableRow().set("timestamp", now + 2).set
+ ("contributor_username", "user1"), new TableRow().set("timestamp", now).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 1).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 5).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 7).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 8).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 200).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 230).set
+ ("contributor_username", "user1"), new TableRow().set("timestamp", now + 230).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 240).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now + 245).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 235).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 236).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 237).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 238).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 239).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 240).set
+ ("contributor_username", "user3"), new TableRow().set("timestamp", now + 241).set
+ ("contributor_username", "user2"), new TableRow().set("timestamp", now)
+ .set("contributor_username", "user3"))))
+
+
+
+ .apply(ParDo.of(new DoFn<TableRow, String>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ TableRow row = c.element();
+ long timestamp = (Integer) row.get("timestamp");
+ String userName = (String) row.get("contributor_username");
+ if (userName != null) {
+ // Sets the timestamp field to be used in windowing.
+ c.outputWithTimestamp(userName, new Instant(timestamp * 1000L));
+ }
+ }
+ }))
+
+ .apply(Window.<String>into(Sessions.withGapDuration(Duration.standardMinutes(1))))
+
+ .apply(Count.<String>perElement());
+
+ PCollection<String> format = output.apply(ParDo.of(new DoFn<KV<String, Long>, String>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ KV<String, Long> el = c.element();
+ String out = "user: " + el.getKey() + " value:" + el.getValue();
+ c.output(out);
+ }
+ }));
+
+ format.apply(TextIO.Write.to(resultPath));
+
+ p.run();
+ }
+}