You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ch...@apache.org on 2016/12/08 12:54:56 UTC

[3/3] flink git commit: [FLINK-5020] Make the GenericWriteAheadSink rescalable.

[FLINK-5020] Make the GenericWriteAheadSink rescalable.

Integrates the new state abstractions with the GenericWriteAheadSink
so that the latter can change its parallelism when resuming execution
from a savepoint, without geopardizing the provided guarantees.

This closes #2759


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4eb71927
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4eb71927
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4eb71927

Branch: refs/heads/master
Commit: 4eb71927bc4f0832eb08a79394ad6864a3c2e142
Parents: 86f784a
Author: kl0u <kk...@gmail.com>
Authored: Wed Oct 26 17:19:12 2016 +0200
Committer: zentol <ch...@apache.org>
Committed: Thu Dec 8 12:27:14 2016 +0100

----------------------------------------------------------------------
 .../cassandra/CassandraConnectorITCase.java     |  38 ++--
 .../runtime/operators/CheckpointCommitter.java  |   1 +
 .../operators/GenericWriteAheadSink.java        | 105 ++++++-----
 .../operators/GenericWriteAheadSinkTest.java    |  50 +++---
 .../operators/WriteAheadSinkTestBase.java       | 172 +++++++++++++++++--
 5 files changed, 276 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/4eb71927/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
index 2bb6fd1..f2e8f8b 100644
--- a/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
+++ b/flink-connectors/flink-connector-cassandra/src/test/java/org/apache/flink/streaming/connectors/cassandra/CassandraConnectorITCase.java
@@ -47,7 +47,6 @@ import org.apache.flink.streaming.api.datastream.DataStreamSource;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.runtime.operators.WriteAheadSinkTestBase;
-import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.TestStreamEnvironment;
 import org.apache.flink.test.util.TestEnvironment;
 
@@ -71,6 +70,7 @@ import java.io.File;
 import java.io.FileWriter;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Scanner;
 import java.util.UUID;
 
@@ -262,9 +262,7 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 	}
 
 	@Override
