You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ie...@apache.org on 2017/04/19 13:09:13 UTC
[04/18] beam git commit: [BEAM-1994] Remove Flink examples package
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
new file mode 100644
index 0000000..2ed5024
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/BoundedSourceWrapper.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.io;
+
+import com.google.common.annotations.VisibleForTesting;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.flink.api.common.functions.StoppableFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Wrapper for executing {@link BoundedSource BoundedSources} as a Flink Source.
+ */
+public class BoundedSourceWrapper<OutputT>
+ extends RichParallelSourceFunction<WindowedValue<OutputT>>
+ implements StoppableFunction {
+
+ private static final Logger LOG = LoggerFactory.getLogger(BoundedSourceWrapper.class);
+
+ /**
+ * Keep the options so that we can initialize the readers.
+ */
+ private final SerializedPipelineOptions serializedOptions;
+
+ /**
+ * The split sources. We split them in the constructor to ensure that all parallel
+ * sources are consistent about the split sources.
+ */
+ private List<? extends BoundedSource<OutputT>> splitSources;
+
+ /**
+ * Make it a field so that we can access it in {@link #close()}.
+ */
+ private transient List<BoundedSource.BoundedReader<OutputT>> readers;
+
+ /**
+ * Initialize here and not in run() to prevent races where we cancel a job before run() is
+ * ever called or run() is called after cancel().
+ */
+ private volatile boolean isRunning = true;
+
+ @SuppressWarnings("unchecked")
+ public BoundedSourceWrapper(
+ PipelineOptions pipelineOptions,
+ BoundedSource<OutputT> source,
+ int parallelism) throws Exception {
+ this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+
+ long desiredBundleSize = source.getEstimatedSizeBytes(pipelineOptions) / parallelism;
+
+ // get the splits early. we assume that the generated splits are stable,
+ // this is necessary so that the mapping of state to source is correct
+ // when restoring
+ splitSources = source.split(desiredBundleSize, pipelineOptions);
+ }
+
+ @Override
+ public void run(SourceContext<WindowedValue<OutputT>> ctx) throws Exception {
+
+ // figure out which split sources we're responsible for
+ int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+ int numSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
+
+ List<BoundedSource<OutputT>> localSources = new ArrayList<>();
+
+ for (int i = 0; i < splitSources.size(); i++) {
+ if (i % numSubtasks == subtaskIndex) {
+ localSources.add(splitSources.get(i));
+ }
+ }
+
+ LOG.info("Bounded Flink Source {}/{} is reading from sources: {}",
+ subtaskIndex,
+ numSubtasks,
+ localSources);
+
+ readers = new ArrayList<>();
+ // initialize readers from scratch
+ for (BoundedSource<OutputT> source : localSources) {
+ readers.add(source.createReader(serializedOptions.getPipelineOptions()));
+ }
+
+ if (readers.size() == 1) {
+ // the easy case, we just read from one reader
+ BoundedSource.BoundedReader<OutputT> reader = readers.get(0);
+
+ boolean dataAvailable = reader.start();
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ }
+
+ while (isRunning) {
+ dataAvailable = reader.advance();
+
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ } else {
+ break;
+ }
+ }
+ } else {
+ // a bit more complicated, we are responsible for several readers
+ // loop through them and sleep if none of them had any data
+
+ int currentReader = 0;
+
+ // start each reader and emit data if immediately available
+ for (BoundedSource.BoundedReader<OutputT> reader : readers) {
+ boolean dataAvailable = reader.start();
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ }
+ }
+
+ // a flag telling us whether any of the readers had data
+ // if no reader had data, sleep for bit
+ boolean hadData = false;
+ while (isRunning && !readers.isEmpty()) {
+ BoundedSource.BoundedReader<OutputT> reader = readers.get(currentReader);
+ boolean dataAvailable = reader.advance();
+
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ hadData = true;
+ } else {
+ readers.remove(currentReader);
+ currentReader--;
+ if (readers.isEmpty()) {
+ break;
+ }
+ }
+
+ currentReader = (currentReader + 1) % readers.size();
+ if (currentReader == 0 && !hadData) {
+ Thread.sleep(50);
+ } else if (currentReader == 0) {
+ hadData = false;
+ }
+ }
+
+ }
+
+ // emit final Long.MAX_VALUE watermark, just to be sure
+ ctx.emitWatermark(new Watermark(Long.MAX_VALUE));
+ }
+
+ /**
+ * Emit the current element from the given Reader. The reader is guaranteed to have data.
+ */
+ private void emitElement(
+ SourceContext<WindowedValue<OutputT>> ctx,
+ BoundedSource.BoundedReader<OutputT> reader) {
+ // make sure that reader state update and element emission are atomic
+ // with respect to snapshots
+ synchronized (ctx.getCheckpointLock()) {
+
+ OutputT item = reader.getCurrent();
+ Instant timestamp = reader.getCurrentTimestamp();
+
+ WindowedValue<OutputT> windowedValue =
+ WindowedValue.of(item, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING);
+ ctx.collectWithTimestamp(windowedValue, timestamp.getMillis());
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ super.close();
+ if (readers != null) {
+ for (BoundedSource.BoundedReader<OutputT> reader: readers) {
+ reader.close();
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ isRunning = false;
+ }
+
+ @Override
+ public void stop() {
+ this.isRunning = false;
+ }
+
+ /**
+ * Visible so that we can check this in tests. Must not be used for anything else.
+ */
+ @VisibleForTesting
+ public List<? extends BoundedSource<OutputT>> getSplitSources() {
+ return splitSources;
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java
new file mode 100644
index 0000000..910a33f
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSocketSource.java
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.io;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.util.Collections;
+import java.util.List;
+import java.util.NoSuchElementException;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An example unbounded Beam source that reads input from a socket.
+ * This is used mainly for testing and debugging.
+ * */
+public class UnboundedSocketSource<CheckpointMarkT extends UnboundedSource.CheckpointMark>
+ extends UnboundedSource<String, CheckpointMarkT> {
+
+ private static final Coder<String> DEFAULT_SOCKET_CODER = StringUtf8Coder.of();
+
+ private static final long serialVersionUID = 1L;
+
+ private static final int DEFAULT_CONNECTION_RETRY_SLEEP = 500;
+
+ private static final int CONNECTION_TIMEOUT_TIME = 0;
+
+ private final String hostname;
+ private final int port;
+ private final char delimiter;
+ private final long maxNumRetries;
+ private final long delayBetweenRetries;
+
+ public UnboundedSocketSource(String hostname, int port, char delimiter, long maxNumRetries) {
+ this(hostname, port, delimiter, maxNumRetries, DEFAULT_CONNECTION_RETRY_SLEEP);
+ }
+
+ public UnboundedSocketSource(String hostname,
+ int port,
+ char delimiter,
+ long maxNumRetries,
+ long delayBetweenRetries) {
+ this.hostname = hostname;
+ this.port = port;
+ this.delimiter = delimiter;
+ this.maxNumRetries = maxNumRetries;
+ this.delayBetweenRetries = delayBetweenRetries;
+ }
+
+ public String getHostname() {
+ return this.hostname;
+ }
+
+ public int getPort() {
+ return this.port;
+ }
+
+ public char getDelimiter() {
+ return this.delimiter;
+ }
+
+ public long getMaxNumRetries() {
+ return this.maxNumRetries;
+ }
+
+ public long getDelayBetweenRetries() {
+ return this.delayBetweenRetries;
+ }
+
+ @Override
+ public List<? extends UnboundedSource<String, CheckpointMarkT>> split(
+ int desiredNumSplits,
+ PipelineOptions options) throws Exception {
+ return Collections.<UnboundedSource<String, CheckpointMarkT>>singletonList(this);
+ }
+
+ @Override
+ public UnboundedReader<String> createReader(PipelineOptions options,
+ @Nullable CheckpointMarkT checkpointMark) {
+ return new UnboundedSocketReader(this);
+ }
+
+ @Nullable
+ @Override
+ public Coder getCheckpointMarkCoder() {
+ // Flink and Dataflow have different checkpointing mechanisms.
+ // In our case we do not need a coder.
+ return null;
+ }
+
+ @Override
+ public void validate() {
+ checkArgument(port > 0 && port < 65536, "port is out of range");
+ checkArgument(maxNumRetries >= -1, "maxNumRetries must be zero or larger (num retries), "
+ + "or -1 (infinite retries)");
+ checkArgument(delayBetweenRetries >= 0, "delayBetweenRetries must be zero or positive");
+ }
+
+ @Override
+ public Coder getDefaultOutputCoder() {
+ return DEFAULT_SOCKET_CODER;
+ }
+
+ /**
+ * Unbounded socket reader.
+ */
+ public static class UnboundedSocketReader extends UnboundedSource.UnboundedReader<String> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(UnboundedSocketReader.class);
+
+ private final UnboundedSocketSource source;
+
+ private Socket socket;
+ private BufferedReader reader;
+
+ private boolean isRunning;
+
+ private String currentRecord;
+
+ public UnboundedSocketReader(UnboundedSocketSource source) {
+ this.source = source;
+ }
+
+ private void openConnection() throws IOException {
+ this.socket = new Socket();
+ this.socket.connect(new InetSocketAddress(this.source.getHostname(), this.source.getPort()),
+ CONNECTION_TIMEOUT_TIME);
+ this.reader = new BufferedReader(new InputStreamReader(this.socket.getInputStream()));
+ this.isRunning = true;
+ }
+
+ @Override
+ public boolean start() throws IOException {
+ int attempt = 0;
+ while (!isRunning) {
+ try {
+ openConnection();
+ LOG.info("Connected to server socket " + this.source.getHostname() + ':'
+ + this.source.getPort());
+
+ return advance();
+ } catch (IOException e) {
+ LOG.info("Lost connection to server socket " + this.source.getHostname() + ':'
+ + this.source.getPort() + ". Retrying in "
+ + this.source.getDelayBetweenRetries() + " msecs...");
+
+ if (this.source.getMaxNumRetries() == -1 || attempt++ < this.source.getMaxNumRetries()) {
+ try {
+ Thread.sleep(this.source.getDelayBetweenRetries());
+ } catch (InterruptedException e1) {
+ e1.printStackTrace();
+ }
+ } else {
+ this.isRunning = false;
+ break;
+ }
+ }
+ }
+ LOG.error("Unable to connect to host " + this.source.getHostname()
+ + " : " + this.source.getPort());
+ return false;
+ }
+
+ @Override
+ public boolean advance() throws IOException {
+ final StringBuilder buffer = new StringBuilder();
+ int data;
+ while (isRunning && (data = reader.read()) != -1) {
+ // check if the string is complete
+ if (data != this.source.getDelimiter()) {
+ buffer.append((char) data);
+ } else {
+ if (buffer.length() > 0 && buffer.charAt(buffer.length() - 1) == '\r') {
+ buffer.setLength(buffer.length() - 1);
+ }
+ this.currentRecord = buffer.toString();
+ buffer.setLength(0);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public byte[] getCurrentRecordId() throws NoSuchElementException {
+ return new byte[0];
+ }
+
+ @Override
+ public String getCurrent() throws NoSuchElementException {
+ return this.currentRecord;
+ }
+
+ @Override
+ public Instant getCurrentTimestamp() throws NoSuchElementException {
+ return Instant.now();
+ }
+
+ @Override
+ public void close() throws IOException {
+ this.reader.close();
+ this.socket.close();
+ this.isRunning = false;
+ LOG.info("Closed connection to server socket at " + this.source.getHostname() + ":"
+ + this.source.getPort() + ".");
+ }
+
+ @Override
+ public Instant getWatermark() {
+ return Instant.now();
+ }
+
+ @Override
+ public CheckpointMark getCheckpointMark() {
+ return null;
+ }
+
+ @Override
+ public UnboundedSource<String, ?> getCurrentSource() {
+ return this.source;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
new file mode 100644
index 0000000..bb9b58a
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.io;
+
+import com.google.common.annotations.VisibleForTesting;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.io.UnboundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.StoppableFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Wrapper for executing {@link UnboundedSource UnboundedSources} as a Flink Source.
+ */
+public class UnboundedSourceWrapper<
+ OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark>
+ extends RichParallelSourceFunction<WindowedValue<OutputT>>
+ implements ProcessingTimeCallback, StoppableFunction,
+ CheckpointListener, CheckpointedFunction {
+
+ private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceWrapper.class);
+
+ /**
+ * Keep the options so that we can initialize the localReaders.
+ */
+ private final SerializedPipelineOptions serializedOptions;
+
+ /**
+ * For snapshot and restore.
+ */
+ private final KvCoder<
+ ? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> checkpointCoder;
+
+ /**
+ * The split sources. We split them in the constructor to ensure that all parallel
+ * sources are consistent about the split sources.
+ */
+ private final List<? extends UnboundedSource<OutputT, CheckpointMarkT>> splitSources;
+
+ /**
+ * The local split sources. Assigned at runtime when the wrapper is executed in parallel.
+ */
+ private transient List<UnboundedSource<OutputT, CheckpointMarkT>> localSplitSources;
+
+ /**
+ * The local split readers. Assigned at runtime when the wrapper is executed in parallel.
+ * Make it a field so that we can access it in {@link #onProcessingTime(long)} for
+ * emitting watermarks.
+ */
+ private transient List<UnboundedSource.UnboundedReader<OutputT>> localReaders;
+
+ /**
+ * Flag to indicate whether the source is running.
+ * Initialize here and not in run() to prevent races where we cancel a job before run() is
+ * ever called or run() is called after cancel().
+ */
+ private volatile boolean isRunning = true;
+
+ /**
+ * Make it a field so that we can access it in {@link #onProcessingTime(long)} for registering new
+ * triggers.
+ */
+ private transient StreamingRuntimeContext runtimeContext;
+
+ /**
+ * Make it a field so that we can access it in {@link #onProcessingTime(long)} for emitting
+ * watermarks.
+ */
+ private transient SourceContext<WindowedValue<OutputT>> context;
+
+ /**
+ * Pending checkpoints which have not been acknowledged yet.
+ */
+ private transient LinkedHashMap<Long, List<CheckpointMarkT>> pendingCheckpoints;
+ /**
+ * Keep a maximum of 32 checkpoints for {@code CheckpointMark.finalizeCheckpoint()}.
+ */
+ private static final int MAX_NUMBER_PENDING_CHECKPOINTS = 32;
+
+ private transient ListState<KV<? extends
+ UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> stateForCheckpoint;
+
+ /**
+ * false if checkpointCoder is null or no restore state by starting first.
+ */
+ private transient boolean isRestored = false;
+
+ @SuppressWarnings("unchecked")
+ public UnboundedSourceWrapper(
+ PipelineOptions pipelineOptions,
+ UnboundedSource<OutputT, CheckpointMarkT> source,
+ int parallelism) throws Exception {
+ this.serializedOptions = new SerializedPipelineOptions(pipelineOptions);
+
+ if (source.requiresDeduping()) {
+ LOG.warn("Source {} requires deduping but Flink runner doesn't support this yet.", source);
+ }
+
+ Coder<CheckpointMarkT> checkpointMarkCoder = source.getCheckpointMarkCoder();
+ if (checkpointMarkCoder == null) {
+ LOG.info("No CheckpointMarkCoder specified for this source. Won't create snapshots.");
+ checkpointCoder = null;
+ } else {
+
+ Coder<? extends UnboundedSource<OutputT, CheckpointMarkT>> sourceCoder =
+ (Coder) SerializableCoder.of(new TypeDescriptor<UnboundedSource>() {
+ });
+
+ checkpointCoder = KvCoder.of(sourceCoder, checkpointMarkCoder);
+ }
+
+ // get the splits early. we assume that the generated splits are stable,
+ // this is necessary so that the mapping of state to source is correct
+ // when restoring
+ splitSources = source.split(parallelism, pipelineOptions);
+ }
+
+
+ /**
+ * Initialize and restore state before starting execution of the source.
+ */
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ runtimeContext = (StreamingRuntimeContext) getRuntimeContext();
+
+ // figure out which split sources we're responsible for
+ int subtaskIndex = runtimeContext.getIndexOfThisSubtask();
+ int numSubtasks = runtimeContext.getNumberOfParallelSubtasks();
+
+ localSplitSources = new ArrayList<>();
+ localReaders = new ArrayList<>();
+
+ pendingCheckpoints = new LinkedHashMap<>();
+
+ if (isRestored) {
+ // restore the splitSources from the checkpoint to ensure consistent ordering
+ for (KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> restored:
+ stateForCheckpoint.get()) {
+ localSplitSources.add(restored.getKey());
+ localReaders.add(restored.getKey().createReader(
+ serializedOptions.getPipelineOptions(), restored.getValue()));
+ }
+ } else {
+ // initialize localReaders and localSources from scratch
+ for (int i = 0; i < splitSources.size(); i++) {
+ if (i % numSubtasks == subtaskIndex) {
+ UnboundedSource<OutputT, CheckpointMarkT> source =
+ splitSources.get(i);
+ UnboundedSource.UnboundedReader<OutputT> reader =
+ source.createReader(serializedOptions.getPipelineOptions(), null);
+ localSplitSources.add(source);
+ localReaders.add(reader);
+ }
+ }
+ }
+
+ LOG.info("Unbounded Flink Source {}/{} is reading from sources: {}",
+ subtaskIndex,
+ numSubtasks,
+ localSplitSources);
+ }
+
+ @Override
+ public void run(SourceContext<WindowedValue<OutputT>> ctx) throws Exception {
+
+ context = ctx;
+
+ if (localReaders.size() == 0) {
+ // do nothing, but still look busy ...
+ // also, output a Long.MAX_VALUE watermark since we know that we're not
+ // going to emit anything
+ // we can't return here since Flink requires that all operators stay up,
+ // otherwise checkpointing would not work correctly anymore
+ ctx.emitWatermark(new Watermark(Long.MAX_VALUE));
+
+ // wait until this is canceled
+ final Object waitLock = new Object();
+ while (isRunning) {
+ try {
+ // Flink will interrupt us at some point
+ //noinspection SynchronizationOnLocalVariableOrMethodParameter
+ synchronized (waitLock) {
+ // don't wait indefinitely, in case something goes horribly wrong
+ waitLock.wait(1000);
+ }
+ } catch (InterruptedException e) {
+ if (!isRunning) {
+ // restore the interrupted state, and fall through the loop
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
+ } else if (localReaders.size() == 1) {
+ // the easy case, we just read from one reader
+ UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(0);
+
+ boolean dataAvailable = reader.start();
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ }
+
+ setNextWatermarkTimer(this.runtimeContext);
+
+ while (isRunning) {
+ dataAvailable = reader.advance();
+
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ } else {
+ Thread.sleep(50);
+ }
+ }
+ } else {
+ // a bit more complicated, we are responsible for several localReaders
+ // loop through them and sleep if none of them had any data
+
+ int numReaders = localReaders.size();
+ int currentReader = 0;
+
+ // start each reader and emit data if immediately available
+ for (UnboundedSource.UnboundedReader<OutputT> reader : localReaders) {
+ boolean dataAvailable = reader.start();
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ }
+ }
+
+ // a flag telling us whether any of the localReaders had data
+ // if no reader had data, sleep for bit
+ boolean hadData = false;
+ while (isRunning) {
+ UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(currentReader);
+ boolean dataAvailable = reader.advance();
+
+ if (dataAvailable) {
+ emitElement(ctx, reader);
+ hadData = true;
+ }
+
+ currentReader = (currentReader + 1) % numReaders;
+ if (currentReader == 0 && !hadData) {
+ Thread.sleep(50);
+ } else if (currentReader == 0) {
+ hadData = false;
+ }
+ }
+
+ }
+ }
+
+ /**
+ * Emit the current element from the given Reader. The reader is guaranteed to have data.
+ */
+ private void emitElement(
+ SourceContext<WindowedValue<OutputT>> ctx,
+ UnboundedSource.UnboundedReader<OutputT> reader) {
+ // make sure that reader state update and element emission are atomic
+ // with respect to snapshots
+ synchronized (ctx.getCheckpointLock()) {
+
+ OutputT item = reader.getCurrent();
+ Instant timestamp = reader.getCurrentTimestamp();
+
+ WindowedValue<OutputT> windowedValue =
+ WindowedValue.of(item, timestamp, GlobalWindow.INSTANCE, PaneInfo.NO_FIRING);
+ ctx.collectWithTimestamp(windowedValue, timestamp.getMillis());
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ super.close();
+ if (localReaders != null) {
+ for (UnboundedSource.UnboundedReader<OutputT> reader: localReaders) {
+ reader.close();
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ isRunning = false;
+ }
+
+ @Override
+ public void stop() {
+ isRunning = false;
+ }
+
+ // ------------------------------------------------------------------------
+ // Checkpoint and restore
+ // ------------------------------------------------------------------------
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+ if (!isRunning) {
+ LOG.debug("snapshotState() called on closed source");
+ } else {
+
+ if (checkpointCoder == null) {
+ // no checkpoint coder available in this source
+ return;
+ }
+
+ stateForCheckpoint.clear();
+
+ long checkpointId = functionSnapshotContext.getCheckpointId();
+
+ // we checkpoint the sources along with the CheckpointMarkT to ensure
+ // than we have a correct mapping of checkpoints to sources when
+ // restoring
+ List<CheckpointMarkT> checkpointMarks = new ArrayList<>(localSplitSources.size());
+
+ for (int i = 0; i < localSplitSources.size(); i++) {
+ UnboundedSource<OutputT, CheckpointMarkT> source = localSplitSources.get(i);
+ UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(i);
+
+ @SuppressWarnings("unchecked")
+ CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark();
+ checkpointMarks.add(mark);
+ KV<UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> kv =
+ KV.of(source, mark);
+ stateForCheckpoint.add(kv);
+ }
+
+ // cleanup old pending checkpoints and add new checkpoint
+ int diff = pendingCheckpoints.size() - MAX_NUMBER_PENDING_CHECKPOINTS;
+ if (diff >= 0) {
+ for (Iterator<Long> iterator = pendingCheckpoints.keySet().iterator();
+ diff >= 0;
+ diff--) {
+ iterator.next();
+ iterator.remove();
+ }
+ }
+ pendingCheckpoints.put(checkpointId, checkpointMarks);
+
+ }
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ if (checkpointCoder == null) {
+ // no checkpoint coder available in this source
+ return;
+ }
+
+ OperatorStateStore stateStore = context.getOperatorStateStore();
+ CoderTypeInformation<
+ KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>>
+ typeInformation = (CoderTypeInformation) new CoderTypeInformation<>(checkpointCoder);
+ stateForCheckpoint = stateStore.getOperatorState(
+ new ListStateDescriptor<>(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
+ typeInformation.createSerializer(new ExecutionConfig())));
+
+ if (context.isRestored()) {
+ isRestored = true;
+ LOG.info("Having restore state in the UnbounedSourceWrapper.");
+ } else {
+ LOG.info("No restore state for UnbounedSourceWrapper.");
+ }
+ }
+
+ @Override
+ public void onProcessingTime(long timestamp) throws Exception {
+ if (this.isRunning) {
+ synchronized (context.getCheckpointLock()) {
+ // find minimum watermark over all localReaders
+ long watermarkMillis = Long.MAX_VALUE;
+ for (UnboundedSource.UnboundedReader<OutputT> reader: localReaders) {
+ Instant watermark = reader.getWatermark();
+ if (watermark != null) {
+ watermarkMillis = Math.min(watermark.getMillis(), watermarkMillis);
+ }
+ }
+ context.emitWatermark(new Watermark(watermarkMillis));
+ }
+ setNextWatermarkTimer(this.runtimeContext);
+ }
+ }
+
+ private void setNextWatermarkTimer(StreamingRuntimeContext runtime) {
+ if (this.isRunning) {
+ long watermarkInterval = runtime.getExecutionConfig().getAutoWatermarkInterval();
+ long timeToNextWatermark = getTimeToNextWatermark(watermarkInterval);
+ runtime.getProcessingTimeService().registerTimer(timeToNextWatermark, this);
+ }
+ }
+
+ private long getTimeToNextWatermark(long watermarkInterval) {
+ return System.currentTimeMillis() + watermarkInterval;
+ }
+
+ /**
+ * Visible so that we can check this in tests. Must not be used for anything else.
+ */
+ @VisibleForTesting
+ public List<? extends UnboundedSource<OutputT, CheckpointMarkT>> getSplitSources() {
+ return splitSources;
+ }
+
+ /**
+ * Visible so that we can check this in tests. Must not be used for anything else.
+ */
+ @VisibleForTesting
+ public List<? extends UnboundedSource<OutputT, CheckpointMarkT>> getLocalSplitSources() {
+ return localSplitSources;
+ }
+
+ @Override
+ public void notifyCheckpointComplete(long checkpointId) throws Exception {
+
+ List<CheckpointMarkT> checkpointMarks = pendingCheckpoints.get(checkpointId);
+
+ if (checkpointMarks != null) {
+
+ // remove old checkpoints including the current one
+ Iterator<Long> iterator = pendingCheckpoints.keySet().iterator();
+ long currentId;
+ do {
+ currentId = iterator.next();
+ iterator.remove();
+ } while (currentId != checkpointId);
+
+ // confirm all marks
+ for (CheckpointMarkT mark : checkpointMarks) {
+ mark.finalizeCheckpoint();
+ }
+
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java
new file mode 100644
index 0000000..b431ce7
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Internal implementation of the Beam runner for Apache Flink.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.io;
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java
new file mode 100644
index 0000000..0674871
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Internal implementation of the Beam runner for Apache Flink.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming;
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
new file mode 100644
index 0000000..3203446
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
@@ -0,0 +1,865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.MapCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.CombineWithContext;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
+import org.apache.beam.sdk.util.CombineContextFactory;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
+import org.apache.beam.sdk.util.state.State;
+import org.apache.beam.sdk.util.state.StateContext;
+import org.apache.beam.sdk.util.state.StateContexts;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+
+/**
+ * {@link StateInternals} that uses a Flink {@link DefaultOperatorStateBackend}
+ * to manage the broadcast state.
+ * The state is the same on all parallel instances of the operator.
+ * So we just need store state of operator-0 in OperatorStateBackend.
+ *
+ * <p>Note: Ignore index of key.
+ * Mainly for SideInputs.
+ */
+public class FlinkBroadcastStateInternals<K> implements StateInternals<K> {
+
+ private int indexInSubtaskGroup;
+ private final DefaultOperatorStateBackend stateBackend;
+ // stateName -> <namespace, state>
+ private Map<String, Map<String, ?>> stateForNonZeroOperator;
+
+ public FlinkBroadcastStateInternals(int indexInSubtaskGroup, OperatorStateBackend stateBackend) {
+ //TODO flink do not yet expose through public API
+ this.stateBackend = (DefaultOperatorStateBackend) stateBackend;
+ this.indexInSubtaskGroup = indexInSubtaskGroup;
+ if (indexInSubtaskGroup != 0) {
+ stateForNonZeroOperator = new HashMap<>();
+ }
+ }
+
+ @Override
+ public K getKey() {
+ return null;
+ }
+
+ @Override
+ public <T extends State> T state(
+ final StateNamespace namespace,
+ StateTag<? super K, T> address) {
+
+ return state(namespace, address, StateContexts.nullContext());
+ }
+
+ @Override
+ public <T extends State> T state(
+ final StateNamespace namespace,
+ StateTag<? super K, T> address,
+ final StateContext<?> context) {
+
+ return address.bind(new StateTag.StateBinder<K>() {
+
+ @Override
+ public <T> ValueState<T> bindValue(
+ StateTag<? super K, ValueState<T>> address,
+ Coder<T> coder) {
+
+ return new FlinkBroadcastValueState<>(stateBackend, address, namespace, coder);
+ }
+
+ @Override
+ public <T> BagState<T> bindBag(
+ StateTag<? super K, BagState<T>> address,
+ Coder<T> elemCoder) {
+
+ return new FlinkBroadcastBagState<>(stateBackend, address, namespace, elemCoder);
+ }
+
+ @Override
+ public <T> SetState<T> bindSet(
+ StateTag<? super K, SetState<T>> address,
+ Coder<T> elemCoder) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", SetState.class.getSimpleName()));
+ }
+
+ @Override
+ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+ StateTag<? super K, MapState<KeyT, ValueT>> spec,
+ Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", MapState.class.getSimpleName()));
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT>
+ CombiningState<InputT, AccumT, OutputT>
+ bindCombiningValue(
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ Combine.CombineFn<InputT, AccumT, OutputT> combineFn) {
+
+ return new FlinkCombiningState<>(
+ stateBackend, address, combineFn, namespace, accumCoder);
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT>
+ CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue(
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) {
+ return new FlinkKeyedCombiningState<>(
+ stateBackend,
+ address,
+ combineFn,
+ namespace,
+ accumCoder,
+ FlinkBroadcastStateInternals.this);
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT>
+ CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext(
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ CombineWithContext.KeyedCombineFnWithContext<
+ ? super K, InputT, AccumT, OutputT> combineFn) {
+ return new FlinkCombiningStateWithContext<>(
+ stateBackend,
+ address,
+ combineFn,
+ namespace,
+ accumCoder,
+ FlinkBroadcastStateInternals.this,
+ CombineContextFactory.createFromStateContext(context));
+ }
+
+ @Override
+ public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(
+ StateTag<? super K, WatermarkHoldState<W>> address,
+ OutputTimeFn<? super W> outputTimeFn) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", WatermarkHoldState.class.getSimpleName()));
+ }
+ });
+ }
+
+ /**
+ * 1. The way we would use it is to only checkpoint anything from the operator
+ * with subtask index 0 because we assume that the state is the same on all
+ * parallel instances of the operator.
+ *
+ * <p>2. Use map to support namespace.
+ */
+ private abstract class AbstractBroadcastState<T> {
+
+ private String name;
+ private final StateNamespace namespace;
+ private final ListStateDescriptor<Map<String, T>> flinkStateDescriptor;
+ private final DefaultOperatorStateBackend flinkStateBackend;
+
+ AbstractBroadcastState(
+ DefaultOperatorStateBackend flinkStateBackend,
+ String name,
+ StateNamespace namespace,
+ Coder<T> coder) {
+ this.name = name;
+
+ this.namespace = namespace;
+ this.flinkStateBackend = flinkStateBackend;
+
+ CoderTypeInformation<Map<String, T>> typeInfo =
+ new CoderTypeInformation<>(MapCoder.of(StringUtf8Coder.of(), coder));
+
+ flinkStateDescriptor = new ListStateDescriptor<>(name,
+ typeInfo.createSerializer(new ExecutionConfig()));
+ }
+
+ /**
+ * Get map(namespce->T) from index 0.
+ */
+ Map<String, T> getMap() throws Exception {
+ if (indexInSubtaskGroup == 0) {
+ return getMapFromBroadcastState();
+ } else {
+ Map<String, T> result = (Map<String, T>) stateForNonZeroOperator.get(name);
+ // maybe restore from BroadcastState of Operator-0
+ if (result == null) {
+ result = getMapFromBroadcastState();
+ if (result != null) {
+ stateForNonZeroOperator.put(name, result);
+ // we don't need it anymore, must clear it.
+ flinkStateBackend.getBroadcastOperatorState(
+ flinkStateDescriptor).clear();
+ }
+ }
+ return result;
+ }
+ }
+
+ Map<String, T> getMapFromBroadcastState() throws Exception {
+ ListState<Map<String, T>> state = flinkStateBackend.getBroadcastOperatorState(
+ flinkStateDescriptor);
+ Iterable<Map<String, T>> iterable = state.get();
+ Map<String, T> ret = null;
+ if (iterable != null) {
+ // just use index 0
+ Iterator<Map<String, T>> iterator = iterable.iterator();
+ if (iterator.hasNext()) {
+ ret = iterator.next();
+ }
+ }
+ return ret;
+ }
+
+ /**
+ * Update map(namespce->T) from index 0.
+ */
+ void updateMap(Map<String, T> map) throws Exception {
+ if (indexInSubtaskGroup == 0) {
+ ListState<Map<String, T>> state = flinkStateBackend.getBroadcastOperatorState(
+ flinkStateDescriptor);
+ state.clear();
+ if (map.size() > 0) {
+ state.add(map);
+ }
+ } else {
+ if (map.size() == 0) {
+ stateForNonZeroOperator.remove(name);
+ // updateMap is always behind getMap,
+ // getMap will clear map in BroadcastOperatorState,
+ // we don't need clear here.
+ } else {
+ stateForNonZeroOperator.put(name, map);
+ }
+ }
+ }
+
+ void writeInternal(T input) {
+ try {
+ Map<String, T> map = getMap();
+ if (map == null) {
+ map = new HashMap<>();
+ }
+ map.put(namespace.stringKey(), input);
+ updateMap(map);
+ } catch (Exception e) {
+ throw new RuntimeException("Error updating state.", e);
+ }
+ }
+
+ T readInternal() {
+ try {
+ Map<String, T> map = getMap();
+ if (map == null) {
+ return null;
+ } else {
+ return map.get(namespace.stringKey());
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+ }
+
+ void clearInternal() {
+ try {
+ Map<String, T> map = getMap();
+ if (map != null) {
+ map.remove(namespace.stringKey());
+ updateMap(map);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Error clearing state.", e);
+ }
+ }
+
+ }
+
+ private class FlinkBroadcastValueState<K, T>
+ extends AbstractBroadcastState<T> implements ValueState<T> {
+
+ private final StateNamespace namespace;
+ private final StateTag<? super K, ValueState<T>> address;
+
+ FlinkBroadcastValueState(
+ DefaultOperatorStateBackend flinkStateBackend,
+ StateTag<? super K, ValueState<T>> address,
+ StateNamespace namespace,
+ Coder<T> coder) {
+ super(flinkStateBackend, address.getId(), namespace, coder);
+
+ this.namespace = namespace;
+ this.address = address;
+
+ }
+
+ @Override
+ public void write(T input) {
+ writeInternal(input);
+ }
+
+ @Override
+ public ValueState<T> readLater() {
+ return this;
+ }
+
+ @Override
+ public T read() {
+ return readInternal();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ FlinkBroadcastValueState<?, ?> that = (FlinkBroadcastValueState<?, ?>) o;
+
+ return namespace.equals(that.namespace) && address.equals(that.address);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+
+ @Override
+ public void clear() {
+ clearInternal();
+ }
+ }
+
+ private class FlinkBroadcastBagState<K, T> extends AbstractBroadcastState<List<T>>
+ implements BagState<T> {
+
+ private final StateNamespace namespace;
+ private final StateTag<? super K, BagState<T>> address;
+
+ FlinkBroadcastBagState(
+ DefaultOperatorStateBackend flinkStateBackend,
+ StateTag<? super K, BagState<T>> address,
+ StateNamespace namespace,
+ Coder<T> coder) {
+ super(flinkStateBackend, address.getId(), namespace, ListCoder.of(coder));
+
+ this.namespace = namespace;
+ this.address = address;
+ }
+
+ @Override
+ public void add(T input) {
+ List<T> list = readInternal();
+ if (list == null) {
+ list = new ArrayList<>();
+ }
+ list.add(input);
+ writeInternal(list);
+ }
+
+ @Override
+ public BagState<T> readLater() {
+ return this;
+ }
+
+ @Override
+ public Iterable<T> read() {
+ List<T> result = readInternal();
+ return result != null ? result : Collections.<T>emptyList();
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ try {
+ List<T> result = readInternal();
+ return result == null;
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ clearInternal();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ FlinkBroadcastBagState<?, ?> that = (FlinkBroadcastBagState<?, ?>) o;
+
+ return namespace.equals(that.namespace) && address.equals(that.address);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+ }
+
+ private class FlinkCombiningState<K, InputT, AccumT, OutputT>
+ extends AbstractBroadcastState<AccumT>
+ implements CombiningState<InputT, AccumT, OutputT> {
+
+ private final StateNamespace namespace;
+ private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address;
+ private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
+
+ FlinkCombiningState(
+ DefaultOperatorStateBackend flinkStateBackend,
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Combine.CombineFn<InputT, AccumT, OutputT> combineFn,
+ StateNamespace namespace,
+ Coder<AccumT> accumCoder) {
+ super(flinkStateBackend, address.getId(), namespace, accumCoder);
+
+ this.namespace = namespace;
+ this.address = address;
+ this.combineFn = combineFn;
+ }
+
+ @Override
+ public CombiningState<InputT, AccumT, OutputT> readLater() {
+ return this;
+ }
+
+ @Override
+ public void add(InputT value) {
+ AccumT current = readInternal();
+ if (current == null) {
+ current = combineFn.createAccumulator();
+ }
+ current = combineFn.addInput(current, value);
+ writeInternal(current);
+ }
+
+ @Override
+ public void addAccum(AccumT accum) {
+ AccumT current = readInternal();
+
+ if (current == null) {
+ writeInternal(accum);
+ } else {
+ current = combineFn.mergeAccumulators(Arrays.asList(current, accum));
+ writeInternal(current);
+ }
+ }
+
+ @Override
+ public AccumT getAccum() {
+ return readInternal();
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ return combineFn.mergeAccumulators(accumulators);
+ }
+
+ @Override
+ public OutputT read() {
+ AccumT accum = readInternal();
+ if (accum != null) {
+ return combineFn.extractOutput(accum);
+ } else {
+ return combineFn.extractOutput(combineFn.createAccumulator());
+ }
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ try {
+ return readInternal() == null;
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ clearInternal();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ FlinkCombiningState<?, ?, ?, ?> that =
+ (FlinkCombiningState<?, ?, ?, ?>) o;
+
+ return namespace.equals(that.namespace) && address.equals(that.address);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+ }
+
+ private class FlinkKeyedCombiningState<K, InputT, AccumT, OutputT>
+ extends AbstractBroadcastState<AccumT>
+ implements CombiningState<InputT, AccumT, OutputT> {
+
+ private final StateNamespace namespace;
+ private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address;
+ private final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn;
+ private final FlinkBroadcastStateInternals<K> flinkStateInternals;
+
+ FlinkKeyedCombiningState(
+ DefaultOperatorStateBackend flinkStateBackend,
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn,
+ StateNamespace namespace,
+ Coder<AccumT> accumCoder,
+ FlinkBroadcastStateInternals<K> flinkStateInternals) {
+ super(flinkStateBackend, address.getId(), namespace, accumCoder);
+
+ this.namespace = namespace;
+ this.address = address;
+ this.combineFn = combineFn;
+ this.flinkStateInternals = flinkStateInternals;
+
+ }
+
+ @Override
+ public CombiningState<InputT, AccumT, OutputT> readLater() {
+ return this;
+ }
+
+ @Override
+ public void add(InputT value) {
+ try {
+ AccumT current = readInternal();
+ if (current == null) {
+ current = combineFn.createAccumulator(flinkStateInternals.getKey());
+ }
+ current = combineFn.addInput(flinkStateInternals.getKey(), current, value);
+ writeInternal(current);
+ } catch (Exception e) {
+ throw new RuntimeException("Error adding to state." , e);
+ }
+ }
+
+ @Override
+ public void addAccum(AccumT accum) {
+ try {
+ AccumT current = readInternal();
+ if (current == null) {
+ writeInternal(accum);
+ } else {
+ current = combineFn.mergeAccumulators(
+ flinkStateInternals.getKey(),
+ Arrays.asList(current, accum));
+ writeInternal(current);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Error adding to state.", e);
+ }
+ }
+
+ @Override
+ public AccumT getAccum() {
+ try {
+ return readInternal();
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators);
+ }
+
+ @Override
+ public OutputT read() {
+ try {
+ AccumT accum = readInternal();
+ if (accum != null) {
+ return combineFn.extractOutput(flinkStateInternals.getKey(), accum);
+ } else {
+ return combineFn.extractOutput(
+ flinkStateInternals.getKey(),
+ combineFn.createAccumulator(flinkStateInternals.getKey()));
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ try {
+ return readInternal() == null;
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ clearInternal();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ FlinkKeyedCombiningState<?, ?, ?, ?> that =
+ (FlinkKeyedCombiningState<?, ?, ?, ?>) o;
+
+ return namespace.equals(that.namespace) && address.equals(that.address);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+ }
+
+ private class FlinkCombiningStateWithContext<K, InputT, AccumT, OutputT>
+ extends AbstractBroadcastState<AccumT>
+ implements CombiningState<InputT, AccumT, OutputT> {
+
+ private final StateNamespace namespace;
+ private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address;
+ private final CombineWithContext.KeyedCombineFnWithContext<
+ ? super K, InputT, AccumT, OutputT> combineFn;
+ private final FlinkBroadcastStateInternals<K> flinkStateInternals;
+ private final CombineWithContext.Context context;
+
+ FlinkCombiningStateWithContext(
+ DefaultOperatorStateBackend flinkStateBackend,
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ CombineWithContext.KeyedCombineFnWithContext<
+ ? super K, InputT, AccumT, OutputT> combineFn,
+ StateNamespace namespace,
+ Coder<AccumT> accumCoder,
+ FlinkBroadcastStateInternals<K> flinkStateInternals,
+ CombineWithContext.Context context) {
+ super(flinkStateBackend, address.getId(), namespace, accumCoder);
+
+ this.namespace = namespace;
+ this.address = address;
+ this.combineFn = combineFn;
+ this.flinkStateInternals = flinkStateInternals;
+ this.context = context;
+
+ }
+
+ @Override
+ public CombiningState<InputT, AccumT, OutputT> readLater() {
+ return this;
+ }
+
+ @Override
+ public void add(InputT value) {
+ try {
+ AccumT current = readInternal();
+ if (current == null) {
+ current = combineFn.createAccumulator(flinkStateInternals.getKey(), context);
+ }
+ current = combineFn.addInput(flinkStateInternals.getKey(), current, value, context);
+ writeInternal(current);
+ } catch (Exception e) {
+ throw new RuntimeException("Error adding to state." , e);
+ }
+ }
+
+ @Override
+ public void addAccum(AccumT accum) {
+ try {
+
+ AccumT current = readInternal();
+ if (current == null) {
+ writeInternal(accum);
+ } else {
+ current = combineFn.mergeAccumulators(
+ flinkStateInternals.getKey(),
+ Arrays.asList(current, accum),
+ context);
+ writeInternal(current);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Error adding to state.", e);
+ }
+ }
+
+ @Override
+ public AccumT getAccum() {
+ try {
+ return readInternal();
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+ }
+
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
+ return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators, context);
+ }
+
+ @Override
+ public OutputT read() {
+ try {
+ AccumT accum = readInternal();
+ return combineFn.extractOutput(flinkStateInternals.getKey(), accum, context);
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ try {
+ return readInternal() == null;
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ clearInternal();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ FlinkCombiningStateWithContext<?, ?, ?, ?> that =
+ (FlinkCombiningStateWithContext<?, ?, ?, ?>) o;
+
+ return namespace.equals(that.namespace) && address.equals(that.address);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/cdd2544b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java
----------------------------------------------------------------------
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java
new file mode 100644
index 0000000..24b340e
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateTag;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.Coder.Context;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.CombineWithContext;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
+import org.apache.beam.sdk.util.state.State;
+import org.apache.beam.sdk.util.state.StateContext;
+import org.apache.beam.sdk.util.state.StateContexts;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.util.state.WatermarkHoldState;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.KeyGroupsList;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.streaming.api.operators.HeapInternalTimerService;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * {@link StateInternals} that uses {@link KeyGroupCheckpointedOperator}
+ * to checkpoint state.
+ *
+ * <p>Note:
+ * Ignore index of key.
+ * Just implement BagState.
+ *
+ * <p>Reference from {@link HeapInternalTimerService} to the local key-group range.
+ */
+public class FlinkKeyGroupStateInternals<K> implements StateInternals<K> {
+
+ private final Coder<K> keyCoder;
+ private final KeyGroupsList localKeyGroupRange;
+ private KeyedStateBackend keyedStateBackend;
+ private final int localKeyGroupRangeStartIdx;
+
+ // stateName -> namespace -> (valueCoder, value)
+ private final Map<String, Tuple2<Coder<?>, Map<String, ?>>>[] stateTables;
+
+ public FlinkKeyGroupStateInternals(
+ Coder<K> keyCoder,
+ KeyedStateBackend keyedStateBackend) {
+ this.keyCoder = keyCoder;
+ this.keyedStateBackend = keyedStateBackend;
+ this.localKeyGroupRange = keyedStateBackend.getKeyGroupRange();
+ // find the starting index of the local key-group range
+ int startIdx = Integer.MAX_VALUE;
+ for (Integer keyGroupIdx : localKeyGroupRange) {
+ startIdx = Math.min(keyGroupIdx, startIdx);
+ }
+ this.localKeyGroupRangeStartIdx = startIdx;
+ stateTables = (Map<String, Tuple2<Coder<?>, Map<String, ?>>>[])
+ new Map[localKeyGroupRange.getNumberOfKeyGroups()];
+ for (int i = 0; i < stateTables.length; i++) {
+ stateTables[i] = new HashMap<>();
+ }
+ }
+
+ @Override
+ public K getKey() {
+ ByteBuffer keyBytes = (ByteBuffer) keyedStateBackend.getCurrentKey();
+ try {
+ return CoderUtils.decodeFromByteArray(keyCoder, keyBytes.array());
+ } catch (CoderException e) {
+ throw new RuntimeException("Error decoding key.", e);
+ }
+ }
+
+ @Override
+ public <T extends State> T state(
+ final StateNamespace namespace,
+ StateTag<? super K, T> address) {
+
+ return state(namespace, address, StateContexts.nullContext());
+ }
+
+ @Override
+ public <T extends State> T state(
+ final StateNamespace namespace,
+ StateTag<? super K, T> address,
+ final StateContext<?> context) {
+
+ return address.bind(new StateTag.StateBinder<K>() {
+
+ @Override
+ public <T> ValueState<T> bindValue(
+ StateTag<? super K, ValueState<T>> address,
+ Coder<T> coder) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", ValueState.class.getSimpleName()));
+ }
+
+ @Override
+ public <T> BagState<T> bindBag(
+ StateTag<? super K, BagState<T>> address,
+ Coder<T> elemCoder) {
+
+ return new FlinkKeyGroupBagState<>(address, namespace, elemCoder);
+ }
+
+ @Override
+ public <T> SetState<T> bindSet(
+ StateTag<? super K, SetState<T>> address,
+ Coder<T> elemCoder) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", SetState.class.getSimpleName()));
+ }
+
+ @Override
+ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+ StateTag<? super K, MapState<KeyT, ValueT>> spec,
+ Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", MapState.class.getSimpleName()));
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT>
+ CombiningState<InputT, AccumT, OutputT>
+ bindCombiningValue(
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ Combine.CombineFn<InputT, AccumT, OutputT> combineFn) {
+ throw new UnsupportedOperationException("bindCombiningValue is not supported.");
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT>
+ CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue(
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) {
+ throw new UnsupportedOperationException("bindKeyedCombiningValue is not supported.");
+
+ }
+
+ @Override
+ public <InputT, AccumT, OutputT>
+ CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext(
+ StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address,
+ Coder<AccumT> accumCoder,
+ CombineWithContext.KeyedCombineFnWithContext<
+ ? super K, InputT, AccumT, OutputT> combineFn) {
+ throw new UnsupportedOperationException(
+ "bindKeyedCombiningValueWithContext is not supported.");
+ }
+
+ @Override
+ public <W extends BoundedWindow> WatermarkHoldState<W> bindWatermark(
+ StateTag<? super K, WatermarkHoldState<W>> address,
+ OutputTimeFn<? super W> outputTimeFn) {
+ throw new UnsupportedOperationException(
+ String.format("%s is not supported", CombiningState.class.getSimpleName()));
+ }
+ });
+ }
+
+ /**
+ * Reference from {@link Combine.CombineFn}.
+ *
+ * <p>Accumulators are stored in each KeyGroup, call addInput() when a element comes,
+ * call extractOutput() to produce the desired value when need to read data.
+ */
+ interface KeyGroupCombiner<InputT, AccumT, OutputT> {
+
+ /**
+ * Returns a new, mutable accumulator value, representing the accumulation
+ * of zero input values.
+ */
+ AccumT createAccumulator();
+
+ /**
+ * Adds the given input value to the given accumulator, returning the
+ * new accumulator value.
+ */
+ AccumT addInput(AccumT accumulator, InputT input);
+
+ /**
+ * Returns the output value that is the result of all accumulators from KeyGroups
+ * that are assigned to this operator.
+ */
+ OutputT extractOutput(Iterable<AccumT> accumulators);
+ }
+
+ private abstract class AbstractKeyGroupState<InputT, AccumT, OutputT> {
+
+ private String stateName;
+ private String namespace;
+ private Coder<AccumT> coder;
+ private KeyGroupCombiner<InputT, AccumT, OutputT> keyGroupCombiner;
+
+ AbstractKeyGroupState(
+ String stateName,
+ String namespace,
+ Coder<AccumT> coder,
+ KeyGroupCombiner<InputT, AccumT, OutputT> keyGroupCombiner) {
+ this.stateName = stateName;
+ this.namespace = namespace;
+ this.coder = coder;
+ this.keyGroupCombiner = keyGroupCombiner;
+ }
+
+ /**
+ * Choose keyGroup of input and addInput to accumulator.
+ */
+ void addInput(InputT input) {
+ int keyGroupIdx = keyedStateBackend.getCurrentKeyGroupIndex();
+ int localIdx = getIndexForKeyGroup(keyGroupIdx);
+ Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx];
+ Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
+ if (tuple2 == null) {
+ tuple2 = new Tuple2<>();
+ tuple2.f0 = coder;
+ tuple2.f1 = new HashMap<>();
+ stateTable.put(stateName, tuple2);
+ }
+ Map<String, AccumT> map = (Map<String, AccumT>) tuple2.f1;
+ AccumT accumulator = map.get(namespace);
+ if (accumulator == null) {
+ accumulator = keyGroupCombiner.createAccumulator();
+ }
+ accumulator = keyGroupCombiner.addInput(accumulator, input);
+ map.put(namespace, accumulator);
+ }
+
+ /**
+ * Get all accumulators and invoke extractOutput().
+ */
+ OutputT extractOutput() {
+ List<AccumT> accumulators = new ArrayList<>(stateTables.length);
+ for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) {
+ Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
+ if (tuple2 != null) {
+ AccumT accumulator = (AccumT) tuple2.f1.get(namespace);
+ if (accumulator != null) {
+ accumulators.add(accumulator);
+ }
+ }
+ }
+ return keyGroupCombiner.extractOutput(accumulators);
+ }
+
+ /**
+ * Find the first accumulator and return immediately.
+ */
+ boolean isEmptyInternal() {
+ for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) {
+ Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
+ if (tuple2 != null) {
+ AccumT accumulator = (AccumT) tuple2.f1.get(namespace);
+ if (accumulator != null) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Clear accumulators and clean empty map.
+ */
+ void clearInternal() {
+ for (Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable : stateTables) {
+ Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
+ if (tuple2 != null) {
+ tuple2.f1.remove(namespace);
+ if (tuple2.f1.size() == 0) {
+ stateTable.remove(stateName);
+ }
+ }
+ }
+ }
+
+ }
+
+ private int getIndexForKeyGroup(int keyGroupIdx) {
+ checkArgument(localKeyGroupRange.contains(keyGroupIdx),
+ "Key Group " + keyGroupIdx + " does not belong to the local range.");
+ return keyGroupIdx - this.localKeyGroupRangeStartIdx;
+ }
+
+ private class KeyGroupBagCombiner<T> implements KeyGroupCombiner<T, List<T>, Iterable<T>> {
+
+ @Override
+ public List<T> createAccumulator() {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public List<T> addInput(List<T> accumulator, T input) {
+ accumulator.add(input);
+ return accumulator;
+ }
+
+ @Override
+ public Iterable<T> extractOutput(Iterable<List<T>> accumulators) {
+ List<T> result = new ArrayList<>();
+ // maybe can return an unmodifiable view.
+ for (List<T> list : accumulators) {
+ result.addAll(list);
+ }
+ return result;
+ }
+ }
+
+ private class FlinkKeyGroupBagState<T> extends AbstractKeyGroupState<T, List<T>, Iterable<T>>
+ implements BagState<T> {
+
+ private final StateNamespace namespace;
+ private final StateTag<? super K, BagState<T>> address;
+
+ FlinkKeyGroupBagState(
+ StateTag<? super K, BagState<T>> address,
+ StateNamespace namespace,
+ Coder<T> coder) {
+ super(address.getId(), namespace.stringKey(), ListCoder.of(coder),
+ new KeyGroupBagCombiner<T>());
+ this.namespace = namespace;
+ this.address = address;
+ }
+
+ @Override
+ public void add(T input) {
+ addInput(input);
+ }
+
+ @Override
+ public BagState<T> readLater() {
+ return this;
+ }
+
+ @Override
+ public Iterable<T> read() {
+ Iterable<T> result = extractOutput();
+ return result != null ? result : Collections.<T>emptyList();
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return new ReadableState<Boolean>() {
+ @Override
+ public Boolean read() {
+ try {
+ return isEmptyInternal();
+ } catch (Exception e) {
+ throw new RuntimeException("Error reading state.", e);
+ }
+
+ }
+
+ @Override
+ public ReadableState<Boolean> readLater() {
+ return this;
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ clearInternal();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ FlinkKeyGroupBagState<?> that = (FlinkKeyGroupBagState<?>) o;
+
+ return namespace.equals(that.namespace) && address.equals(that.address);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = namespace.hashCode();
+ result = 31 * result + address.hashCode();
+ return result;
+ }
+ }
+
+ /**
+ * Snapshots the state {@code (stateName -> (valueCoder && (namespace -> value)))} for a given
+ * {@code keyGroupIdx}.
+ *
+ * @param keyGroupIdx the id of the key-group to be put in the snapshot.
+ * @param out the stream to write to.
+ */
+ public void snapshotKeyGroupState(int keyGroupIdx, DataOutputStream out) throws Exception {
+ int localIdx = getIndexForKeyGroup(keyGroupIdx);
+ Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx];
+ Preconditions.checkState(stateTable.size() <= Short.MAX_VALUE,
+ "Too many States: " + stateTable.size() + ". Currently at most "
+ + Short.MAX_VALUE + " states are supported");
+ out.writeShort(stateTable.size());
+ for (Map.Entry<String, Tuple2<Coder<?>, Map<String, ?>>> entry : stateTable.entrySet()) {
+ out.writeUTF(entry.getKey());
+ Coder coder = entry.getValue().f0;
+ InstantiationUtil.serializeObject(out, coder);
+ Map<String, ?> map = entry.getValue().f1;
+ out.writeInt(map.size());
+ for (Map.Entry<String, ?> entry1 : map.entrySet()) {
+ StringUtf8Coder.of().encode(entry1.getKey(), out, Context.NESTED);
+ coder.encode(entry1.getValue(), out, Context.NESTED);
+ }
+ }
+ }
+
+ /**
+ * Restore the state {@code (stateName -> (valueCoder && (namespace -> value)))}
+ * for a given {@code keyGroupIdx}.
+ *
+ * @param keyGroupIdx the id of the key-group to be put in the snapshot.
+ * @param in the stream to read from.
+ * @param userCodeClassLoader the class loader that will be used to deserialize
+ * the valueCoder.
+ */
+ public void restoreKeyGroupState(int keyGroupIdx, DataInputStream in,
+ ClassLoader userCodeClassLoader) throws Exception {
+ int localIdx = getIndexForKeyGroup(keyGroupIdx);
+ Map<String, Tuple2<Coder<?>, Map<String, ?>>> stateTable = stateTables[localIdx];
+ int numStates = in.readShort();
+ for (int i = 0; i < numStates; ++i) {
+ String stateName = in.readUTF();
+ Coder coder = InstantiationUtil.deserializeObject(in, userCodeClassLoader);
+ Tuple2<Coder<?>, Map<String, ?>> tuple2 = stateTable.get(stateName);
+ if (tuple2 == null) {
+ tuple2 = new Tuple2<>();
+ tuple2.f0 = coder;
+ tuple2.f1 = new HashMap<>();
+ stateTable.put(stateName, tuple2);
+ }
+ Map<String, Object> map = (Map<String, Object>) tuple2.f1;
+ int mapSize = in.readInt();
+ for (int j = 0; j < mapSize; j++) {
+ String namespace = StringUtf8Coder.of().decode(in, Context.NESTED);
+ Object value = coder.decode(in, Context.NESTED);
+ map.put(namespace, value);
+ }
+ }
+ }
+
+}