You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ch...@apache.org on 2016/11/14 12:41:43 UTC

flink git commit: [FLINK-4939] WriteAheadSink: Decouple creation and commit of a pending checkpoint

Repository: flink
Updated Branches:
  refs/heads/master b2e8792b8 -> 381bf5912


[FLINK-4939] WriteAheadSink: Decouple creation and commit of a pending checkpoint

So far the GenericWriteAheadSink expected that
the subtask that wrote a temporary buffer to the
state backend, will be also the one to commit it to
the third-party storage system.

This commit removes this assumption. To do this
it changes the CheckpointCommitter to dynamically
take the subtaskIdx as a parameter when asking
if a checkpoint was committed and also changes the
state kept by the GenericWriteAheadSink to also
include that subtask index of the subtask that wrote
the pending buffer.

This closes #2707.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/381bf591
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/381bf591
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/381bf591

Branch: refs/heads/master
Commit: 381bf5912237420b8e294c1fd38006a85403fd4f
Parents: b2e8792
Author: kl0u <kk...@gmail.com>
Authored: Wed Oct 26 17:19:12 2016 +0200
Committer: zentol <ch...@apache.org>
Committed: Mon Nov 14 13:41:24 2016 +0100

----------------------------------------------------------------------
 .../cassandra/CassandraCommitter.java           |  54 +++--
 .../cassandra/CassandraConnectorITCase.java     |  26 ++-
 .../runtime/operators/CheckpointCommitter.java  |  22 +--
 .../operators/GenericWriteAheadSink.java        | 196 ++++++++++++-------
 .../operators/GenericWriteAheadSinkTest.java    |  14 +-
 .../operators/WriteAheadSinkTestBase.java       |   2 +-
 6 files changed, 191 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/381bf591/flink-streaming-connectors/flink-connector-cassandra/src/main/java/org/apache/flink/streaming/connectors/cassandra/CassandraCommitter.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-cassandra/src/main/java/org/apache/flink/streaming/connectors/cassandra/CassandraCommitter.java b/flink-streaming-connectors/flink-connector-cassandra/src/main/java/org/apache/flink/streaming/connectors/cassandra/CassandraCommitter.java
index e83b1be..63b76da 100644
--- a/flink-streaming-connectors/flink-connector-cassandra/src/main/java/org/apache/flink/streaming/connectors/cassandra/CassandraCommitter.java
+++ b/flink-streaming-connectors/flink-connector-cassandra/src/main/java/org/apache/flink/streaming/connectors/cassandra/CassandraCommitter.java
@@ -18,11 +18,15 @@
 package org.apache.flink.streaming.connectors.cassandra;
 
 import com.datastax.driver.core.Cluster;
