You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/11/12 02:28:33 UTC
[15/39] incubator-beam git commit: BEAM-784 Checkpointing for
StateInternals
BEAM-784 Checkpointing for StateInternals
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1db4ff63
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1db4ff63
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1db4ff63
Branch: refs/heads/master
Commit: 1db4ff631736172882976c33316bc089d58483af
Parents: 0a1b278
Author: Thomas Weise <th...@apache.org>
Authored: Tue Oct 25 08:32:23 2016 -0700
Committer: Thomas Weise <th...@apache.org>
Committed: Tue Oct 25 10:06:12 2016 -0700
----------------------------------------------------------------------
.../apex/translators/GroupByKeyTranslator.java | 3 +-
.../translators/ParDoBoundMultiTranslator.java | 4 +-
.../apex/translators/ParDoBoundTranslator.java | 4 +-
.../apex/translators/TranslationContext.java | 10 +
.../functions/ApexGroupByKeyOperator.java | 12 +-
.../functions/ApexParDoOperator.java | 11 +-
.../translators/utils/ApexStateInternals.java | 438 +++++++++++++++++++
.../translators/ParDoBoundTranslatorTest.java | 62 ++-
.../utils/ApexStateInternalsTest.java | 361 +++++++++++++++
9 files changed, 883 insertions(+), 22 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java
index d3e7d2d..cb78579 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/GroupByKeyTranslator.java
@@ -33,7 +33,8 @@ public class GroupByKeyTranslator<K, V> implements TransformTranslator<GroupByKe
public void translate(GroupByKey<K, V> transform, TranslationContext context) {
PCollection<KV<K, V>> input = context.getInput();
ApexGroupByKeyOperator<K, V> group = new ApexGroupByKeyOperator<>(context.getPipelineOptions(),
- input);
+ input, context.<K>stateInternalsFactory()
+ );
context.addOperator(group, group.output);
context.addStream(input, group.input);
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
index 13f07c1..2678869 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundMultiTranslator.java
@@ -64,7 +64,9 @@ public class ParDoBoundMultiTranslator<InputT, OutputT>
ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(
context.getPipelineOptions(),
doFn, transform.getMainOutputTag(), transform.getSideOutputTags().getAll(),
- context.<PCollection<?>>getInput().getWindowingStrategy(), sideInputs, wvInputCoder);
+ context.<PCollection<?>>getInput().getWindowingStrategy(), sideInputs, wvInputCoder,
+ context.<Void>stateInternalsFactory()
+ );
Map<TupleTag<?>, PCollection<?>> outputs = output.getAll();
Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size());
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
index bd7115e..92567a6 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslator.java
@@ -52,7 +52,9 @@ public class ParDoBoundTranslator<InputT, OutputT> implements
ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(
context.getPipelineOptions(),
doFn, new TupleTag<OutputT>(), TupleTagList.empty().getAll() /*sideOutputTags*/,
- output.getWindowingStrategy(), sideInputs, wvInputCoder);
+ output.getWindowingStrategy(), sideInputs, wvInputCoder,
+ context.<Void>stateInternalsFactory()
+ );
context.addOperator(operator, operator.output);
context.addStream(context.getInput(), operator.input);
if (!sideInputs.isEmpty()) {
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java
index ddacc29..07c6494 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/TranslationContext.java
@@ -31,6 +31,7 @@ import java.util.List;
import java.util.Map;
import org.apache.beam.runners.apex.ApexPipelineOptions;
+import org.apache.beam.runners.apex.translators.utils.ApexStateInternals;
import org.apache.beam.runners.apex.translators.utils.ApexStreamTuple;
import org.apache.beam.runners.apex.translators.utils.CoderAdapterStreamCodec;
import org.apache.beam.sdk.coders.Coder;
@@ -38,6 +39,7 @@ import org.apache.beam.sdk.runners.TransformTreeNode;
import org.apache.beam.sdk.transforms.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
+import org.apache.beam.sdk.util.state.StateInternalsFactory;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PInput;
@@ -165,4 +167,12 @@ public class TranslationContext {
}
}
+ /**
+ * Return the {@link StateInternalsFactory} for the pipeline translation.
+ * @return
+ */
+ public <K> StateInternalsFactory<K> stateInternalsFactory() {
+ return new ApexStateInternals.ApexStateInternalsFactory();
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java
index 845618d..d69aeab 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexGroupByKeyOperator.java
@@ -64,7 +64,6 @@ import org.apache.beam.sdk.util.TimerInternals;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingInternals;
import org.apache.beam.sdk.util.WindowingStrategy;
-import org.apache.beam.sdk.util.state.InMemoryStateInternals;
import org.apache.beam.sdk.util.state.StateInternals;
import org.apache.beam.sdk.util.state.StateInternalsFactory;
import org.apache.beam.sdk.values.KV;
@@ -97,8 +96,8 @@ public class ApexGroupByKeyOperator<K, V> implements Operator {
@Bind(JavaSerializer.class)
private final SerializablePipelineOptions serializedOptions;
@Bind(JavaSerializer.class)
-// TODO: InMemoryStateInternals not serializable
- private transient Map<ByteBuffer, StateInternals<K>> perKeyStateInternals = new HashMap<>();
+ private final StateInternalsFactory<K> stateInternalsFactory;
+ private Map<ByteBuffer, StateInternals<K>> perKeyStateInternals = new HashMap<>();
private Map<ByteBuffer, Set<TimerInternals.TimerData>> activeTimers = new HashMap<>();
private transient ProcessContext context;
@@ -137,17 +136,20 @@ public class ApexGroupByKeyOperator<K, V> implements Operator {
output = new DefaultOutputPort<>();
@SuppressWarnings("unchecked")
- public ApexGroupByKeyOperator(ApexPipelineOptions pipelineOptions, PCollection<KV<K, V>> input) {
+ public ApexGroupByKeyOperator(ApexPipelineOptions pipelineOptions, PCollection<KV<K, V>> input,
+ StateInternalsFactory<K> stateInternalsFactory) {
checkNotNull(pipelineOptions);
this.serializedOptions = new SerializablePipelineOptions(pipelineOptions);
this.windowingStrategy = (WindowingStrategy<V, BoundedWindow>) input.getWindowingStrategy();
this.keyCoder = ((KvCoder<K, V>) input.getCoder()).getKeyCoder();
this.valueCoder = ((KvCoder<K, V>) input.getCoder()).getValueCoder();
+ this.stateInternalsFactory = stateInternalsFactory;
}
@SuppressWarnings("unused") // for Kryo
private ApexGroupByKeyOperator() {
this.serializedOptions = null;
+ this.stateInternalsFactory = null;
}
@Override
@@ -230,7 +232,7 @@ public class ApexGroupByKeyOperator<K, V> implements Operator {
}
StateInternals<K> stateInternals = perKeyStateInternals.get(keyBytes);
if (stateInternals == null) {
- stateInternals = InMemoryStateInternals.forKey(key);
+ stateInternals = stateInternalsFactory.stateInternalsForKey(key);
perKeyStateInternals.put(keyBytes, stateInternals);
}
return stateInternals;
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
index 9e8f3dc..43384d6 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/functions/ApexParDoOperator.java
@@ -57,8 +57,8 @@ import org.apache.beam.sdk.util.SideInputReader;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowingStrategy;
-import org.apache.beam.sdk.util.state.InMemoryStateInternals;
import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateInternalsFactory;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.slf4j.Logger;
@@ -84,9 +84,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements
@Bind(JavaSerializer.class)
private final List<PCollectionView<?>> sideInputs;
-// TODO: not Kryo serializable, integrate codec
- private transient StateInternals<Void> sideInputStateInternals = InMemoryStateInternals
- .forKey(null);
+ private final StateInternals<Void> sideInputStateInternals;
private final ValueAndCoderKryoSerializable<List<WindowedValue<InputT>>> pushedBack;
private LongMin pushedBackWatermark = new LongMin();
private long currentInputWatermark = Long.MIN_VALUE;
@@ -104,7 +102,8 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements
List<TupleTag<?>> sideOutputTags,
WindowingStrategy<?, ?> windowingStrategy,
List<PCollectionView<?>> sideInputs,
- Coder<WindowedValue<InputT>> inputCoder
+ Coder<WindowedValue<InputT>> inputCoder,
+ StateInternalsFactory<Void> stateInternalsFactory
) {
this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
this.doFn = doFn;
@@ -112,6 +111,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements
this.sideOutputTags = sideOutputTags;
this.windowingStrategy = windowingStrategy;
this.sideInputs = sideInputs;
+ this.sideInputStateInternals = stateInternalsFactory.stateInternalsForKey(null);
if (sideOutputTags.size() > sideOutputPorts.length) {
String msg = String.format("Too many side outputs (currently only supporting %s).",
@@ -134,6 +134,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements
this.windowingStrategy = null;
this.sideInputs = null;
this.pushedBack = null;
+ this.sideInputStateInternals = null;
}
public final transient DefaultInputPort<ApexStreamTuple<WindowedValue<InputT>>> input =
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java
new file mode 100644
index 0000000..edc1220
--- /dev/null
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternals.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.apex.translators.utils;
+
+import com.esotericsoftware.kryo.DefaultSerializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.serializers.JavaSerializer;
+import com.google.common.collect.HashBasedTable;
+import com.google.common.collect.Table;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.Coder.Context;
+import org.apache.beam.sdk.coders.InstantCoder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Combine.KeyedCombineFn;
+import org.apache.beam.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
+import org.apache.beam.sdk.util.CombineFnUtil;
+import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.State;
+import org.apache.beam.sdk.util.state.StateContext;
+import org.apache.beam.sdk.util.state.StateContexts;
+import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateInternalsFactory;
+import org.apache.beam.sdk.util.state.StateNamespace;
+import org.apache.beam.sdk.util.state.StateTag;
+import org.apache.beam.sdk.util.state.StateTag.StateBinder;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.joda.time.Instant;
+
+/**
+ * Implementation of {@link StateInternals} that can be serialized and
+ * checkpointed with the operator. Suitable for small states, in the future this
+ * should be based on the incremental state saving components in the Apex
+ * library.
+ */
+@DefaultSerializer(JavaSerializer.class)
+public class ApexStateInternals<K> implements StateInternals<K>, Serializable {
+ private static final long serialVersionUID = 1L;
+ public static <K> ApexStateInternals<K> forKey(K key) {
+ return new ApexStateInternals<>(key);
+ }
+
+ private final K key;
+
+ protected ApexStateInternals(K key) {
+ this.key = key;
+ }
+
+ @Override
+ public K getKey() {
+ return key;
+ }
+
+ /**
+ * Serializable state for internals (namespace -> state tag -> coded value).
+ */
+ private final Table<String, String, byte[]> stateTable = HashBasedTable.create();
+
+ @Override
+ public <T extends State> T state(StateNamespace namespace, StateTag<? super K, T> address) {
+ return state(namespace, address, StateContexts.nullContext());
+ }
+
+ @Override
+ public <T extends State> T state(
+ StateNamespace namespace, StateTag<? super K, T> address, final StateContext<?> c) {
+ return address.bind(new ApexStateBinder(key, namespace, address, c));
+ }
+
+ /**
+ * A {@link StateBinder} that returns {@link State} wrappers for serialized state.
+ */
+ private class ApexStateBinder implements StateBinder<K> {
+ private final K key;
+ private final StateNamespace namespace;
+ private final StateContext<?> c;
+
+ private ApexStateBinder(K key, StateNamespace namespace, StateTag<? super K, ?> address,
+ StateContext<?> c) {
+ this.key = key;
+ this.namespace = namespace;
+ this.c = c;
+ }
+
+ @Override
+ public <T> ValueState<T> bindValue(
+ StateTag<? super K, ValueState<T>> address, Coder<T> coder) {
+ return new ApexValueState<T>(namespace, address, coder);
+ }
+
+ @Override
+ public <T> BagState<T> bindBag(
+ final StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) {
+ return new ApexBagState<T>(namespace, address, elemCoder);
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT>
+ bindCombiningValue(
+ StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ final CombineFn<InputT, AccumT, OutputT> combineFn) {
+ return new ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT>(
+ namespace,
+ address,
+ accumCoder,
+ key,
+ combineFn.<K>asKeyedFn()
+ );
+ }
+
+ @Override
+ public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(
+ StateTag<? super K, WatermarkHoldState<W>> address,
+ OutputTimeFn<? super W> outputTimeFn) {
+ return new ApexWatermarkHoldState<W>(namespace, address, outputTimeFn);
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT>
+ bindKeyedCombiningValue(
+ StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) {
+ return new ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT>(
+ namespace,
+ address,
+ accumCoder,
+ key, combineFn);
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT>
+ bindKeyedCombiningValueWithContext(
+ StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) {
+ return bindKeyedCombiningValue(address, accumCoder, CombineFnUtil.bindContext(combineFn, c));
+ }
+ }
+
+ private class AbstractState<T> {
+ protected final StateNamespace namespace;
+ protected final StateTag<?, ? extends State> address;
+ protected final Coder<T> coder;
+
+ private AbstractState(
+ StateNamespace namespace,
+ StateTag<?, ? extends State> address,
+ Coder<T> coder) {
+ this.namespace = namespace;
+ this.address = address;
+ this.coder = coder;
+ }
+
+ protected T readValue() {
+ T value = null;
+ byte[] buf = stateTable.get(namespace.stringKey(), address.getId());
+ if (buf != null) {
+ // TODO: reuse input
+ Input input = new Input(buf);
+ try {
+ return coder.decode(input, Context.OUTER);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return value;
+ }
+
+ public void writeValue(T input) {
+ ByteArrayOutputStream output = new ByteArrayOutputStream();
+ try {
+ coder.encode(input, output, Context.OUTER);
+ stateTable.put(namespace.stringKey(), address.getId(), output.toByteArray());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public void clear() {
+ stateTable.remove(namespace.stringKey(), address.getId());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ @SuppressWarnings("unchecked")
+ AbstractState<?> that = (AbstractState<?>) o;
+ return namespace.equals(that.namespace) && address.equals(that.address);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+ }
+
+ private class ApexValueState<T> extends AbstractState<T> implements ValueState<T> {
+
+ private ApexValueState(
+ StateNamespace namespace,
+ StateTag<?, ValueState<T>> address,
+ Coder<T> coder) {
+ super(namespace, address, coder);
+ }
+
+ @Override
+ public ApexValueState<T> readLater() {
+ return this;
+ }
+
+ @Override
+ public T read() {
+ return readValue();
+ }
+
+ @Override
+ public void write(T input) {
+ writeValue(input);
+ }
+ }
+
+ private final class ApexWatermarkHoldState<W extends BoundedWindow>
+ extends AbstractState<Instant> implements WatermarkHoldState<W> {
+
+ private final OutputTimeFn<? super W> outputTimeFn;
+
+ public ApexWatermarkHoldState(
+ StateNamespace namespace,
+ StateTag<?, WatermarkHoldState<W>> address,
+ OutputTimeFn<? super W> outputTimeFn) {
+ super(namespace, address, InstantCoder.of());
+ this.outputTimeFn = outputTimeFn;
+ }
+
+ @Override
+ public ApexWatermarkHoldState<W> readLater() {
+ return this;
+ }
+
+ @Override
+ public Instant read() {
+ return readValue();
+ }
+
+ @Override
+ public void add(Instant outputTime) {
+ Instant combined = read();
+ combined = (combined == null) ? outputTime : outputTimeFn.combine(combined, outputTime);
+ writeValue(combined);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ @Override
+ public Boolean read() {
+ return stateTable.get(namespace.stringKey(), address.getId()) == null;
+ }
+ };
+ }
+
+ @Override
+ public OutputTimeFn<? super W> getOutputTimeFn() {
+ return outputTimeFn;
+ }
+
+ }
+
+ private final class ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT>
+ extends AbstractState<AccumT>
+ implements AccumulatorCombiningState<InputT, AccumT, OutputT> {
+ private final K key;
+ private final KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn;
+
+ private ApexAccumulatorCombiningState(StateNamespace namespace,
+ StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> coder,
+ K key, KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) {
+ super(namespace, address, coder);
+ this.key = key;
+ this.combineFn = combineFn;
+ }
+
+ @Override
+ public ApexAccumulatorCombiningState<K, InputT, AccumT, OutputT> readLater() {
+ return this;
+ }
+
+ @Override
+ public OutputT read() {
+ return combineFn.extractOutput(key, getAccum());
+ }
+
+ @Override
+ public void add(InputT input) {
+ AccumT accum = getAccum();
+ combineFn.addInput(key, accum, input);
+ writeValue(accum);
+ }
+
+ @Override
+ public AccumT getAccum() {
+ AccumT accum = readValue();
+ if (accum == null) {
+ accum = combineFn.createAccumulator(key);
+ }
+ return accum;
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ @Override
+ public Boolean read() {
+ return stateTable.get(namespace.stringKey(), address.getId()) == null;
+ }
+ };
+ }
+
+ @Override
+ public void addAccum(AccumT accum) {
+ accum = combineFn.mergeAccumulators(key, Arrays.asList(getAccum(), accum));
+ writeValue(accum);
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ return combineFn.mergeAccumulators(key, accumulators);
+ }
+
+ }
+
+ private final class ApexBagState<T> extends AbstractState<List<T>> implements BagState<T> {
+ private ApexBagState(
+ StateNamespace namespace,
+ StateTag<?, BagState<T>> address,
+ Coder<T> coder) {
+ super(namespace, address, ListCoder.of(coder));
+ }
+
+ @Override
+ public ApexBagState<T> readLater() {
+ return this;
+ }
+
+ @Override
+ public List<T> read() {
+ List<T> value = super.readValue();
+ if (value == null) {
+ value = new ArrayList<T>();
+ }
+ return value;
+ }
+
+ @Override
+ public void add(T input) {
+ List<T> value = read();
+ value.add(input);
+ writeValue(value);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+
+ @Override
+ public Boolean read() {
+ return stateTable.get(namespace.stringKey(), address.getId()) == null;
+ }
+ };
+ }
+ }
+
+ /**
+ * Factory for {@link ApexStateInternals}.
+ *
+ * @param <K>
+ */
+ public static class ApexStateInternalsFactory<K>
+ implements StateInternalsFactory<K>, Serializable {
+ private static final long serialVersionUID = 1L;
+
+ @Override
+ public StateInternals<K> stateInternalsForKey(K key) {
+ return ApexStateInternals.forKey(key);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java
index ad22acd..9ea4233 100644
--- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java
+++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/ParDoBoundTranslatorTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import com.datatorrent.api.DAG;
+import com.datatorrent.api.Sink;
import com.datatorrent.lib.util.KryoCloneUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@@ -37,6 +38,7 @@ import org.apache.beam.runners.apex.ApexRunnerResult;
import org.apache.beam.runners.apex.TestApexRunner;
import org.apache.beam.runners.apex.translators.functions.ApexParDoOperator;
import org.apache.beam.runners.apex.translators.io.ApexReadUnboundedInputOperator;
+import org.apache.beam.runners.apex.translators.utils.ApexStateInternals;
import org.apache.beam.runners.apex.translators.utils.ApexStreamTuple;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
@@ -107,14 +109,22 @@ public class ParDoBoundTranslatorTest {
@SuppressWarnings("serial")
private static class Add extends OldDoFn<Integer, Integer> {
- private final Integer number;
+ private Integer number;
+ private PCollectionView<Integer> sideInputView;
- public Add(Integer number) {
+ private Add(Integer number) {
this.number = number;
}
+ private Add(PCollectionView<Integer> sideInputView) {
+ this.sideInputView = sideInputView;
+ }
+
@Override
public void processElement(ProcessContext c) throws Exception {
+ if (sideInputView != null) {
+ number = c.sideInput(sideInputView);
+ }
c.output(c.element() + number);
}
}
@@ -190,17 +200,51 @@ public class ParDoBoundTranslatorTest {
.apply(Sum.integersGlobally().asSingletonView());
ApexParDoOperator<Integer, Integer> operator = new ApexParDoOperator<>(options,
- new Add(0), new TupleTag<Integer>(), TupleTagList.empty().getAll(),
+ new Add(singletonView), new TupleTag<Integer>(), TupleTagList.empty().getAll(),
WindowingStrategy.globalDefault(),
Collections.<PCollectionView<?>>singletonList(singletonView),
- coder);
+ coder,
+ new ApexStateInternals.ApexStateInternalsFactory<Void>()
+ );
operator.setup(null);
operator.beginWindow(0);
- WindowedValue<Integer> wv = WindowedValue.valueInGlobalWindow(0);
- operator.input.process(ApexStreamTuple.DataTuple.of(wv));
- operator.input.process(ApexStreamTuple.WatermarkTuple.<WindowedValue<Integer>>of(0));
- operator.endWindow();
- Assert.assertNotNull("Serialization", KryoCloneUtils.cloneObject(operator));
+ WindowedValue<Integer> wv1 = WindowedValue.valueInGlobalWindow(1);
+ WindowedValue<Iterable<?>> sideInput = WindowedValue.<Iterable<?>>valueInGlobalWindow(
+ Lists.<Integer>newArrayList(22));
+ operator.input.process(ApexStreamTuple.DataTuple.of(wv1)); // pushed back input
+
+ final List<Object> results = Lists.newArrayList();
+ Sink<Object> sink = new Sink<Object>() {
+ @Override
+ public void put(Object tuple) {
+ results.add(tuple);
+ }
+ @Override
+ public int getCount(boolean reset) {
+ return 0;
+ }
+ };
+ // verify pushed back input checkpointing
+ Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator));
+ operator.output.setSink(sink);
+ operator.setup(null);
+ operator.beginWindow(1);
+ WindowedValue<Integer> wv2 = WindowedValue.valueInGlobalWindow(2);
+ operator.sideInput1.process(ApexStreamTuple.DataTuple.of(sideInput));
+ Assert.assertEquals("number outputs", 1, results.size());
+ Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(23),
+ ((ApexStreamTuple.DataTuple) results.get(0)).getValue());
+
+ // verify side input checkpointing
+ results.clear();
+ Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator));
+ operator.output.setSink(sink);
+ operator.setup(null);
+ operator.beginWindow(2);
+ operator.input.process(ApexStreamTuple.DataTuple.of(wv2));
+ Assert.assertEquals("number outputs", 1, results.size());
+ Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(24),
+ ((ApexStreamTuple.DataTuple) results.get(0)).getValue());
}
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1db4ff63/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java
new file mode 100644
index 0000000..055d98c
--- /dev/null
+++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translators/utils/ApexStateInternalsTest.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.apex.translators.utils;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThat;
+
+import com.datatorrent.lib.util.KryoCloneUtils;
+
+import java.util.Arrays;
+
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFns;
+import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.StateMerging;
+import org.apache.beam.sdk.util.state.StateNamespace;
+import org.apache.beam.sdk.util.state.StateNamespaceForTest;
+import org.apache.beam.sdk.util.state.StateTag;
+import org.apache.beam.sdk.util.state.StateTags;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.hamcrest.Matchers;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Tests for {@link ApexStateInternals}. This is based on the tests for
+ * {@code InMemoryStateInternals}.
+ */
+public class ApexStateInternalsTest {
+ private static final BoundedWindow WINDOW_1 = new IntervalWindow(new Instant(0), new Instant(10));
+ private static final StateNamespace NAMESPACE_1 = new StateNamespaceForTest("ns1");
+ private static final StateNamespace NAMESPACE_2 = new StateNamespaceForTest("ns2");
+ private static final StateNamespace NAMESPACE_3 = new StateNamespaceForTest("ns3");
+
+ private static final StateTag<Object, ValueState<String>> STRING_VALUE_ADDR =
+ StateTags.value("stringValue", StringUtf8Coder.of());
+ private static final StateTag<Object, AccumulatorCombiningState<Integer, int[], Integer>>
+ SUM_INTEGER_ADDR = StateTags.combiningValueFromInputInternal(
+ "sumInteger", VarIntCoder.of(), new Sum.SumIntegerFn());
+ private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
+ StateTags.bag("stringBag", StringUtf8Coder.of());
+ private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
+ WATERMARK_EARLIEST_ADDR =
+ StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp());
+ private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
+ WATERMARK_LATEST_ADDR =
+ StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtLatestInputTimestamp());
+ private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> WATERMARK_EOW_ADDR =
+ StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEndOfWindow());
+
+ private ApexStateInternals<String> underTest;
+
+ @Before
+ public void initStateInternals() {
+ underTest = new ApexStateInternals<>(null);
+ }
+
+ @Test
+ public void testBag() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+
+ assertThat(value.read(), Matchers.emptyIterable());
+ value.add("hello");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello"));
+
+ value.add("world");
+ assertThat(value.read(), Matchers.containsInAnyOrder("hello", "world"));
+
+ value.clear();
+ assertThat(value.read(), Matchers.emptyIterable());
+ assertEquals(underTest.state(NAMESPACE_1, STRING_BAG_ADDR), value);
+
+ }
+
+ @Test
+ public void testBagIsEmpty() throws Exception {
+ BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add("hello");
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeBagIntoSource() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2), bag1);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag1.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testMergeBagIntoNewNamespace() throws Exception {
+ BagState<String> bag1 = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
+ BagState<String> bag2 = underTest.state(NAMESPACE_2, STRING_BAG_ADDR);
+ BagState<String> bag3 = underTest.state(NAMESPACE_3, STRING_BAG_ADDR);
+
+ bag1.add("Hello");
+ bag2.add("World");
+ bag1.add("!");
+
+ StateMerging.mergeBags(Arrays.asList(bag1, bag2, bag3), bag3);
+
+ // Reading the merged bag gets both the contents
+ assertThat(bag3.read(), Matchers.containsInAnyOrder("Hello", "World", "!"));
+ assertThat(bag1.read(), Matchers.emptyIterable());
+ assertThat(bag2.read(), Matchers.emptyIterable());
+ }
+
+ @Test
+ public void testCombiningValue() throws Exception {
+ CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
+
+ assertThat(value.read(), Matchers.equalTo(0));
+ value.add(2);
+ assertThat(value.read(), Matchers.equalTo(2));
+
+ value.add(3);
+ assertThat(value.read(), Matchers.equalTo(5));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(0));
+ assertEquals(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), value);
+ }
+
+ @Test
+ public void testCombiningIsEmpty() throws Exception {
+ CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add(5);
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeCombiningValueIntoSource() throws Exception {
+ AccumulatorCombiningState<Integer, int[], Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+ AccumulatorCombiningState<Integer, int[], Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ assertThat(value1.read(), Matchers.equalTo(11));
+ assertThat(value2.read(), Matchers.equalTo(10));
+
+ // Merging clears the old values and updates the result value.
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+ assertThat(value1.read(), Matchers.equalTo(21));
+ assertThat(value2.read(), Matchers.equalTo(0));
+ }
+
+ @Test
+ public void testMergeCombiningValueIntoNewNamespace() throws Exception {
+ AccumulatorCombiningState<Integer, int[], Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
+ AccumulatorCombiningState<Integer, int[], Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+ AccumulatorCombiningState<Integer, int[], Integer> value3 =
+ underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+ // Merging clears the old values and updates the result value.
+ assertThat(value1.read(), Matchers.equalTo(0));
+ assertThat(value2.read(), Matchers.equalTo(0));
+ assertThat(value3.read(), Matchers.equalTo(21));
+ }
+
+ @Test
+ public void testWatermarkEarliestState() throws Exception {
+ WatermarkHoldState<BoundedWindow> value =
+ underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR)));
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.add(new Instant(2000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.add(new Instant(3000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.add(new Instant(1000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(1000)));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(null));
+ assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), value);
+ }
+
+ @Test
+ public void testWatermarkLatestState() throws Exception {
+ WatermarkHoldState<BoundedWindow> value =
+ underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR)));
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.add(new Instant(2000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.add(new Instant(3000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+
+ value.add(new Instant(1000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(null));
+ assertEquals(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), value);
+ }
+
+ @Test
+ public void testWatermarkEndOfWindowState() throws Exception {
+ WatermarkHoldState<BoundedWindow> value = underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR);
+
+ // State instances are cached, but depend on the namespace.
+ assertEquals(value, underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR));
+ assertFalse(value.equals(underTest.state(NAMESPACE_2, WATERMARK_EOW_ADDR)));
+
+ assertThat(value.read(), Matchers.nullValue());
+ value.add(new Instant(2000));
+ assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+
+ value.clear();
+ assertThat(value.read(), Matchers.equalTo(null));
+ assertEquals(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), value);
+ }
+
+ @Test
+ public void testWatermarkStateIsEmpty() throws Exception {
+ WatermarkHoldState<BoundedWindow> value =
+ underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+
+ assertThat(value.isEmpty().read(), Matchers.is(true));
+ ReadableState<Boolean> readFuture = value.isEmpty();
+ value.add(new Instant(1000));
+ assertThat(readFuture.read(), Matchers.is(false));
+
+ value.clear();
+ assertThat(readFuture.read(), Matchers.is(true));
+ }
+
+ @Test
+ public void testMergeEarliestWatermarkIntoSource() throws Exception {
+ WatermarkHoldState<BoundedWindow> value1 =
+ underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR);
+ WatermarkHoldState<BoundedWindow> value2 =
+ underTest.state(NAMESPACE_2, WATERMARK_EARLIEST_ADDR);
+
+ value1.add(new Instant(3000));
+ value2.add(new Instant(5000));
+ value1.add(new Instant(4000));
+ value2.add(new Instant(2000));
+
+ // Merging clears the old values and updates the merged value.
+ StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1);
+
+ assertThat(value1.read(), Matchers.equalTo(new Instant(2000)));
+ assertThat(value2.read(), Matchers.equalTo(null));
+ }
+
+ @Test
+ public void testMergeLatestWatermarkIntoSource() throws Exception {
+ WatermarkHoldState<BoundedWindow> value1 =
+ underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR);
+ WatermarkHoldState<BoundedWindow> value2 =
+ underTest.state(NAMESPACE_2, WATERMARK_LATEST_ADDR);
+ WatermarkHoldState<BoundedWindow> value3 =
+ underTest.state(NAMESPACE_3, WATERMARK_LATEST_ADDR);
+
+ value1.add(new Instant(3000));
+ value2.add(new Instant(5000));
+ value1.add(new Instant(4000));
+ value2.add(new Instant(2000));
+
+ // Merging clears the old values and updates the result value.
+ StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1);
+
+ // Merging clears the old values and updates the result value.
+ assertThat(value3.read(), Matchers.equalTo(new Instant(5000)));
+ assertThat(value1.read(), Matchers.equalTo(null));
+ assertThat(value2.read(), Matchers.equalTo(null));
+ }
+
+ @Test
+ public void testSerialization() throws Exception {
+ ApexStateInternals<String> original = new ApexStateInternals<String>(null);
+ ValueState<String> value = original.state(NAMESPACE_1, STRING_VALUE_ADDR);
+ assertEquals(original.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+ value.write("hello");
+
+ ApexStateInternals<String> cloned;
+ assertNotNull("Serialization", cloned = KryoCloneUtils.cloneObject(original));
+ ValueState<String> clonedValue = cloned.state(NAMESPACE_1, STRING_VALUE_ADDR);
+ assertThat(clonedValue.read(), Matchers.equalTo("hello"));
+ assertEquals(cloned.state(NAMESPACE_1, STRING_VALUE_ADDR), value);
+ }
+
+}