-	protected void verifyResultsIdealCircumstances(
-		OneInputStreamOperatorTestHarness<Tuple3<String, Integer, Integer>, Tuple3<String, Integer, Integer>> harness,
-		CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink) {
+	protected void verifyResultsIdealCircumstances(CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink) {
 
 		ResultSet result = session.execute(SELECT_DATA_QUERY);
 		ArrayList<Integer> list = new ArrayList<>();
@@ -279,9 +277,7 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 	}
 
 	@Override
-	protected void verifyResultsDataPersistenceUponMissedNotify(
-		OneInputStreamOperatorTestHarness<Tuple3<String, Integer, Integer>, Tuple3<String, Integer, Integer>> harness,
-		CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink) {
+	protected void verifyResultsDataPersistenceUponMissedNotify(CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink) {
 
 		ResultSet result = session.execute(SELECT_DATA_QUERY);
 		ArrayList<Integer> list = new ArrayList<>();
@@ -296,9 +292,7 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 	}
 
 	@Override
-	protected void verifyResultsDataDiscardingUponRestore(
-		OneInputStreamOperatorTestHarness<Tuple3<String, Integer, Integer>, Tuple3<String, Integer, Integer>> harness,
-		CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink) {
+	protected void verifyResultsDataDiscardingUponRestore(CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink) {
 
 		ResultSet result = session.execute(SELECT_DATA_QUERY);
 		ArrayList<Integer> list = new ArrayList<>();
@@ -315,6 +309,30 @@ public class CassandraConnectorITCase extends WriteAheadSinkTestBase<Tuple3<Stri
 		Assert.assertTrue("The following ID's were not found in the ResultSet: " + list.toString(), list.isEmpty());
 	}
 
+	@Override
+	protected void verifyResultsWhenReScaling(
+		CassandraTupleWriteAheadSink<Tuple3<String, Integer, Integer>> sink, int startElementCounter, int endElementCounter) {
+
+		// IMPORTANT NOTE:
+		//
+		// for cassandra we always have to start from 1 because
+		// all operators will share the same final db
+
+		ArrayList<Integer> expected = new ArrayList<>();
+		for (int i = 1; i <= endElementCounter; i++) {
+			expected.add(i);
+		}
+
+		ArrayList<Integer> actual = new ArrayList<>();
+		ResultSet result = session.execute(SELECT_DATA_QUERY);
+		for (Row s : result) {
+			actual.add(s.getInt("counter"));
+		}
+
+		Collections.sort(actual);
+		Assert.assertArrayEquals(expected.toArray(), actual.toArray());
+	}
+
 	@Test
 	public void testCassandraCommitter() throws Exception {
 		CassandraCommitter cc1 = new CassandraCommitter(builder);

http://git-wip-us.apache.org/repos/asf/flink/blob/4eb71927/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
index 90e3a57..6e50dde 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/CheckpointCommitter.java
@@ -40,6 +40,7 @@ import java.io.Serializable;
  * and as such should kept as small as possible.
  */
 public abstract class CheckpointCommitter implements Serializable {
+
 	protected static final Logger LOG = LoggerFactory.getLogger(CheckpointCommitter.class);
 
 	protected String jobId;

http://git-wip-us.apache.org/repos/asf/flink/blob/4eb71927/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
index b08b2e9..564fa22 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
@@ -17,20 +17,20 @@
  */
 package org.apache.flink.streaming.runtime.operators;
 
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.io.disk.InputViewIterator;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -52,7 +52,7 @@ import java.util.UUID;
  * @param <IN> Type of the elements emitted by this sink
  */
 public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN>
-		implements OneInputStreamOperator<IN, IN>, StreamCheckpointedOperator {
+		implements OneInputStreamOperator<IN, IN> {
 
 	private static final long serialVersionUID = 1L;
 
@@ -65,9 +65,15 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	private transient CheckpointStreamFactory.CheckpointStateOutputStream out;
 	private transient CheckpointStreamFactory checkpointStreamFactory;
 
+	private transient ListState<PendingCheckpoint> checkpointedState;
+
 	private final Set<PendingCheckpoint> pendingCheckpoints = new TreeSet<>();
 
-	public GenericWriteAheadSink(CheckpointCommitter committer,	TypeSerializer<IN> serializer, String jobID) throws Exception {
+	public GenericWriteAheadSink(
+			CheckpointCommitter committer,
+			TypeSerializer<IN> serializer,
+			String jobID) throws Exception {
+
 		this.committer = Preconditions.checkNotNull(committer);
 		this.serializer = Preconditions.checkNotNull(serializer);
 		this.id = UUID.randomUUID().toString();
@@ -77,12 +83,39 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	}
 
 	@Override
+	public void initializeState(StateInitializationContext context) throws Exception {
+		super.initializeState(context);
+
+		Preconditions.checkState(this.checkpointedState == null,
+			"The reader state has already been initialized.");
+
+		checkpointedState = context.getOperatorStateStore()
+			.getSerializableListState("pending-checkpoints");
+
+		int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask();
+		if (context.isRestored()) {
+			LOG.info("Restoring state for the GenericWriteAheadSink (taskIdx={}).", subtaskIdx);
+
+			for (PendingCheckpoint pendingCheckpoint : checkpointedState.get()) {
+				this.pendingCheckpoints.add(pendingCheckpoint);
+			}
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("GenericWriteAheadSink idx {} restored {}.", subtaskIdx, this.pendingCheckpoints);
+			}
+		} else {
+			LOG.info("No state to restore for the GenericWriteAheadSink (taskIdx={}).", subtaskIdx);
+		}
+	}
+
+	@Override
 	public void open() throws Exception {
 		super.open();
 		committer.setOperatorId(id);
 		committer.open();
 
-		checkpointStreamFactory = getContainingTask().createCheckpointStreamFactory(this);
+		checkpointStreamFactory = getContainingTask()
+			.createCheckpointStreamFactory(this);
 
 		cleanRestoredHandles();
 	}
@@ -99,12 +132,14 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	 * @throws IOException in case something went wrong when handling the stream to the backend.
 	 */
 	private void saveHandleInState(final long checkpointId, final long timestamp) throws Exception {
+
 		//only add handle if a new OperatorState was created since the last snapshot
 		if (out != null) {
 			int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask();
 			StreamStateHandle handle = out.closeAndGetHandle();
 
-			PendingCheckpoint pendingCheckpoint = new PendingCheckpoint(checkpointId, subtaskIdx, timestamp, handle);
+			PendingCheckpoint pendingCheckpoint = new PendingCheckpoint(
+				checkpointId, subtaskIdx, timestamp, handle);
 
 			if (pendingCheckpoints.contains(pendingCheckpoint)) {
 				//we already have a checkpoint stored for that ID that may have been partially written,
@@ -118,22 +153,23 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	}
 
 	@Override
-	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
-		saveHandleInState(checkpointId, timestamp);
+	public void snapshotState(StateSnapshotContext context) throws Exception {
+		super.snapshotState(context);
+
+		Preconditions.checkState(this.checkpointedState != null,
+			"The operator state has not been properly initialized.");
 
-		DataOutputViewStreamWrapper outStream = new DataOutputViewStreamWrapper(out);
-		outStream.writeInt(pendingCheckpoints.size());
+		saveHandleInState(context.getCheckpointId(), context.getCheckpointTimestamp());
+
+		this.checkpointedState.clear();
 		for (PendingCheckpoint pendingCheckpoint : pendingCheckpoints) {
-			pendingCheckpoint.serialize(outStream);
+			// create a new partition for each entry.
+			this.checkpointedState.add(pendingCheckpoint);
 		}
-	}
 
-	@Override
-	public void restoreState(FSDataInputStream in) throws Exception {
-		final DataInputViewStreamWrapper inStream = new DataInputViewStreamWrapper(in);
-		int numPendingHandles = inStream.readInt();
-		for (int i = 0; i < numPendingHandles; i++) {
-			pendingCheckpoints.add(PendingCheckpoint.restore(inStream, getUserCodeClassloader()));
+		int subtaskIdx = getRuntimeContext().getIndexOfThisSubtask();
+		if (LOG.isDebugEnabled()) {
+			LOG.debug("{} (taskIdx= {}) checkpointed {}.", getClass().getSimpleName(), subtaskIdx, this.pendingCheckpoints);
 		}
 	}
 
@@ -162,9 +198,12 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 		super.notifyOfCompletedCheckpoint(checkpointId);
 
 		synchronized (pendingCheckpoints) {
+
 			Iterator<PendingCheckpoint> pendingCheckpointIt = pendingCheckpoints.iterator();
 			while (pendingCheckpointIt.hasNext()) {
+
 				PendingCheckpoint pendingCheckpoint = pendingCheckpointIt.next();
+
 				long pastCheckpointId = pendingCheckpoint.checkpointId;
 				int subtaskId = pendingCheckpoint.subtaskId;
 				long timestamp = pendingCheckpoint.timestamp;
@@ -241,34 +280,15 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 			this.stateHandle = handle;
 		}
 
-		void serialize(DataOutputViewStreamWrapper outputStream) throws IOException {
-			outputStream.writeLong(checkpointId);
-			outputStream.writeInt(subtaskId);
-			outputStream.writeLong(timestamp);
-			InstantiationUtil.serializeObject(outputStream, stateHandle);
-		}
-
-		static PendingCheckpoint restore(
-				DataInputViewStreamWrapper inputStream,
-				ClassLoader classLoader) throws IOException, ClassNotFoundException {
-
-			long checkpointId = inputStream.readLong();
-			int subtaskId = inputStream.readInt();
-			long timestamp = inputStream.readLong();
-			StreamStateHandle handle = InstantiationUtil.deserializeObject(inputStream, classLoader);
-
-			return new PendingCheckpoint(checkpointId, subtaskId, timestamp, handle);
-		}
-
 		@Override
 		public int compareTo(PendingCheckpoint o) {
 			int res = Long.compare(this.checkpointId, o.checkpointId);
-			return res != 0 ? res : Integer.compare(this.subtaskId, o.subtaskId);
+			return res != 0 ? res : this.subtaskId - o.subtaskId;
 		}
 
 		@Override
 		public boolean equals(Object o) {
-			if (!(o instanceof GenericWriteAheadSink.PendingCheckpoint)) {
+			if (o == null || !(o instanceof GenericWriteAheadSink.PendingCheckpoint)) {
 				return false;
 			}
 			PendingCheckpoint other = (PendingCheckpoint) o;
@@ -285,5 +305,10 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 			hash = 31 * hash + (int) (timestamp ^ (timestamp >>> 32));
 			return hash;
 		}
+
+		@Override
+		public String toString() {
+			return "Pending Checkpoint: id=" + checkpointId + "/" + subtaskId + "@" + timestamp;
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4eb71927/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
index 8d092ed..9bcd2e6 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.java.tuple.Tuple1;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -29,6 +30,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Integer>, GenericWriteAheadSinkTest.ListSink> {
@@ -50,9 +52,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 
 
 	@Override
-	protected void verifyResultsIdealCircumstances(
-		OneInputStreamOperatorTestHarness<Tuple1<Integer>, Tuple1<Integer>> harness,
-		ListSink sink) {
+	protected void verifyResultsIdealCircumstances(ListSink sink) {
 
 		ArrayList<Integer> list = new ArrayList<>();
 		for (int x = 1; x <= 60; x++) {
@@ -67,9 +67,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 	}
 
 	@Override
-	protected void verifyResultsDataPersistenceUponMissedNotify(
-		OneInputStreamOperatorTestHarness<Tuple1<Integer>, Tuple1<Integer>> harness,
-		ListSink sink) {
+	protected void verifyResultsDataPersistenceUponMissedNotify(ListSink sink) {
 
 		ArrayList<Integer> list = new ArrayList<>();
 		for (int x = 1; x <= 60; x++) {
@@ -84,9 +82,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 	}
 
 	@Override
-	protected void verifyResultsDataDiscardingUponRestore(
-		OneInputStreamOperatorTestHarness<Tuple1<Integer>, Tuple1<Integer>> harness,
-		ListSink sink) {
+	protected void verifyResultsDataDiscardingUponRestore(ListSink sink) {
 
 		ArrayList<Integer> list = new ArrayList<>();
 		for (int x = 1; x <= 20; x++) {
@@ -103,6 +99,18 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 		Assert.assertTrue("The sink emitted to many values: " + (sink.values.size() - 40), sink.values.size() == 40);
 	}
 
+	@Override
+	protected void verifyResultsWhenReScaling(ListSink sink, int startElementCounter, int endElementCounter) throws Exception {
+
+		ArrayList<Integer> list = new ArrayList<>();
+		for (int i = startElementCounter; i <= endElementCounter; i++) {
+			list.add(i);
+		}
+
+		Collections.sort(sink.values);
+		Assert.assertArrayEquals(list.toArray(), sink.values.toArray());
+	}
+
 	@Test
 	/**
 	 * Verifies that exceptions thrown by a committer do not fail a job and lead to an abort of notify()
@@ -124,33 +132,33 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(0, 0);
+		testHarness.snapshot(0, 0);
 		testHarness.notifyOfCompletedCheckpoint(0);
 
 		//isCommitted should have failed, thus sendValues() should never have been called
 		Assert.assertEquals(0, sink.values.size());
 
-		for (int x = 0; x < 10; x++) {
+		for (int x = 0; x < 11; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 1)));
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(1, 0);
+		testHarness.snapshot(1, 0);
 		testHarness.notifyOfCompletedCheckpoint(1);
 
 		//previous CP should be retried, but will fail the CP commit. Second CP should be skipped.
 		Assert.assertEquals(10, sink.values.size());
 
-		for (int x = 0; x < 10; x++) {
+		for (int x = 0; x < 12; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 2)));
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(2, 0);
+		testHarness.snapshot(2, 0);
 		testHarness.notifyOfCompletedCheckpoint(2);
 
-		//all CP's should be retried and succeed; since one CP was written twice we have 2 * 10 + 10 + 10 = 40 values
-		Assert.assertEquals(40, sink.values.size());
+		//all CP's should be retried and succeed; since one CP was written twice we have 2 * 10 + 11 + 12 = 43 values
+		Assert.assertEquals(43, sink.values.size());
 	}
 
 	/**
@@ -177,7 +185,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 	public static class SimpleCommitter extends CheckpointCommitter {
 		private static final long serialVersionUID = 1L;
 
-		private List<Long> checkpoints;
+		private List<Tuple2<Long, Integer>> checkpoints;
 
 		@Override
 		public void open() throws Exception {
@@ -194,12 +202,12 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 
 		@Override
 		public void commitCheckpoint(int subtaskIdx, long checkpointID) {
-			checkpoints.add(checkpointID);
+			checkpoints.add(new Tuple2<>(checkpointID, subtaskIdx));
 		}
 
 		@Override
 		public boolean isCheckpointCommitted(int subtaskIdx, long checkpointID) {
-			return checkpoints.contains(checkpointID);
+			return checkpoints.contains(new Tuple2<>(checkpointID, subtaskIdx));
 		}
 	}
 
@@ -227,7 +235,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 	public static class FailingCommitter extends CheckpointCommitter {
 		private static final long serialVersionUID = 1L;
 
-		private List<Long> checkpoints;
+		private List<Tuple2<Long, Integer>> checkpoints;
 		private boolean failIsCommitted = true;
 		private boolean failCommit = true;
 
@@ -250,7 +258,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 				failCommit = false;
 				throw new RuntimeException("Expected exception");
 			} else {
-				checkpoints.add(checkpointID);
+				checkpoints.add(new Tuple2<>(checkpointID, subtaskIdx));
 			}
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/4eb71927/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
index a9c5792..46d92af 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
@@ -19,8 +19,9 @@
 package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.util.TestLogger;
 
@@ -34,14 +35,13 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 
 	protected abstract IN generateValue(int counter, int checkpointID);
 
-	protected abstract void verifyResultsIdealCircumstances(
-		OneInputStreamOperatorTestHarness<IN, IN> harness, S sink) throws Exception;
+	protected abstract void verifyResultsIdealCircumstances(S sink) throws Exception;
 
-	protected abstract void verifyResultsDataPersistenceUponMissedNotify(
-			OneInputStreamOperatorTestHarness<IN, IN> harness, S sink) throws Exception;
+	protected abstract void verifyResultsDataPersistenceUponMissedNotify(S sink) throws Exception;
 
-	protected abstract void verifyResultsDataDiscardingUponRestore(
-		OneInputStreamOperatorTestHarness<IN, IN> harness, S sink) throws Exception;
+	protected abstract void verifyResultsDataDiscardingUponRestore(S sink) throws Exception;
+
+	protected abstract void verifyResultsWhenReScaling(S sink, int startElementCounter, int endElementCounter) throws Exception;
 
 	@Test
 	public void testIdealCircumstances() throws Exception {
@@ -60,7 +60,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -68,7 +68,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -76,10 +76,10 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
-		verifyResultsIdealCircumstances(testHarness, sink);
+		verifyResultsIdealCircumstances(sink);
 	}
 
 	@Test
@@ -99,7 +99,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -107,17 +107,17 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 
 		for (int x = 0; x < 20; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 2)));
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
-		verifyResultsDataPersistenceUponMissedNotify(testHarness, sink);
+		verifyResultsDataPersistenceUponMissedNotify(sink);
 	}
 
 	@Test
@@ -137,7 +137,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		StreamStateHandle latestSnapshot = testHarness.snapshotLegacy(snapshotCount++, 0);
+		OperatorStateHandles latestSnapshot = testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -152,7 +152,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 		testHarness = new OneInputStreamOperatorTestHarness<>(sink);
 
 		testHarness.setup();
-		testHarness.restore(latestSnapshot);
+		testHarness.initializeState(latestSnapshot);
 		testHarness.open();
 
 		for (int x = 0; x < 20; x++) {
@@ -160,9 +160,143 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshotLegacy(snapshotCount++, 0);
+		testHarness.snapshot(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
-		verifyResultsDataDiscardingUponRestore(testHarness, sink);
+		verifyResultsDataDiscardingUponRestore(sink);
+	}
+
+	@Test
+	public void testScalingDown() throws Exception {
+		S sink1 = createSink();
+		OneInputStreamOperatorTestHarness<IN, IN> testHarness1 =
+			new OneInputStreamOperatorTestHarness<>(sink1, 10, 2, 0);
+		testHarness1.open();
+
+		S sink2 = createSink();
+		OneInputStreamOperatorTestHarness<IN, IN> testHarness2 =
+			new OneInputStreamOperatorTestHarness<>(sink2, 10, 2, 1);
+		testHarness2.open();
+
+		int elementCounter = 1;
+		int snapshotCount = 0;
+
+		for (int x = 0; x < 10; x++) {
+			testHarness1.processElement(new StreamRecord<>(generateValue(elementCounter, 0)));
+			elementCounter++;
+		}
+
+		for (int x = 0; x < 11; x++) {
+			testHarness2.processElement(new StreamRecord<>(generateValue(elementCounter, 0)));
+			elementCounter++;
+		}
+
+		// snapshot at checkpoint 0 for testHarness1 and testHarness 2
+		OperatorStateHandles snapshot1 = testHarness1.snapshot(snapshotCount, 0);
+		OperatorStateHandles snapshot2 = testHarness2.snapshot(snapshotCount, 0);
+
+		// merge the two partial states
+		OperatorStateHandles mergedSnapshot = AbstractStreamOperatorTestHarness
+			.repackageState(snapshot1, snapshot2);
+
+		testHarness1.close();
+		testHarness2.close();
+
+		// and create a third instance that operates alone but
+		// has the merged state of the previous 2 instances
+
+		S sink3 = createSink();
+		OneInputStreamOperatorTestHarness<IN, IN> mergedTestHarness =
+			new OneInputStreamOperatorTestHarness<>(sink3, 10, 1, 0);
+
+		mergedTestHarness.setup();
+		mergedTestHarness.initializeState(mergedSnapshot);
+		mergedTestHarness.open();
+
+		for (int x = 0; x < 12; x++) {
+			mergedTestHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 0)));
+			elementCounter++;
+		}
+
+		snapshotCount++;
+		mergedTestHarness.snapshot(snapshotCount, 1);
+		mergedTestHarness.notifyOfCompletedCheckpoint(snapshotCount);
+
+		verifyResultsWhenReScaling(sink3, 1, 33);
+		mergedTestHarness.close();
+	}
+
+	@Test
+	public void testScalingUp() throws Exception {
+
+		S sink1 = createSink();
+		OneInputStreamOperatorTestHarness<IN, IN> testHarness1 =
+			new OneInputStreamOperatorTestHarness<>(sink1, 10, 1, 0);
+
+		int elementCounter = 1;
+		int snapshotCount = 0;
+
+		testHarness1.open();
+
+		// put two more checkpoints as pending
+
+		for (int x = 0; x < 10; x++) {
+			testHarness1.processElement(new StreamRecord<>(generateValue(elementCounter, 0)));
+			elementCounter++;
+		}
+		testHarness1.snapshot(++snapshotCount, 0);
+
+		for (int x = 0; x < 11; x++) {
+			testHarness1.processElement(new StreamRecord<>(generateValue(elementCounter, 0)));
+			elementCounter++;
+		}
+
+		// this will be the state that will be split between the two new operators
+		OperatorStateHandles snapshot = testHarness1.snapshot(++snapshotCount, 0);
+
+		testHarness1.close();
+
+		// verify no elements are in the sink
+		verifyResultsWhenReScaling(sink1, 0, -1);
+
+		// we will create two operator instances, testHarness2 and testHarness3,
+		// that will share the state of testHarness1
+
+		++snapshotCount;
+
+		S sink2 = createSink();
+		OneInputStreamOperatorTestHarness<IN, IN> testHarness2 =
+			new OneInputStreamOperatorTestHarness<>(sink2, 10, 2, 0);
+
+		testHarness2.setup();
+		testHarness2.initializeState(snapshot);
+		testHarness2.open();
+
+		testHarness2.notifyOfCompletedCheckpoint(snapshotCount);
+
+		verifyResultsWhenReScaling(sink2, 1, 10);
+
+		S sink3 = createSink();
+		OneInputStreamOperatorTestHarness<IN, IN> testHarness3 =
+			new OneInputStreamOperatorTestHarness<>(sink3, 10, 2, 1);
+
+		testHarness3.setup();
+		testHarness3.initializeState(snapshot);
+		testHarness3.open();
+
+		// add some more elements to verify that everything functions normally from now on...
+
+		for (int x = 0; x < 10; x++) {
+			testHarness3.processElement(new StreamRecord<>(generateValue(elementCounter, 0)));
+			elementCounter++;
+		}
+
+		testHarness3.snapshot(snapshotCount, 1);
+		testHarness3.notifyOfCompletedCheckpoint(snapshotCount);
+
+		verifyResultsWhenReScaling(sink3, 11, 31);
+
+		testHarness2.close();
+		testHarness3.close();
 	}
 }