-import com.datastax.driver.core.PreparedStatement;
+import com.datastax.driver.core.Row;
 import com.datastax.driver.core.Session;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.streaming.runtime.operators.CheckpointCommitter;
 
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
 /**
  * CheckpointCommitter that saves information about completed checkpoints within a separate table in a cassandra
  * database.
@@ -40,10 +44,11 @@ public class CassandraCommitter extends CheckpointCommitter {
 	private String keySpace = "flink_auxiliary";
 	private String table = "checkpoints_";
 
-	private transient PreparedStatement updateStatement;
-	private transient PreparedStatement selectStatement;
-
-	private long lastCommittedCheckpointID = -1;
+	/**
+	 * A cache of the last committed checkpoint ids per subtask index. This is used to
+	 * avoid redundant round-trips to Cassandra (see {@link #isCheckpointCommitted(int, long)}.
+	 */
+	private final Map<Integer, Long> lastCommittedCheckpoints = new HashMap<>();
 
 	public CassandraCommitter(ClusterBuilder builder) {
 		this.builder = builder;
@@ -95,16 +100,11 @@ public class CassandraCommitter extends CheckpointCommitter {
 		}
 		cluster = builder.getCluster();
 		session = cluster.connect();
-
-		updateStatement = session.prepare(String.format("UPDATE %s.%s set checkpoint_id=? where sink_id='%s' and sub_id=%d;", keySpace, table, operatorId, subtaskId));
-		selectStatement = session.prepare(String.format("SELECT checkpoint_id FROM %s.%s where sink_id='%s' and sub_id=%d;", keySpace, table, operatorId, subtaskId));
-
-		session.execute(String.format("INSERT INTO %s.%s (sink_id, sub_id, checkpoint_id) values ('%s', %d, " + -1 + ") IF NOT EXISTS;", keySpace, table, operatorId, subtaskId));
 	}
 
 	@Override
 	public void close() throws Exception {
-		this.lastCommittedCheckpointID = -1;
+		this.lastCommittedCheckpoints.clear();
 		try {
 			session.close();
 		} catch (Exception e) {
@@ -118,16 +118,34 @@ public class CassandraCommitter extends CheckpointCommitter {
 	}
 
 	@Override
-	public void commitCheckpoint(long checkpointID) {
-		session.execute(updateStatement.bind(checkpointID));
-		this.lastCommittedCheckpointID = checkpointID;
+	public void commitCheckpoint(int subtaskIdx, long checkpointId) {
+		String statement = String.format(
+			"UPDATE %s.%s set checkpoint_id=%d where sink_id='%s' and sub_id=%d;",
+			keySpace, table, checkpointId, operatorId, subtaskIdx);
+
+		session.execute(statement);
+		lastCommittedCheckpoints.put(subtaskIdx, checkpointId);
 	}
 
 	@Override
-	public boolean isCheckpointCommitted(long checkpointID) {
-		if (this.lastCommittedCheckpointID == -1) {
-			this.lastCommittedCheckpointID = session.execute(selectStatement.bind()).one().getLong("checkpoint_id");
+	public boolean isCheckpointCommitted(int subtaskIdx, long checkpointId) {
+		// Pending checkpointed buffers are committed in ascending order of their
+		// checkpoint id. This way we can tell if a checkpointed buffer was committed
+		// just by asking the third-party storage system for the last checkpoint id
+		// committed by the specified subtask.
+
+		Long lastCommittedCheckpoint = lastCommittedCheckpoints.get(subtaskIdx);
+		if (lastCommittedCheckpoint == null) {
+			String statement = String.format(
+				"SELECT checkpoint_id FROM %s.%s where sink_id='%s' and sub_id=%d;",
+				keySpace, table, operatorId, subtaskIdx);
+
+			Iterator<Row> resultIt = session.execute(statement).iterator();
+			if (resultIt.hasNext()) {
+				lastCommittedCheckpoint = resultIt.next().getLong("checkpoint_id");
+				lastCommittedCheckpoints.put(subtaskIdx, lastCommittedCheckpoint);
+			}
 		}
-		return checkpointID <= this.lastCommittedCheckpointID;
+		return lastCommittedCheckpoint != null && checkpointId <= lastCommittedCheckpoint;
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/381bf591/flink-streaming-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java b/flink-streaming-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
index a29e881..2bb6fd1 100644
--- a/flink-streaming-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
+++ b/flink-streaming-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -320,17 +320,14 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 		CassandraCommitter cc1 = new CassandraCommitter(builder);
 		cc1.setJobId("job");
 		cc1.setOperatorId("operator");
-		cc1.setOperatorSubtaskId(0);
 
 		CassandraCommitter cc2 = new CassandraCommitter(builder);
 		cc2.setJobId("job");
 		cc2.setOperatorId("operator");
-		cc2.setOperatorSubtaskId(1);
 
 		CassandraCommitter cc3 = new CassandraCommitter(builder);
 		cc3.setJobId("job");
 		cc3.setOperatorId("operator1");
-		cc3.setOperatorSubtaskId(0);
 
 		cc1.createResource();
 
@@ -338,18 +335,18 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 		cc2.open();
 		cc3.open();
 
-		Assert.assertFalse(cc1.isCheckpointCommitted(1));
-		Assert.assertFalse(cc2.isCheckpointCommitted(1));
-		Assert.assertFalse(cc3.isCheckpointCommitted(1));
+		Assert.assertFalse(cc1.isCheckpointCommitted(0, 1));
+		Assert.assertFalse(cc2.isCheckpointCommitted(1, 1));
+		Assert.assertFalse(cc3.isCheckpointCommitted(0, 1));
 
-		cc1.commitCheckpoint(1);
-		Assert.assertTrue(cc1.isCheckpointCommitted(1));
+		cc1.commitCheckpoint(0, 1);
+		Assert.assertTrue(cc1.isCheckpointCommitted(0, 1));
 		//verify that other sub-tasks aren't affected
-		Assert.assertFalse(cc2.isCheckpointCommitted(1));
+		Assert.assertFalse(cc2.isCheckpointCommitted(1, 1));
 		//verify that other tasks aren't affected
-		Assert.assertFalse(cc3.isCheckpointCommitted(1));
+		Assert.assertFalse(cc3.isCheckpointCommitted(0, 1));
 
-		Assert.assertFalse(cc1.isCheckpointCommitted(2));
+		Assert.assertFalse(cc1.isCheckpointCommitted(0, 2));
 
 		cc1.close();
 		cc2.close();
@@ -358,13 +355,12 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 		cc1 = new CassandraCommitter(builder);
 		cc1.setJobId("job");
 		cc1.setOperatorId("operator");
-		cc1.setOperatorSubtaskId(0);
 
 		cc1.open();
 
 		//verify that checkpoint data is not destroyed within open/close and not reliant on internally cached data
-		Assert.assertTrue(cc1.isCheckpointCommitted(1));
-		Assert.assertFalse(cc1.isCheckpointCommitted(2));
+		Assert.assertTrue(cc1.isCheckpointCommitted(0, 1));
+		Assert.assertFalse(cc1.isCheckpointCommitted(0, 2));
 
 		cc1.close();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/381bf591/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
index 9ecc2ee..90e3a57 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
@@ -41,9 +41,9 @@ import java.io.Serializable;
  */
 public abstract class CheckpointCommitter implements Serializable {
 	protected static final Logger LOG = LoggerFactory.getLogger(CheckpointCommitter.class);
+
 	protected String jobId;
 	protected String operatorId;
-	protected int subtaskId;
 
 	/**
 	 * Internally used to set the job ID after instantiation.
@@ -66,16 +66,6 @@ public abstract class CheckpointCommitter implements Serializable {
 	}
 
 	/**
-	 * Internally used to set the operator subtask ID after instantiation.
-	 *
-	 * @param id
-	 * @throws Exception
-	 */
-	public void setOperatorSubtaskId(int id) throws Exception {
-		this.subtaskId = id;
-	}
-
-	/**
 	 * Opens/connects to the resource, and possibly creates it beforehand.
 	 *
 	 * @throws Exception
@@ -98,17 +88,19 @@ public abstract class CheckpointCommitter implements Serializable {
 	/**
 	 * Mark the given checkpoint as completed in the resource.
 	 *
-	 * @param checkpointID
+	 * @param subtaskIdx the index of the subtask responsible for committing the checkpoint.
+	 * @param checkpointID the id of the checkpoint to be committed.
 	 * @throws Exception
 	 */
-	public abstract void commitCheckpoint(long checkpointID) throws Exception;
+	public abstract void commitCheckpoint(int subtaskIdx, long checkpointID) throws Exception;
 
 	/**
 	 * Checked the resource whether the given checkpoint was committed completely.
 	 *
-	 * @param checkpointID
+	 * @param subtaskIdx the index of the subtask responsible for committing the checkpoint.
+	 * @param checkpointID the id of the checkpoint we are interested in.
 	 * @return true if the checkpoint was committed completely, false otherwise
 	 * @throws Exception
 	 */
-	public abstract boolean isCheckpointCommitted(long checkpointID) throws Exception;
+	public abstract boolean isCheckpointCommitted(int subtaskIdx, long checkpointID) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/381bf591/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
index 499fe83..b08b2e9 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
@@ -18,7 +18,6 @@
 package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -27,23 +26,24 @@ import org.apache.flink.runtime.io.disk.InputViewIterator;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.io.Serializable;
-import java.util.HashSet;
+import java.util.Iterator;
 import java.util.Set;
-import java.util.TreeMap;
+import java.util.TreeSet;
 import java.util.UUID;
 
 /**
- * Generic Sink that emits its input elements into an arbitrary backend. This sink is integrated with the checkpointing
+ * Generic Sink that emits its input elements into an arbitrary backend. This sink is integrated with Flink's checkpointing
  * mechanism and can provide exactly-once guarantees; depending on the storage backend and sink/committer implementation.
  * <p/>
  * Incoming records are stored within a {@link org.apache.flink.runtime.state.AbstractStateBackend}, and only committed if a
@@ -57,18 +57,21 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	private static final long serialVersionUID = 1L;
 
 	protected static final Logger LOG = LoggerFactory.getLogger(GenericWriteAheadSink.class);
+
+	private final String id;
 	private final CheckpointCommitter committer;
-	private transient CheckpointStreamFactory.CheckpointStateOutputStream out;
 	protected final TypeSerializer<IN> serializer;
-	private final String id;
+
+	private transient CheckpointStreamFactory.CheckpointStateOutputStream out;
 	private transient CheckpointStreamFactory checkpointStreamFactory;
 
-	private ExactlyOnceState state = new ExactlyOnceState();
+	private final Set<PendingCheckpoint> pendingCheckpoints = new TreeSet<>();
 
-	public GenericWriteAheadSink(CheckpointCommitter committer, TypeSerializer<IN> serializer, String jobID) throws Exception {
-		this.committer = committer;
-		this.serializer = serializer;
+	public GenericWriteAheadSink(CheckpointCommitter committer,	TypeSerializer<IN> serializer, String jobID) throws Exception {
+		this.committer = Preconditions.checkNotNull(committer);
+		this.serializer = Preconditions.checkNotNull(serializer);
 		this.id = UUID.randomUUID().toString();
+
 		this.committer.setJobId(jobID);
 		this.committer.createResource();
 	}
@@ -77,11 +80,11 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	public void open() throws Exception {
 		super.open();
 		committer.setOperatorId(id);
-		committer.setOperatorSubtaskId(getRuntimeContext().getIndexOfThisSubtask());
 		committer.open();
-		cleanState();
-		checkpointStreamFactory =
-				getContainingTask().createCheckpointStreamFactory(this);
+
+		checkpointStreamFactory = getContainingTask().createCheckpointStreamFactory(this);
+
+		cleanRestoredHandles();
 	}
 
 	public void close() throws Exception {
@@ -89,52 +92,68 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	}
 
 	/**
-	 * Saves a handle in the state.
+	 * Called when a checkpoint barrier arrives. It closes any open streams to the backend
+	 * and marks them as pending for committing to the external, third-party storage system.
 	 *
-	 * @param checkpointId
-	 * @throws IOException
+	 * @param checkpointId the id of the latest received checkpoint.
+	 * @throws IOException in case something went wrong when handling the stream to the backend.
 	 */
 	private void saveHandleInState(final long checkpointId, final long timestamp) throws Exception {
 		//only add handle if a new OperatorState was created since the last snapshot
 		if (out != null) {
+			int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask();
 			StreamStateHandle handle = out.closeAndGetHandle();
-			if (state.pendingHandles.containsKey(checkpointId)) {
+
+			PendingCheckpoint pendingCheckpoint = new PendingCheckpoint(checkpointId, subtaskIdx, timestamp, handle);
+
+			if (pendingCheckpoints.contains(pendingCheckpoint)) {
 				//we already have a checkpoint stored for that ID that may have been partially written,
 				//so we discard this "alternate version" and use the stored checkpoint
 				handle.discardState();
 			} else {
-				state.pendingHandles.put(checkpointId, new Tuple2<>(timestamp, handle));
+				pendingCheckpoints.add(pendingCheckpoint);
 			}
 			out = null;
 		}
 	}
 
 	@Override
-	public void snapshotState(FSDataOutputStream out,
-			long checkpointId,
-			long timestamp) throws Exception {
+	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
 		saveHandleInState(checkpointId, timestamp);
 
-		InstantiationUtil.serializeObject(out, state);
+		DataOutputViewStreamWrapper outStream = new DataOutputViewStreamWrapper(out);
+		outStream.writeInt(pendingCheckpoints.size());
+		for (PendingCheckpoint pendingCheckpoint : pendingCheckpoints) {
+			pendingCheckpoint.serialize(outStream);
+		}
 	}
 
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
-		this.state = InstantiationUtil.deserializeObject(in, getUserCodeClassloader());
+		final DataInputViewStreamWrapper inStream = new DataInputViewStreamWrapper(in);
+		int numPendingHandles = inStream.readInt();
+		for (int i = 0; i < numPendingHandles; i++) {
+			pendingCheckpoints.add(PendingCheckpoint.restore(inStream, getUserCodeClassloader()));
+		}
 	}
 
-	private void cleanState() throws Exception {
-		synchronized (this.state.pendingHandles) { //remove all handles that were already committed
-			Set<Long> pastCheckpointIds = this.state.pendingHandles.keySet();
-			Set<Long> checkpointsToRemove = new HashSet<>();
-			for (Long pastCheckpointId : pastCheckpointIds) {
-				if (committer.isCheckpointCommitted(pastCheckpointId)) {
-					checkpointsToRemove.add(pastCheckpointId);
+	/**
+	 * Called at {@link #open()} to clean-up the pending handle list.
+	 * It iterates over all restored pending handles, checks which ones are already
+	 * committed to the outside storage system and removes them from the list.
+	 */
+	private void cleanRestoredHandles() throws Exception {
+		synchronized (pendingCheckpoints) {
+
+			Iterator<PendingCheckpoint> pendingCheckpointIt = pendingCheckpoints.iterator();
+			while (pendingCheckpointIt.hasNext()) {
+				PendingCheckpoint pendingCheckpoint = pendingCheckpointIt.next();
+
+				if (committer.isCheckpointCommitted(pendingCheckpoint.subtaskId, pendingCheckpoint.checkpointId)) {
+					pendingCheckpoint.stateHandle.discardState();
+					pendingCheckpointIt.remove();
 				}
 			}
-			for (Long toRemove : checkpointsToRemove) {
-				this.state.pendingHandles.remove(toRemove);
-			}
 		}
 	}
 
@@ -142,15 +161,19 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
 		super.notifyOfCompletedCheckpoint(checkpointId);
 
-		synchronized (state.pendingHandles) {
-			Set<Long> pastCheckpointIds = state.pendingHandles.keySet();
-			Set<Long> checkpointsToRemove = new HashSet<>();
-			for (Long pastCheckpointId : pastCheckpointIds) {
+		synchronized (pendingCheckpoints) {
+			Iterator<PendingCheckpoint> pendingCheckpointIt = pendingCheckpoints.iterator();
+			while (pendingCheckpointIt.hasNext()) {
+				PendingCheckpoint pendingCheckpoint = pendingCheckpointIt.next();
+				long pastCheckpointId = pendingCheckpoint.checkpointId;
+				int subtaskId = pendingCheckpoint.subtaskId;
+				long timestamp = pendingCheckpoint.timestamp;
+				StreamStateHandle streamHandle = pendingCheckpoint.stateHandle;
+
 				if (pastCheckpointId <= checkpointId) {
 					try {
-						if (!committer.isCheckpointCommitted(pastCheckpointId)) {
-							Tuple2<Long, StreamStateHandle> handle = state.pendingHandles.get(pastCheckpointId);
-							try (FSDataInputStream in = handle.f1.openInputStream()) {
+						if (!committer.isCheckpointCommitted(subtaskId, pastCheckpointId)) {
+							try (FSDataInputStream in = streamHandle.openInputStream()) {
 								boolean success = sendValues(
 										new ReusingMutableToRegularIteratorWrapper<>(
 												new InputViewIterator<>(
@@ -158,30 +181,31 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 																in),
 														serializer),
 												serializer),
-										handle.f0);
-								if (success) { //if the sending has failed we will retry on the next notify
-									committer.commitCheckpoint(pastCheckpointId);
-									checkpointsToRemove.add(pastCheckpointId);
+										timestamp);
+								if (success) {
+									// in case the checkpoint was successfully committed,
+									// discard its state from the backend and mark it for removal
+									// in case it failed, we retry on the next checkpoint
+									committer.commitCheckpoint(subtaskId, pastCheckpointId);
+									streamHandle.discardState();
+									pendingCheckpointIt.remove();
 								}
 							}
 						} else {
-							checkpointsToRemove.add(pastCheckpointId);
+							streamHandle.discardState();
+							pendingCheckpointIt.remove();
 						}
 					} catch (Exception e) {
+						// we have to break here to prevent a new (later) checkpoint
+						// from being committed before this one
 						LOG.error("Could not commit checkpoint.", e);
-						break; // we have to break here to prevent a new checkpoint from being committed before this one
+						break;
 					}
 				}
 			}
-			for (Long toRemove : checkpointsToRemove) {
-				Tuple2<Long, StreamStateHandle> handle = state.pendingHandles.get(toRemove);
-				state.pendingHandles.remove(toRemove);
-				handle.f1.discardState();
-			}
 		}
 	}
 
-
 	/**
 	 * Write the given element into the backend.
 	 *
@@ -201,27 +225,65 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 		serializer.serialize(value, new DataOutputViewStreamWrapper(out));
 	}
 
-	/**
-	 * This state is used to keep a list of all StateHandles (essentially references to past OperatorStates) that were
-	 * used since the last completed checkpoint.
-	 **/
-	public static class ExactlyOnceState implements Serializable {
+	private static final class PendingCheckpoint implements Comparable<PendingCheckpoint>, Serializable {
+
+		private static final long serialVersionUID = -3571036395734603443L;
 
-		private static final long serialVersionUID = -3571063495273460743L;
+		private final long checkpointId;
+		private final int subtaskId;
+		private final long timestamp;
+		private final StreamStateHandle stateHandle;
 
-		protected TreeMap<Long, Tuple2<Long, StreamStateHandle>> pendingHandles;
+		PendingCheckpoint(long checkpointId, int subtaskId, long timestamp, StreamStateHandle handle) {
+			this.checkpointId = checkpointId;
+			this.subtaskId = subtaskId;
+			this.timestamp = timestamp;
+			this.stateHandle = handle;
+		}
 
-		public ExactlyOnceState() {
-			pendingHandles = new TreeMap<>();
+		void serialize(DataOutputViewStreamWrapper outputStream) throws IOException {
+			outputStream.writeLong(checkpointId);
+			outputStream.writeInt(subtaskId);
+			outputStream.writeLong(timestamp);
+			InstantiationUtil.serializeObject(outputStream, stateHandle);
 		}
 
-		public TreeMap<Long, Tuple2<Long, StreamStateHandle>> getState(ClassLoader userCodeClassLoader) throws Exception {
-			return pendingHandles;
+		static PendingCheckpoint restore(
+				DataInputViewStreamWrapper inputStream,
+				ClassLoader classLoader) throws IOException, ClassNotFoundException {
+
+			long checkpointId = inputStream.readLong();
+			int subtaskId = inputStream.readInt();
+			long timestamp = inputStream.readLong();
+			StreamStateHandle handle = InstantiationUtil.deserializeObject(inputStream, classLoader);
+
+			return new PendingCheckpoint(checkpointId, subtaskId, timestamp, handle);
+		}
+
+		@Override
+		public int compareTo(PendingCheckpoint o) {
+			int res = Long.compare(this.checkpointId, o.checkpointId);
+			return res != 0 ? res : Integer.compare(this.subtaskId, o.subtaskId);
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if (!(o instanceof GenericWriteAheadSink.PendingCheckpoint)) {
+				return false;
+			}
+			PendingCheckpoint other = (PendingCheckpoint) o;
+			return this.checkpointId == other.checkpointId &&
+				this.subtaskId == other.subtaskId &&
+				this.timestamp == other.timestamp;
 		}
 
 		@Override
-		public String toString() {
-			return this.pendingHandles.toString();
+		public int hashCode() {
+			int hash = 17;
+			hash = 31 * hash + (int) (checkpointId ^ (checkpointId >>> 32));
+			hash = 31 * hash + subtaskId;
+			hash = 31 * hash + (int) (timestamp ^ (timestamp >>> 32));
+			return hash;
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/381bf591/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
index e186be0..8d092ed 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
@@ -128,7 +128,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		testHarness.notifyOfCompletedCheckpoint(0);
 
 		//isCommitted should have failed, thus sendValues() should never have been called
-		Assert.assertTrue(sink.values.size() == 0);
+		Assert.assertEquals(0, sink.values.size());
 
 		for (int x = 0; x < 10; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 1)));
@@ -139,7 +139,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		testHarness.notifyOfCompletedCheckpoint(1);
 
 		//previous CP should be retried, but will fail the CP commit. Second CP should be skipped.
-		Assert.assertTrue(sink.values.size() == 10);
+		Assert.assertEquals(10, sink.values.size());
 
 		for (int x = 0; x < 10; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 2)));
@@ -150,7 +150,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		testHarness.notifyOfCompletedCheckpoint(2);
 
 		//all CP's should be retried and succeed; since one CP was written twice we have 2 * 10 + 10 + 10 = 40 values
-		Assert.assertTrue(sink.values.size() == 40);
+		Assert.assertEquals(40, sink.values.size());
 	}
 
 	/**
@@ -193,12 +193,12 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		}
 
 		@Override
-		public void commitCheckpoint(long checkpointID) {
+		public void commitCheckpoint(int subtaskIdx, long checkpointID) {
 			checkpoints.add(checkpointID);
 		}
 
 		@Override
-		public boolean isCheckpointCommitted(long checkpointID) {
+		public boolean isCheckpointCommitted(int subtaskIdx, long checkpointID) {
 			return checkpoints.contains(checkpointID);
 		}
 	}
@@ -245,7 +245,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		}
 
 		@Override
-		public void commitCheckpoint(long checkpointID) {
+		public void commitCheckpoint(int subtaskIdx, long checkpointID) {
 			if (failCommit) {
 				failCommit = false;
 				throw new RuntimeException("Expected exception");
@@ -255,7 +255,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		}
 
 		@Override
-		public boolean isCheckpointCommitted(long checkpointID) {
+		public boolean isCheckpointCommitted(int subtaskIdx, long checkpointID) {
 			if (failIsCommitted) {
 				failIsCommitted = false;
 				throw new RuntimeException("Expected exception");

http://git-wip-us.apache.org/repos/asf/flink/blob/381bf591/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
index ab84bc1..a9c5792 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
@@ -149,7 +149,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 
 		sink = createSink();
 
-		testHarness =new OneInputStreamOperatorTestHarness<>(sink);
+		testHarness = new OneInputStreamOperatorTestHarness<>(sink);
 
 		testHarness.setup();
 		testHarness.restore(latestSnapshot);