You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/07/09 14:13:16 UTC

[1/2] flink git commit: [FLINK-9486][state] Introduce InternalPriorityQueue as state in keyed state backends

Repository: flink
Updated Branches:
  refs/heads/master b12acea2a -> 79b38f8f9


http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java
index 7bf652f..6c1b188 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/HeapInternalTimerService.java
@@ -24,13 +24,15 @@ import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.InternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ScheduledFuture;
@@ -50,12 +52,12 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 	/**
 	 * Processing time timers that are currently in-flight.
 	 */
-	private final HeapPriorityQueueSet<TimerHeapInternalTimer<K, N>> processingTimeTimersQueue;
+	private final KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> processingTimeTimersQueue;
 
 	/**
 	 * Event time timers that are currently in-flight.
 	 */
-	private final HeapPriorityQueueSet<TimerHeapInternalTimer<K, N>> eventTimeTimersQueue;
+	private final KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> eventTimeTimersQueue;
 
 	/**
 	 * Information concerning the local key-group range.
@@ -94,14 +96,17 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 	private InternalTimersSnapshot<K, N> restoredTimersSnapshot;
 
 	HeapInternalTimerService(
-		int totalKeyGroups,
 		KeyGroupRange localKeyGroupRange,
 		KeyContext keyContext,
-		ProcessingTimeService processingTimeService) {
+		ProcessingTimeService processingTimeService,
+		KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> processingTimeTimersQueue,
+		KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> eventTimeTimersQueue) {
 
 		this.keyContext = checkNotNull(keyContext);
 		this.processingTimeService = checkNotNull(processingTimeService);
 		this.localKeyGroupRange = checkNotNull(localKeyGroupRange);
+		this.processingTimeTimersQueue = checkNotNull(processingTimeTimersQueue);
+		this.eventTimeTimersQueue = checkNotNull(eventTimeTimersQueue);
 
 		// find the starting index of the local key-group range
 		int startIdx = Integer.MAX_VALUE;
@@ -109,9 +114,6 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 			startIdx = Math.min(keyGroupIdx, startIdx);
 		}
 		this.localKeyGroupRangeStartIdx = startIdx;
-
-		this.eventTimeTimersQueue = createPriorityQueue(localKeyGroupRange, totalKeyGroups);
-		this.processingTimeTimersQueue = createPriorityQueue(localKeyGroupRange, totalKeyGroups);
 	}
 
 	/**
@@ -225,16 +227,20 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 		// inside the callback.
 		nextTimer = null;
 
-		InternalTimer<K, N> timer;
-
-		while ((timer = processingTimeTimersQueue.peek()) != null && timer.getTimestamp() <= time) {
-			processingTimeTimersQueue.poll();
-			keyContext.setCurrentKey(timer.getKey());
-			triggerTarget.onProcessingTime(timer);
-		}
+		processingTimeTimersQueue.bulkPoll(
+			(timer) -> (timer.getTimestamp() <= time),
+			(timer) -> {
+				keyContext.setCurrentKey(timer.getKey());
+				try {
+					triggerTarget.onProcessingTime(timer);
+				} catch (Exception e) {
+					throw new FlinkRuntimeException("Problem in trigger target.", e);
+				}
+			});
 
-		if (timer != null) {
-			if (nextTimer == null) {
+		if (nextTimer == null) {
+			final TimerHeapInternalTimer<K, N> timer = processingTimeTimersQueue.peek();
+			if (timer != null) {
 				nextTimer = processingTimeService.registerTimer(timer.getTimestamp(), this);
 			}
 		}
@@ -242,14 +248,16 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 
 	public void advanceWatermark(long time) throws Exception {
 		currentWatermark = time;
-
-		InternalTimer<K, N> timer;
-
-		while ((timer = eventTimeTimersQueue.peek()) != null && timer.getTimestamp() <= time) {
-			eventTimeTimersQueue.poll();
-			keyContext.setCurrentKey(timer.getKey());
-			triggerTarget.onEventTime(timer);
-		}
+		eventTimeTimersQueue.bulkPoll(
+			(timer) -> (timer.getTimestamp() <= time),
+			(timer) -> {
+				keyContext.setCurrentKey(timer.getKey());
+				try {
+					triggerTarget.onEventTime(timer);
+				} catch (Exception e) {
+					throw new FlinkRuntimeException("Problem in trigger target.", e);
+				}
+			});
 	}
 
 	/**
@@ -264,8 +272,8 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 			keySerializer.snapshotConfiguration(),
 			namespaceSerializer,
 			namespaceSerializer.snapshotConfiguration(),
-			eventTimeTimersQueue.getElementsForKeyGroup(keyGroupIdx),
-			processingTimeTimersQueue.getElementsForKeyGroup(keyGroupIdx));
+			eventTimeTimersQueue.getSubsetForKeyGroup(keyGroupIdx),
+			processingTimeTimersQueue.getSubsetForKeyGroup(keyGroupIdx));
 	}
 
 	/**
@@ -339,27 +347,24 @@ public class HeapInternalTimerService<K, N> implements InternalTimerService<N>,
 
 	@VisibleForTesting
 	List<Set<TimerHeapInternalTimer<K, N>>> getEventTimeTimersPerKeyGroup() {
-		return eventTimeTimersQueue.getElementsByKeyGroup();
+		return partitionElementsByKeyGroup(eventTimeTimersQueue);
 	}
 
 	@VisibleForTesting
 	List<Set<TimerHeapInternalTimer<K, N>>> getProcessingTimeTimersPerKeyGroup() {
-		return processingTimeTimersQueue.getElementsByKeyGroup();
+		return partitionElementsByKeyGroup(processingTimeTimersQueue);
+	}
+
+	private <T> List<Set<T>> partitionElementsByKeyGroup(KeyGroupedInternalPriorityQueue<T> keyGroupedQueue) {
+		List<Set<T>> result = new ArrayList<>(localKeyGroupRange.getNumberOfKeyGroups());
+		for (int keyGroup : localKeyGroupRange) {
+			result.add(Collections.unmodifiableSet(keyGroupedQueue.getSubsetForKeyGroup(keyGroup)));
+		}
+		return result;
 	}
 
 	private boolean areSnapshotSerializersIncompatible(InternalTimersSnapshot<?, ?> restoredSnapshot) {
 		return (this.keyDeserializer != null && !this.keyDeserializer.equals(restoredSnapshot.getKeySerializer())) ||
 			(this.namespaceDeserializer != null && !this.namespaceDeserializer.equals(restoredSnapshot.getNamespaceSerializer()));
 	}
-
-	private static <K, N> HeapPriorityQueueSet<TimerHeapInternalTimer<K, N>> createPriorityQueue(
-		KeyGroupRange localKeyGroupRange,
-		int totalKeyGroups) {
-		return new HeapPriorityQueueSet<>(
-			TimerHeapInternalTimer.getTimerComparator(),
-			TimerHeapInternalTimer.getKeyExtractorFunction(),
-			128,
-			localKeyGroupRange,
-			totalKeyGroups);
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
index e62883a..ad1617e 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
@@ -20,16 +20,17 @@ package org.apache.flink.streaming.api.operators;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.PriorityQueueSetFactory;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -49,6 +50,7 @@ public class InternalTimeServiceManager<K> {
 	private final KeyGroupRange localKeyGroupRange;
 	private final KeyContext keyContext;
 
+	private final PriorityQueueSetFactory priorityQueueSetFactory;
 	private final ProcessingTimeService processingTimeService;
 
 	private final Map<String, HeapInternalTimerService<K, ?>> timerServices;
@@ -57,52 +59,66 @@ public class InternalTimeServiceManager<K> {
 			int totalKeyGroups,
 			KeyGroupRange localKeyGroupRange,
 			KeyContext keyContext,
+			PriorityQueueSetFactory priorityQueueSetFactory,
 			ProcessingTimeService processingTimeService) {
 
 		Preconditions.checkArgument(totalKeyGroups > 0);
 		this.totalKeyGroups = totalKeyGroups;
 		this.localKeyGroupRange = Preconditions.checkNotNull(localKeyGroupRange);
-
+		this.priorityQueueSetFactory = Preconditions.checkNotNull(priorityQueueSetFactory);
 		this.keyContext = Preconditions.checkNotNull(keyContext);
 		this.processingTimeService = Preconditions.checkNotNull(processingTimeService);
 
 		this.timerServices = new HashMap<>();
 	}
 
-	/**
-	 * Returns a {@link InternalTimerService} that can be used to query current processing time
-	 * and event time and to set timers. An operator can have several timer services, where
-	 * each has its own namespace serializer. Timer services are differentiated by the string
-	 * key that is given when requesting them, if you call this method with the same key
-	 * multiple times you will get the same timer service instance in subsequent requests.
-	 *
-	 * <p>Timers are always scoped to a key, the currently active key of a keyed stream operation.
-	 * When a timer fires, this key will also be set as the currently active key.
-	 *
-	 * <p>Each timer has attached metadata, the namespace. Different timer services
-	 * can have a different namespace type. If you don't need namespace differentiation you
-	 * can use {@link VoidNamespaceSerializer} as the namespace serializer.
-	 *
-	 * @param name The name of the requested timer service. If no service exists under the given
-	 *             name a new one will be created and returned.
-	 * @param keySerializer {@code TypeSerializer} for the timer keys.
-	 * @param namespaceSerializer {@code TypeSerializer} for the timer namespace.
-	 * @param triggerable The {@link Triggerable} that should be invoked when timers fire
-	 */
 	@SuppressWarnings("unchecked")
-	public <N> InternalTimerService<N> getInternalTimerService(String name, TypeSerializer<K> keySerializer,
-														TypeSerializer<N> namespaceSerializer, Triggerable<K, N> triggerable) {
+	public <N> InternalTimerService<N> getInternalTimerService(
+		String name,
+		TimerSerializer<K, N> timerSerializer,
+		Triggerable<K, N> triggerable) {
+
+		HeapInternalTimerService<K, N> timerService = registerOrGetTimerService(name, timerSerializer);
+
+		timerService.startTimerService(
+			timerSerializer.getKeySerializer(),
+			timerSerializer.getNamespaceSerializer(),
+			triggerable);
+
+		return timerService;
+	}
 
+	@SuppressWarnings("unchecked")
+	<N> HeapInternalTimerService<K, N> registerOrGetTimerService(String name, TimerSerializer<K, N> timerSerializer) {
 		HeapInternalTimerService<K, N> timerService = (HeapInternalTimerService<K, N>) timerServices.get(name);
 		if (timerService == null) {
-			timerService = new HeapInternalTimerService<>(totalKeyGroups,
-				localKeyGroupRange, keyContext, processingTimeService);
+
+			timerService = new HeapInternalTimerService<>(
+				localKeyGroupRange,
+				keyContext,
+				processingTimeService,
+				createTimerPriorityQueue("__ts_" + name + "/processing_timers", timerSerializer),
+				createTimerPriorityQueue("__ts_" + name + "/event_timers", timerSerializer));
+
 			timerServices.put(name, timerService);
 		}
-		timerService.startTimerService(keySerializer, namespaceSerializer, triggerable);
 		return timerService;
 	}
 
+	Map<String, HeapInternalTimerService<K, ?>> getRegisteredTimerServices() {
+		return Collections.unmodifiableMap(timerServices);
+	}
+
+	private <N> KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> createTimerPriorityQueue(
+		String name,
+		TimerSerializer<K, N> timerSerializer) {
+		return priorityQueueSetFactory.create(
+			name,
+			timerSerializer,
+			InternalTimer.getTimerComparator(),
+			InternalTimer.getKeyExtractorFunction());
+	}
+
 	public void advanceWatermark(Watermark watermark) throws Exception {
 		for (HeapInternalTimerService<?, ?> service : timerServices.values()) {
 			service.advanceWatermark(watermark.getTimestamp());
@@ -113,7 +129,7 @@ public class InternalTimeServiceManager<K> {
 
 	public void snapshotStateForKeyGroup(DataOutputView stream, int keyGroupIdx) throws IOException {
 		InternalTimerServiceSerializationProxy<K> serializationProxy =
-			new InternalTimerServiceSerializationProxy<>(timerServices, keyGroupIdx);
+			new InternalTimerServiceSerializationProxy<>(this, keyGroupIdx);
 
 		serializationProxy.write(stream);
 	}
@@ -125,12 +141,8 @@ public class InternalTimeServiceManager<K> {
 
 		InternalTimerServiceSerializationProxy<K> serializationProxy =
 			new InternalTimerServiceSerializationProxy<>(
-				timerServices,
+				this,
 				userCodeClassLoader,
-				totalKeyGroups,
-				localKeyGroupRange,
-				keyContext,
-				processingTimeService,
 				keyGroupIdx);
 
 		serializationProxy.read(stream);

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java
index 5ba1a0f..f88b4fb 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimer.java
@@ -19,6 +19,10 @@
 package org.apache.flink.streaming.api.operators;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.PriorityComparator;
+
+import javax.annotation.Nonnull;
 
 /**
  * Internal interface for in-flight timers.
@@ -29,6 +33,12 @@ import org.apache.flink.annotation.Internal;
 @Internal
 public interface InternalTimer<K, N> {
 
+	/** Function to extract the key from a {@link InternalTimer}. */
+	KeyExtractorFunction<InternalTimer<?, ?>> KEY_EXTRACTOR_FUNCTION = InternalTimer::getKey;
+
+	/** Function to compare instances of {@link InternalTimer}. */
+	PriorityComparator<InternalTimer<?, ?>> TIMER_COMPARATOR =
+		(left, right) -> Long.compare(left.getTimestamp(), right.getTimestamp());
 	/**
 	 * Returns the timestamp of the timer. This value determines the point in time when the timer will fire.
 	 */
@@ -37,10 +47,22 @@ public interface InternalTimer<K, N> {
 	/**
 	 * Returns the key that is bound to this timer.
 	 */
+	@Nonnull
 	K getKey();
 
 	/**
 	 * Returns the namespace that is bound to this timer.
 	 */
+	@Nonnull
 	N getNamespace();
+
+	@SuppressWarnings("unchecked")
+	static <T extends InternalTimer> PriorityComparator<T> getTimerComparator() {
+		return (PriorityComparator<T>) TIMER_COMPARATOR;
+	}
+
+	@SuppressWarnings("unchecked")
+	static <T extends InternalTimer> KeyExtractorFunction<T> getKeyExtractorFunction() {
+		return (KeyExtractorFunction<T>) KEY_EXTRACTOR_FUNCTION;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
index efa93d3..ce490b5 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerServiceSerializationProxy.java
@@ -19,11 +19,10 @@
 package org.apache.flink.streaming.api.operators;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.io.PostVersionedIOReadableWritable;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 
 import java.io.IOException;
 import java.util.Map;
@@ -39,36 +38,24 @@ public class InternalTimerServiceSerializationProxy<K> extends PostVersionedIORe
 	public static final int VERSION = 1;
 
 	/** The key-group timer services to write / read. */
-	private Map<String, HeapInternalTimerService<K, ?>> timerServices;
+	private final InternalTimeServiceManager<K> timerServicesManager;
 
 	/** The user classloader; only relevant if the proxy is used to restore timer services. */
 	private ClassLoader userCodeClassLoader;
 
 	/** Properties of restored timer services. */
-	private int keyGroupIdx;
-	private int totalKeyGroups;
-	private KeyGroupRange localKeyGroupRange;
-	private KeyContext keyContext;
-	private ProcessingTimeService processingTimeService;
+	private final int keyGroupIdx;
+
 
 	/**
 	 * Constructor to use when restoring timer services.
 	 */
 	public InternalTimerServiceSerializationProxy(
-			Map<String, HeapInternalTimerService<K, ?>> timerServicesMapToPopulate,
-			ClassLoader userCodeClassLoader,
-			int totalKeyGroups,
-			KeyGroupRange localKeyGroupRange,
-			KeyContext keyContext,
-			ProcessingTimeService processingTimeService,
-			int keyGroupIdx) {
-
-		this.timerServices = checkNotNull(timerServicesMapToPopulate);
+		InternalTimeServiceManager<K> timerServicesManager,
+		ClassLoader userCodeClassLoader,
+		int keyGroupIdx) {
+		this.timerServicesManager = checkNotNull(timerServicesManager);
 		this.userCodeClassLoader = checkNotNull(userCodeClassLoader);
-		this.totalKeyGroups = totalKeyGroups;
-		this.localKeyGroupRange = checkNotNull(localKeyGroupRange);
-		this.keyContext = checkNotNull(keyContext);
-		this.processingTimeService = checkNotNull(processingTimeService);
 		this.keyGroupIdx = keyGroupIdx;
 	}
 
@@ -76,10 +63,9 @@ public class InternalTimerServiceSerializationProxy<K> extends PostVersionedIORe
 	 * Constructor to use when writing timer services.
 	 */
 	public InternalTimerServiceSerializationProxy(
-			Map<String, HeapInternalTimerService<K, ?>> timerServices,
-			int keyGroupIdx) {
-
-		this.timerServices = checkNotNull(timerServices);
+		InternalTimeServiceManager<K> timerServicesManager,
+		int keyGroupIdx) {
+		this.timerServicesManager = checkNotNull(timerServicesManager);
 		this.keyGroupIdx = keyGroupIdx;
 	}
 
@@ -91,9 +77,11 @@ public class InternalTimerServiceSerializationProxy<K> extends PostVersionedIORe
 	@Override
 	public void write(DataOutputView out) throws IOException {
 		super.write(out);
+		final Map<String, HeapInternalTimerService<K, ?>> registeredTimerServices =
+			timerServicesManager.getRegisteredTimerServices();
 
-		out.writeInt(timerServices.size());
-		for (Map.Entry<String, HeapInternalTimerService<K, ?>> entry : timerServices.entrySet()) {
+		out.writeInt(registeredTimerServices.size());
+		for (Map.Entry<String, HeapInternalTimerService<K, ?>> entry : registeredTimerServices.entrySet()) {
 			String serviceName = entry.getKey();
 			HeapInternalTimerService<K, ?> timerService = entry.getValue();
 
@@ -111,22 +99,25 @@ public class InternalTimerServiceSerializationProxy<K> extends PostVersionedIORe
 		for (int i = 0; i < noOfTimerServices; i++) {
 			String serviceName = in.readUTF();
 
-			HeapInternalTimerService<K, ?> timerService = timerServices.get(serviceName);
-			if (timerService == null) {
-				timerService = new HeapInternalTimerService<>(
-					totalKeyGroups,
-					localKeyGroupRange,
-					keyContext,
-					processingTimeService);
-				timerServices.put(serviceName, timerService);
-			}
-
 			int readerVersion = wasVersioned ? getReadVersion() : InternalTimersSnapshotReaderWriters.NO_VERSION;
 			InternalTimersSnapshot<?, ?> restoredTimersSnapshot = InternalTimersSnapshotReaderWriters
 				.getReaderForVersion(readerVersion, userCodeClassLoader)
 				.readTimersSnapshot(in);
 
+			HeapInternalTimerService<K, ?> timerService = registerOrGetTimerService(
+				serviceName,
+				restoredTimersSnapshot);
+
 			timerService.restoreTimersForKeyGroup(restoredTimersSnapshot, keyGroupIdx);
 		}
 	}
+
+	@SuppressWarnings("unchecked")
+	private <N> HeapInternalTimerService<K, N> registerOrGetTimerService(
+		String serviceName, InternalTimersSnapshot<?, ?> restoredTimersSnapshot) {
+		final TypeSerializer<K> keySerializer = (TypeSerializer<K>) restoredTimersSnapshot.getKeySerializer();
+		final TypeSerializer<N> namespaceSerializer = (TypeSerializer<N>) restoredTimersSnapshot.getNamespaceSerializer();
+		TimerSerializer<K, N> timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer);
+		return timerServicesManager.registerOrGetTimerService(serviceName, timerSerializer);
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
index 578302b..594f337 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
@@ -207,6 +207,7 @@ public class StreamTaskStateInitializerImpl implements StreamTaskStateInitialize
 			keyedStatedBackend.getNumberOfKeyGroups(),
 			keyGroupRange,
 			keyContext,
+			keyedStatedBackend,
 			processingTimeService);
 
 		// and then initialize the timer services

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java
index bd821c4..b9ef88e 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerHeapInternalTimer.java
@@ -19,21 +19,18 @@
 package org.apache.flink.streaming.api.operators;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.CompatibilityResult;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
 
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
-import java.util.Comparator;
 
 /**
  * Implementation of {@link InternalTimer} to use with a {@link HeapPriorityQueueSet}.
@@ -44,14 +41,6 @@ import java.util.Comparator;
 @Internal
 public final class TimerHeapInternalTimer<K, N> implements InternalTimer<K, N>, HeapPriorityQueueElement {
 
-	/** Function to extract the key from a {@link TimerHeapInternalTimer}. */
-	private static final KeyExtractorFunction<TimerHeapInternalTimer<?, ?>> KEY_EXTRACTOR_FUNCTION =
-		TimerHeapInternalTimer::getKey;
-
-	/** Function to compare instances of {@link TimerHeapInternalTimer}. */
-	private static final Comparator<TimerHeapInternalTimer<?, ?>> TIMER_COMPARATOR =
-		(o1, o2) -> Long.compare(o1.getTimestamp(), o2.getTimestamp());
-
 	/** The key for which the timer is scoped. */
 	@Nonnull
 	private final K key;
@@ -144,18 +133,6 @@ public final class TimerHeapInternalTimer<K, N> implements InternalTimer<K, N>,
 				'}';
 	}
 
-	@VisibleForTesting
-	@SuppressWarnings("unchecked")
-	static <T extends TimerHeapInternalTimer> Comparator<T> getTimerComparator() {
-		return (Comparator<T>) TIMER_COMPARATOR;
-	}
-
-	@SuppressWarnings("unchecked")
-	@VisibleForTesting
-	static <T extends TimerHeapInternalTimer> KeyExtractorFunction<T> getKeyExtractorFunction() {
-		return (KeyExtractorFunction<T>) KEY_EXTRACTOR_FUNCTION;
-	}
-
 	/**
 	 * A {@link TypeSerializer} used to serialize/deserialize a {@link TimerHeapInternalTimer}.
 	 */

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
new file mode 100644
index 0000000..87a3159
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.typeutils.CompatibilityResult;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.MathUtils;
+
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * A serializer for {@link TimerHeapInternalTimer} objects that produces a serialization format that is aligned with
+ * {@link InternalTimer#getTimerComparator()}.
+ *
+ * @param <K> type of the timer key.
+ * @param <N> type of the timer namespace.
+ */
+public class TimerSerializer<K, N> extends TypeSerializer<TimerHeapInternalTimer<K, N>> {
+
+	private static final long serialVersionUID = 1L;
+
+	/** Serializer for the key. */
+	@Nonnull
+	private final TypeSerializer<K> keySerializer;
+
+	/** Serializer for the namespace. */
+	@Nonnull
+	private final TypeSerializer<N> namespaceSerializer;
+
+	/** The bytes written for one timer, or -1 if variable size. */
+	private final int length;
+
+	/** True iff the serialized type (and composite objects) are immutable. */
+	private final boolean immutableType;
+
+	TimerSerializer(
+		@Nonnull TypeSerializer<K> keySerializer,
+		@Nonnull TypeSerializer<N> namespaceSerializer) {
+		this(
+			keySerializer,
+			namespaceSerializer,
+			computeTotalByteLength(keySerializer, namespaceSerializer),
+			keySerializer.isImmutableType() && namespaceSerializer.isImmutableType());
+	}
+
+	private TimerSerializer(
+		@Nonnull TypeSerializer<K> keySerializer,
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		int length,
+		boolean immutableType) {
+
+		this.keySerializer = keySerializer;
+		this.namespaceSerializer = namespaceSerializer;
+		this.length = length;
+		this.immutableType = immutableType;
+	}
+
+	private static int computeTotalByteLength(
+		TypeSerializer<?> keySerializer,
+		TypeSerializer<?> namespaceSerializer) {
+		if (keySerializer.getLength() >= 0 && namespaceSerializer.getLength() >= 0) {
+			// timestamp + key + namespace
+			return Long.BYTES + keySerializer.getLength() + namespaceSerializer.getLength();
+		} else {
+			return -1;
+		}
+	}
+
+	@Override
+	public boolean isImmutableType() {
+		return immutableType;
+	}
+
+	@Override
+	public TimerSerializer<K, N> duplicate() {
+
+		final TypeSerializer<K> keySerializerDuplicate = keySerializer.duplicate();
+		final TypeSerializer<N> namespaceSerializerDuplicate = namespaceSerializer.duplicate();
+
+		if (keySerializerDuplicate == keySerializer &&
+			namespaceSerializerDuplicate == namespaceSerializer) {
+			// all delegate serializers seem stateless, so this is also stateless.
+			return this;
+		} else {
+			// at least one delegate serializer seems to be stateful, so we return a new instance.
+			return new TimerSerializer<>(
+				keySerializerDuplicate,
+				namespaceSerializerDuplicate,
+				length,
+				immutableType);
+		}
+	}
+
+	@Override
+	public TimerHeapInternalTimer<K, N> createInstance() {
+		return new TimerHeapInternalTimer<>(
+			0L,
+			keySerializer.createInstance(),
+			namespaceSerializer.createInstance());
+	}
+
+	@Override
+	public TimerHeapInternalTimer<K, N> copy(TimerHeapInternalTimer<K, N> from) {
+
+		K keyDuplicate;
+		N namespaceDuplicate;
+		if (isImmutableType()) {
+			keyDuplicate = from.getKey();
+			namespaceDuplicate = from.getNamespace();
+		} else {
+			keyDuplicate = keySerializer.copy(from.getKey());
+			namespaceDuplicate = namespaceSerializer.copy(from.getNamespace());
+		}
+
+		return new TimerHeapInternalTimer<>(from.getTimestamp(), keyDuplicate, namespaceDuplicate);
+	}
+
+	@Override
+	public TimerHeapInternalTimer<K, N> copy(TimerHeapInternalTimer<K, N> from, TimerHeapInternalTimer<K, N> reuse) {
+		return copy(from);
+	}
+
+	@Override
+	public int getLength() {
+		return length;
+	}
+
+	@Override
+	public void serialize(TimerHeapInternalTimer<K, N> record, DataOutputView target) throws IOException {
+		target.writeLong(MathUtils.flipSignBit(record.getTimestamp()));
+		keySerializer.serialize(record.getKey(), target);
+		namespaceSerializer.serialize(record.getNamespace(), target);
+	}
+
+	@Override
+	public TimerHeapInternalTimer<K, N> deserialize(DataInputView source) throws IOException {
+		long timestamp = MathUtils.flipSignBit(source.readLong());
+		K key = keySerializer.deserialize(source);
+		N namespace = namespaceSerializer.deserialize(source);
+		return new TimerHeapInternalTimer<>(timestamp, key, namespace);
+	}
+
+	@Override
+	public TimerHeapInternalTimer<K, N> deserialize(
+		TimerHeapInternalTimer<K, N> reuse,
+		DataInputView source) throws IOException {
+		return deserialize(source);
+	}
+
+	@Override
+	public void copy(DataInputView source, DataOutputView target) throws IOException {
+		target.writeLong(source.readLong());
+		keySerializer.copy(source, target);
+		namespaceSerializer.copy(source, target);
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (o == null || getClass() != o.getClass()) {
+			return false;
+		}
+		TimerSerializer<?, ?> that = (TimerSerializer<?, ?>) o;
+		return Objects.equals(keySerializer, that.keySerializer) &&
+			Objects.equals(namespaceSerializer, that.namespaceSerializer);
+	}
+
+	@Override
+	public int hashCode() {
+		return Objects.hash(keySerializer, namespaceSerializer);
+	}
+
+	@Override
+	public boolean canEqual(Object obj) {
+		return obj instanceof TimerSerializer;
+	}
+
+	@Override
+	public TypeSerializerConfigSnapshot snapshotConfiguration() {
+		throw new UnsupportedOperationException("This serializer is currently not used to write state.");
+	}
+
+	@Override
+	public CompatibilityResult<TimerHeapInternalTimer<K, N>> ensureCompatibility(
+		TypeSerializerConfigSnapshot configSnapshot) {
+		throw new UnsupportedOperationException("This serializer is currently not used to write state.");
+	}
+
+	@Nonnull
+	public TypeSerializer<K> getKeySerializer() {
+		return keySerializer;
+	}
+
+	@Nonnull
+	public TypeSerializer<N> getNamespaceSerializer() {
+		return namespaceSerializer;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java
index b008fa2..519f10e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/HeapInternalTimerServiceTest.java
@@ -18,12 +18,16 @@
 
 package org.apache.flink.streaming.api.operators;
 
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.PriorityQueueSetFactory;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 
@@ -85,12 +89,13 @@ public class HeapInternalTimerServiceTest {
 
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 
-		HeapInternalTimerService<Integer, String> service =
-				new HeapInternalTimerService<>(
-						testKeyGroupList.getNumberOfKeyGroups(),
-						testKeyGroupList,
-						keyContext,
-						processingTimeService);
+		HeapInternalTimerService<Integer, String> service = createInternalTimerService(
+			testKeyGroupList,
+			keyContext,
+			processingTimeService,
+			IntSerializer.INSTANCE,
+			StringSerializer.INSTANCE,
+			createQueueFactory());
 
 		Assert.assertEquals(startKeyGroupIdx, service.getLocalKeyGroupRangeStartIdx());
 	}
@@ -105,14 +110,20 @@ public class HeapInternalTimerServiceTest {
 
 		@SuppressWarnings("unchecked")
 		Set<TimerHeapInternalTimer<Integer, String>>[] expectedNonEmptyTimerSets = new HashSet[totalNoOfKeyGroups];
-
 		TestKeyContext keyContext = new TestKeyContext();
-		HeapInternalTimerService<Integer, String> timerService =
-				new HeapInternalTimerService<>(
-						totalNoOfKeyGroups,
-						new KeyGroupRange(startKeyGroupIdx, endKeyGroupIdx),
-						keyContext,
-						new TestProcessingTimeService());
+
+		final KeyGroupRange keyGroupRange = new KeyGroupRange(startKeyGroupIdx, endKeyGroupIdx);
+
+		final PriorityQueueSetFactory priorityQueueSetFactory =
+			createQueueFactory(keyGroupRange, totalNoOfKeyGroups);
+
+		HeapInternalTimerService<Integer, String> timerService = createInternalTimerService(
+			keyGroupRange,
+			keyContext,
+			new TestProcessingTimeService(),
+			IntSerializer.INSTANCE,
+			StringSerializer.INSTANCE,
+			priorityQueueSetFactory);
 
 		timerService.startTimerService(IntSerializer.INSTANCE, StringSerializer.INSTANCE, mock(Triggerable.class));
 
@@ -169,9 +180,10 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
-
+		PriorityQueueSetFactory priorityQueueSetFactory =
+			new HeapPriorityQueueSetFactory(testKeyGroupRange, maxParallelism, 128);
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, priorityQueueSetFactory);
 
 		int key = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
 		keyContext.setCurrentKey(key);
@@ -233,7 +245,7 @@ public class HeapInternalTimerServiceTest {
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		int key = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
 
@@ -266,7 +278,7 @@ public class HeapInternalTimerServiceTest {
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 
 		final HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		int key = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
 
@@ -317,7 +329,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		processingTimeService.setCurrentTime(17L);
 		assertEquals(17, timerService.currentProcessingTime());
@@ -335,7 +347,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		timerService.advanceWatermark(17);
 		assertEquals(17, timerService.currentWatermark());
@@ -355,7 +367,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		// get two different keys
 		int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
@@ -400,7 +412,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		// get two different keys
 		int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
@@ -447,7 +459,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		// get two different keys
 		int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
@@ -504,7 +516,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-				createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+				createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		// get two different keys
 		int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
@@ -579,7 +591,7 @@ public class HeapInternalTimerServiceTest {
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
 		HeapInternalTimerService<Integer, String> timerService =
-			createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+			createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, createQueueFactory());
 
 		// get two different keys
 		int key1 = getKeyInKeyGroupRange(testKeyGroupRange, maxParallelism);
@@ -631,7 +643,7 @@ public class HeapInternalTimerServiceTest {
 			keyContext,
 			processingTimeService,
 			testKeyGroupRange,
-			maxParallelism);
+			createQueueFactory());
 
 		processingTimeService.setCurrentTime(10);
 		timerService.advanceWatermark(10);
@@ -652,8 +664,9 @@ public class HeapInternalTimerServiceTest {
 
 		TestKeyContext keyContext = new TestKeyContext();
 		TestProcessingTimeService processingTimeService = new TestProcessingTimeService();
+		final PriorityQueueSetFactory queueFactory = createQueueFactory();
 		HeapInternalTimerService<Integer, String> timerService =
-			createTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, maxParallelism);
+			createAndStartInternalTimerService(mockTriggerable, keyContext, processingTimeService, testKeyGroupRange, queueFactory);
 
 		int midpoint = testKeyGroupRange.getStartKeyGroup() +
 			(testKeyGroupRange.getEndKeyGroup() - testKeyGroupRange.getStartKeyGroup()) / 2;
@@ -724,7 +737,7 @@ public class HeapInternalTimerServiceTest {
 			keyContext1,
 			processingTimeService1,
 			subKeyGroupRange1,
-			maxParallelism);
+			queueFactory);
 
 		HeapInternalTimerService<Integer, String> timerService2 = restoreTimerService(
 			snapshot2,
@@ -733,7 +746,7 @@ public class HeapInternalTimerServiceTest {
 			keyContext2,
 			processingTimeService2,
 			subKeyGroupRange2,
-			maxParallelism);
+			queueFactory);
 
 		processingTimeService1.setCurrentTime(10);
 		timerService1.advanceWatermark(10);
@@ -793,18 +806,19 @@ public class HeapInternalTimerServiceTest {
 		return result;
 	}
 
-	private static HeapInternalTimerService<Integer, String> createTimerService(
+	private static HeapInternalTimerService<Integer, String> createAndStartInternalTimerService(
 			Triggerable<Integer, String> triggerable,
 			KeyContext keyContext,
 			ProcessingTimeService processingTimeService,
 			KeyGroupRange keyGroupList,
-			int maxParallelism) {
-		HeapInternalTimerService<Integer, String> service =
-			new HeapInternalTimerService<>(
-					maxParallelism,
-					keyGroupList,
-					keyContext,
-					processingTimeService);
+			PriorityQueueSetFactory priorityQueueSetFactory) {
+		HeapInternalTimerService<Integer, String> service = createInternalTimerService(
+			keyGroupList,
+			keyContext,
+			processingTimeService,
+			IntSerializer.INSTANCE,
+			StringSerializer.INSTANCE,
+			priorityQueueSetFactory);
 
 		service.startTimerService(IntSerializer.INSTANCE, StringSerializer.INSTANCE, triggerable);
 		return service;
@@ -817,15 +831,16 @@ public class HeapInternalTimerServiceTest {
 			KeyContext keyContext,
 			ProcessingTimeService processingTimeService,
 			KeyGroupRange keyGroupsList,
-			int maxParallelism) throws Exception {
+			PriorityQueueSetFactory priorityQueueSetFactory) throws Exception {
 
 		// create an empty service
-		HeapInternalTimerService<Integer, String> service =
-			new HeapInternalTimerService<>(
-					maxParallelism,
-					keyGroupsList,
-					keyContext,
-					processingTimeService);
+		HeapInternalTimerService<Integer, String> service = createInternalTimerService(
+			keyGroupsList,
+			keyContext,
+			processingTimeService,
+			IntSerializer.INSTANCE,
+			StringSerializer.INSTANCE,
+			priorityQueueSetFactory);
 
 		// restore the timers
 		for (Integer keyGroupIndex : keyGroupsList) {
@@ -846,6 +861,14 @@ public class HeapInternalTimerServiceTest {
 		return service;
 	}
 
+	private PriorityQueueSetFactory createQueueFactory() {
+		return createQueueFactory(testKeyGroupRange, maxParallelism);
+	}
+
+	protected PriorityQueueSetFactory createQueueFactory(KeyGroupRange keyGroupRange, int numKeyGroups) {
+		return new HeapPriorityQueueSetFactory(keyGroupRange, numKeyGroups, 128);
+	}
+
 	// ------------------------------------------------------------------------
 	//  Parametrization for testing with different key-group ranges
 	// ------------------------------------------------------------------------
@@ -862,4 +885,33 @@ public class HeapInternalTimerServiceTest {
 						{2, 5, 6}
 		});
 	}
+
+	private static <K, N> HeapInternalTimerService<K, N> createInternalTimerService(
+		KeyGroupRange keyGroupsList,
+		KeyContext keyContext,
+		ProcessingTimeService processingTimeService,
+		TypeSerializer<K> keySerializer,
+		TypeSerializer<N> namespaceSerializer,
+		PriorityQueueSetFactory priorityQueueSetFactory) {
+
+		TimerSerializer<K, N> timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer);
+
+		return new HeapInternalTimerService<>(
+			keyGroupsList,
+			keyContext,
+			processingTimeService,
+			createTimerQueue("__test_processing_timers", timerSerializer, priorityQueueSetFactory),
+			createTimerQueue("__test_event_timers", timerSerializer, priorityQueueSetFactory));
+	}
+
+	private static <K, N> KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<K, N>> createTimerQueue(
+		String name,
+		TimerSerializer<K, N> timerSerializer,
+		PriorityQueueSetFactory priorityQueueSetFactory) {
+		return priorityQueueSetFactory.create(
+			name,
+			timerSerializer,
+			InternalTimer.getTimerComparator(),
+			InternalTimer.getKeyExtractorFunction());
+	}
 }


[2/2] flink git commit: [FLINK-9486][state] Introduce InternalPriorityQueue as state in keyed state backends

Posted by sr...@apache.org.
[FLINK-9486][state] Introduce InternalPriorityQueue as state in keyed state backends

This commit does not include the integration with checkpointing.

This closes #6276.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/79b38f8f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/79b38f8f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/79b38f8f

Branch: refs/heads/master
Commit: 79b38f8f9a79b917d525842cf46087c5b8c40f3d
Parents: b12acea
Author: Stefan Richter <s....@data-artisans.com>
Authored: Wed Jul 4 13:43:49 2018 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Mon Jul 9 16:12:51 2018 +0200

----------------------------------------------------------------------
 .../KVStateRequestSerializerRocksDBTest.java    |  16 +-
 .../network/KvStateRequestSerializerTest.java   |  19 +-
 .../runtime/state/InternalPriorityQueue.java    |  12 +
 .../state/KeyGroupedInternalPriorityQueue.java  |  38 ++++
 .../flink/runtime/state/KeyedStateBackend.java  |   3 +-
 .../flink/runtime/state/PriorityComparator.java |  42 ++++
 .../runtime/state/PriorityQueueSetFactory.java  |  46 ++++
 .../state/TieBreakingPriorityComparator.java    | 122 ++++++++++
 .../state/filesystem/FsStateBackend.java        |   6 +-
 .../heap/CachingInternalPriorityQueueSet.java   |  26 ++-
 .../state/heap/HeapKeyedStateBackend.java       |  28 ++-
 .../runtime/state/heap/HeapPriorityQueue.java   |  35 ++-
 .../state/heap/HeapPriorityQueueSet.java        |  46 ++--
 .../state/heap/HeapPriorityQueueSetFactory.java |  69 ++++++
 .../heap/KeyGroupPartitionedPriorityQueue.java  |  63 ++++--
 .../runtime/state/heap/TreeOrderedSetCache.java |   7 +
 .../state/memory/MemoryStateBackend.java        |   7 +-
 .../state/InternalPriorityQueueTestBase.java    |  12 +-
 .../state/StateSnapshotCompressionTest.java     |  12 +-
 .../state/heap/HeapPriorityQueueSetTest.java    |   2 +-
 .../state/heap/HeapPriorityQueueTest.java       |   2 +-
 .../state/heap/HeapStateBackendTestBase.java    |  10 +-
 .../KeyGroupPartitionedPriorityQueueTest.java   |   2 +-
 .../streaming/state/RockDBBackendOptions.java   |  38 ++++
 .../state/RocksDBKeyedStateBackend.java         | 171 +++++++++++++-
 .../streaming/state/RocksDBOrderedSetStore.java |  13 +-
 .../streaming/state/RocksDBStateBackend.java    |  22 +-
 ...nalPriorityQueueSetWithRocksDBStoreTest.java |   1 -
 .../state/RocksDBOrderedSetStoreTest.java       |   1 -
 .../state/RocksDBStateBackendTest.java          |   3 +-
 .../api/operators/AbstractStreamOperator.java   |   6 +-
 .../api/operators/HeapInternalTimerService.java |  85 +++----
 .../operators/InternalTimeServiceManager.java   |  80 ++++---
 .../streaming/api/operators/InternalTimer.java  |  22 ++
 .../InternalTimerServiceSerializationProxy.java |  65 +++---
 .../StreamTaskStateInitializerImpl.java         |   1 +
 .../api/operators/TimerHeapInternalTimer.java   |  23 --
 .../api/operators/TimerSerializer.java          | 222 +++++++++++++++++++
 .../operators/HeapInternalTimerServiceTest.java | 138 ++++++++----
 39 files changed, 1234 insertions(+), 282 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
----------------------------------------------------------------------
diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
index a49fdd2..9ea3198 100644
--- a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
+++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.contrib.streaming.state.PredefinedOptions;
 import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
 import org.apache.flink.queryablestate.client.VoidNamespace;
 import org.apache.flink.queryablestate.client.VoidNamespaceSerializer;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
@@ -74,9 +75,12 @@ public final class KVStateRequestSerializerRocksDBTest {
 				columnFamilyOptions,
 				mock(TaskKvStateRegistry.class),
 				LongSerializer.INSTANCE,
-				1, new KeyGroupRange(0, 0),
-				new ExecutionConfig(), false,
-				TestLocalRecoveryConfig.disabled()
+				1,
+				new KeyGroupRange(0, 0),
+				new ExecutionConfig(),
+				false,
+				TestLocalRecoveryConfig.disabled(),
+				RocksDBStateBackend.PriorityQueueStateType.HEAP
 			);
 		longHeapKeyedStateBackend.restore(null);
 		longHeapKeyedStateBackend.setCurrentKey(key);
@@ -112,10 +116,12 @@ public final class KVStateRequestSerializerRocksDBTest {
 				columnFamilyOptions,
 				mock(TaskKvStateRegistry.class),
 				LongSerializer.INSTANCE,
-				1, new KeyGroupRange(0, 0),
+				1,
+				new KeyGroupRange(0, 0),
 				new ExecutionConfig(),
 				false,
-				TestLocalRecoveryConfig.disabled());
+				TestLocalRecoveryConfig.disabled(),
+				RocksDBStateBackend.PriorityQueueStateType.HEAP);
 		longHeapKeyedStateBackend.restore(null);
 		longHeapKeyedStateBackend.setCurrentKey(key);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
----------------------------------------------------------------------
diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
index 2ba7507..73f8831 100644
--- a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
+++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.TestLocalRecoveryConfig;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.runtime.state.internal.InternalMapState;
@@ -185,18 +186,19 @@ public class KvStateRequestSerializerTest {
 	@Test
 	public void testListSerialization() throws Exception {
 		final long key = 0L;
-
+		final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 0);
 		// objects for heap state list serialisation
 		final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend =
 			new HeapKeyedStateBackend<>(
 				mock(TaskKvStateRegistry.class),
 				LongSerializer.INSTANCE,
 				ClassLoader.getSystemClassLoader(),
-				1,
-				new KeyGroupRange(0, 0),
+				keyGroupRange.getNumberOfKeyGroups(),
+				keyGroupRange,
 				async,
 				new ExecutionConfig(),
-				TestLocalRecoveryConfig.disabled()
+				TestLocalRecoveryConfig.disabled(),
+				new HeapPriorityQueueSetFactory(keyGroupRange, keyGroupRange.getNumberOfKeyGroups(), 128)
 			);
 		longHeapKeyedStateBackend.setCurrentKey(key);
 
@@ -292,18 +294,19 @@ public class KvStateRequestSerializerTest {
 	@Test
 	public void testMapSerialization() throws Exception {
 		final long key = 0L;
-
+		final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 0);
 		// objects for heap state list serialisation
 		final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend =
 			new HeapKeyedStateBackend<>(
 				mock(TaskKvStateRegistry.class),
 				LongSerializer.INSTANCE,
 				ClassLoader.getSystemClassLoader(),
-				1,
-				new KeyGroupRange(0, 0),
+				keyGroupRange.getNumberOfKeyGroups(),
+				keyGroupRange,
 				async,
 				new ExecutionConfig(),
-				TestLocalRecoveryConfig.disabled()
+				TestLocalRecoveryConfig.disabled(),
+				new HeapPriorityQueueSetFactory(keyGroupRange, keyGroupRange.getNumberOfKeyGroups(), 128)
 			);
 		longHeapKeyedStateBackend.setCurrentKey(key);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java
index fb3ee82..dc46c8a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/InternalPriorityQueue.java
@@ -26,6 +26,8 @@ import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.util.Collection;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 /**
  * Interface for collection that gives in order access to elements w.r.t their priority.
@@ -36,6 +38,16 @@ import java.util.Collection;
 public interface InternalPriorityQueue<T> {
 
 	/**
+	 * Polls from the top of the queue as long as the the queue is not empty and passes the elements to
+	 * {@link Consumer} until a {@link Predicate} rejects an offered element. The rejected element is not
+	 * removed from the queue and becomes the new head.
+	 *
+	 * @param canConsume bulk polling ends once this returns false. The rejected element is nor removed and not consumed.
+	 * @param consumer consumer function for elements accepted by canConsume.
+	 */
+	void bulkPoll(@Nonnull Predicate<T> canConsume, @Nonnull Consumer<T> consumer);
+
+	/**
 	 * Retrieves and removes the first element (w.r.t. the order) of this set,
 	 * or returns {@code null} if this set is empty.
 	 *

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java
new file mode 100644
index 0000000..68472e2
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupedInternalPriorityQueue.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import javax.annotation.Nonnull;
+
+import java.util.Set;
+
+/**
+ * This interface exists as (temporary) adapter between the new {@link InternalPriorityQueue} and the old way in which
+ * timers are written in a snapshot. This interface can probably go away once timer state becomes part of the
+ * keyed state backend snapshot.
+ */
+public interface KeyGroupedInternalPriorityQueue<T> extends InternalPriorityQueue<T> {
+
+	/**
+	 * Returns the subset of elements in the priority queue that belongs to the given key-group, within the operator's
+	 * key-group range.
+	 */
+	@Nonnull
+	Set<T> getSubsetForKeyGroup(int keyGroupId);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index ad75a1f..7ba14b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -31,7 +31,8 @@ import java.util.stream.Stream;
  *
  * @param <K> The key by which state is keyed.
  */
-public interface KeyedStateBackend<K> extends InternalKeyContext<K>, KeyedStateFactory, Disposable {
+public interface KeyedStateBackend<K>
+	extends InternalKeyContext<K>, KeyedStateFactory, PriorityQueueSetFactory, Disposable {
 
 	/**
 	 * Sets the current key that is used for partitioned state.

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java
new file mode 100644
index 0000000..2f6f5a7
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityComparator.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * This interface works similar to {@link Comparable} and is used to prioritize between two objects. The main difference
+ * between this interface and {@link Comparable} is it is not require to follow the usual contract between that
+ * {@link Comparable#compareTo(Object)} and {@link Object#equals(Object)}. The contract of this interface is:
+ * When two objects are equal, they indicate the same priority, but indicating the same priority does not require that
+ * both objects are equal.
+ *
+ * @param <T> type of the compared objects.
+ */
+@FunctionalInterface
+public interface PriorityComparator<T> {
+
+	/**
+	 * Compares two objects for priority. Returns a negative integer, zero, or a positive integer as the first
+	 * argument has lower, equal to, or higher priority than the second.
+	 * @param left left operand in the comparison by priority.
+	 * @param right left operand in the comparison by priority.
+	 * @return a negative integer, zero, or a positive integer as the first argument has lower, equal to, or higher
+	 * priority than the second.
+	 */
+	int comparePriority(T left, T right);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
new file mode 100644
index 0000000..6f509c0
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PriorityQueueSetFactory.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
+
+import javax.annotation.Nonnull;
+
+/**
+ * Factory for {@link KeyGroupedInternalPriorityQueue} instances.
+ */
+public interface PriorityQueueSetFactory {
+
+	/**
+	 * Creates a {@link KeyGroupedInternalPriorityQueue}.
+	 *
+	 * @param stateName                    unique name for associated with this queue.
+	 * @param byteOrderedElementSerializer a serializer that with a format that is lexicographically ordered in
+	 *                                     alignment with elementPriorityComparator.
+	 * @param <T>                          type of the stored elements.
+	 * @return the queue with the specified unique name.
+	 */
+	@Nonnull
+	<T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+		@Nonnull String stateName,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
+		@Nonnull PriorityComparator<T> elementPriorityComparator,
+		@Nonnull KeyExtractorFunction<T> keyExtractor);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java
new file mode 100644
index 0000000..4384eb7
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TieBreakingPriorityComparator.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.util.Comparator;
+
+/**
+ * This class is an adapter between {@link PriorityComparator} and a full {@link Comparator} that respects the
+ * contract between {@link Comparator#compare(Object, Object)} and {@link Object#equals(Object)}. This is currently
+ * needed for implementations of
+ * {@link org.apache.flink.runtime.state.heap.CachingInternalPriorityQueueSet.OrderedSetCache} that are implemented
+ * on top of a data structure that relies on the this contract, e.g. a tree set. We should replace this in the near
+ * future.
+ *
+ * @param <T> type of the compared elements.
+ */
+public class TieBreakingPriorityComparator<T> implements Comparator<T>, PriorityComparator<T> {
+
+	/** The {@link PriorityComparator} to which we delegate in a first step. */
+	@Nonnull
+	private final PriorityComparator<T> priorityComparator;
+
+	/** Serializer for instances of the compared objects. */
+	@Nonnull
+	private final TypeSerializer<T> serializer;
+
+	/** Stream that we use in serialization. */
+	@Nonnull
+	private final ByteArrayOutputStreamWithPos outStream;
+
+	/** {@link org.apache.flink.core.memory.DataOutputView} around outStream. */
+	@Nonnull
+	private final DataOutputViewStreamWrapper outView;
+
+	public TieBreakingPriorityComparator(
+		@Nonnull PriorityComparator<T> priorityComparator,
+		@Nonnull TypeSerializer<T> serializer,
+		@Nonnull ByteArrayOutputStreamWithPos outStream,
+		@Nonnull DataOutputViewStreamWrapper outView) {
+
+		this.priorityComparator = priorityComparator;
+		this.serializer = serializer;
+		this.outStream = outStream;
+		this.outView = outView;
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	public int compare(T o1, T o2) {
+
+		// first we compare priority, this should be the most commonly hit case
+		int cmp = priorityComparator.comparePriority(o1, o2);
+
+		if (cmp != 0) {
+			return cmp;
+		}
+
+		// here we start tie breaking and do our best to comply with the compareTo/equals contract, first we try
+		// to simply find an existing way to fully compare.
+		if (o1 instanceof Comparable && o1.getClass().equals(o2.getClass())) {
+			return ((Comparable<T>) o1).compareTo(o2);
+		}
+
+		// we catch this case before moving to more expensive tie breaks.
+		if (o1.equals(o2)) {
+			return 0;
+		}
+
+		// if objects are not equal, their serialized form should somehow differ as well. this can be costly, and...
+		// TODO we should have an alternative approach in the future, e.g. a cache that does not rely on compare to check equality.
+		try {
+			outStream.reset();
+			serializer.serialize(o1, outView);
+			int leftLen = outStream.getPosition();
+			serializer.serialize(o2, outView);
+			int rightLen = outStream.getPosition() - leftLen;
+			return compareBytes(outStream.getBuf(), 0, leftLen, leftLen, rightLen);
+		} catch (IOException ex) {
+			throw new FlinkRuntimeException("Serializer problem in comparator.", ex);
+		}
+	}
+
+	@Override
+	public int comparePriority(T left, T right) {
+		return priorityComparator.comparePriority(left, right);
+	}
+
+	public static int compareBytes(byte[] bytes, int offLeft, int leftLen, int offRight, int rightLen) {
+		int maxLen = Math.min(leftLen, rightLen);
+		for (int i = 0; i < maxLen; ++i) {
+			int cmp = bytes[offLeft + i] - bytes[offRight + i];
+			if (cmp != 0) {
+				return cmp;
+			}
+		}
+		return leftLen - rightLen;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 637effd..ad1581b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -36,6 +36,7 @@ import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.util.TernaryBoolean;
 
 import org.slf4j.LoggerFactory;
@@ -457,6 +458,8 @@ public class FsStateBackend extends AbstractFileStateBackend implements Configur
 
 		TaskStateManager taskStateManager = env.getTaskStateManager();
 		LocalRecoveryConfig localRecoveryConfig = taskStateManager.createLocalRecoveryConfig();
+		HeapPriorityQueueSetFactory priorityQueueSetFactory =
+			new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128);
 
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,
@@ -466,7 +469,8 @@ public class FsStateBackend extends AbstractFileStateBackend implements Configur
 				keyGroupRange,
 				isUsingAsynchronousSnapshots(),
 				env.getExecutionConfig(),
-				localRecoveryConfig);
+				localRecoveryConfig,
+				priorityQueueSetFactory);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java
index 771315d..6dc8cf3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CachingInternalPriorityQueueSet.java
@@ -27,6 +27,8 @@ import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.util.Collection;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 /**
  * This class is an implementation of a {@link InternalPriorityQueue} with set semantics that internally consists of
@@ -76,6 +78,15 @@ public class CachingInternalPriorityQueueSet<E> implements InternalPriorityQueue
 		return orderedCache.peekFirst();
 	}
 
+	@Override
+	public void bulkPoll(@Nonnull Predicate<E> canConsume, @Nonnull Consumer<E> consumer) {
+		E element;
+		while ((element = peek()) != null && canConsume.test(element)) {
+			poll();
+			consumer.accept(element);
+		}
+	}
+
 	@Nullable
 	@Override
 	public E poll() {
@@ -158,7 +169,11 @@ public class CachingInternalPriorityQueueSet<E> implements InternalPriorityQueue
 	@Nonnull
 	@Override
 	public CloseableIterator<E> iterator() {
-		return orderedStore.orderedIterator();
+		if (storeOnlyElements) {
+			return orderedStore.orderedIterator();
+		} else {
+			return orderedCache.orderedIterator();
+		}
 	}
 
 	@Override
@@ -184,7 +199,7 @@ public class CachingInternalPriorityQueueSet<E> implements InternalPriorityQueue
 				}
 				storeOnlyElements = iterator.hasNext();
 			} catch (Exception e) {
-				throw new FlinkRuntimeException("Exception while closing RocksDB iterator.", e);
+				throw new FlinkRuntimeException("Exception while refilling store from iterator.", e);
 			}
 		}
 	}
@@ -249,6 +264,13 @@ public class CachingInternalPriorityQueueSet<E> implements InternalPriorityQueue
 		 */
 		@Nullable
 		E peekLast();
+
+		/**
+		 * Returns an iterator over the store that returns element in order. The iterator must be closed by the client
+		 * after usage.
+		 */
+		@Nonnull
+		CloseableIterator<E> orderedIterator();
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 82ce584..b5b2626 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -44,13 +44,17 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.DoneFuture;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateFunction;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.PriorityComparator;
+import org.apache.flink.runtime.state.PriorityQueueSetFactory;
 import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
@@ -102,6 +106,21 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			Tuple2.of(FoldingStateDescriptor.class, (StateFactory) HeapFoldingState::create)
 		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
+	@Nonnull
+	@Override
+	public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+		@Nonnull String stateName,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
+		@Nonnull PriorityComparator<T> elementPriorityComparator,
+		@Nonnull KeyExtractorFunction<T> keyExtractor) {
+
+		return priorityQueueSetFactory.create(
+			stateName,
+			byteOrderedElementSerializer,
+			elementPriorityComparator,
+			keyExtractor);
+	}
+
 	private interface StateFactory {
 		<K, N, SV, S extends State, IS extends S> IS createState(
 			StateDescriptor<S, SV> stateDesc,
@@ -137,6 +156,11 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 */
 	private final HeapSnapshotStrategy snapshotStrategy;
 
+	/**
+	 * Factory for state that is organized as priority queue.
+	 */
+	private final PriorityQueueSetFactory priorityQueueSetFactory;
+
 	public HeapKeyedStateBackend(
 			TaskKvStateRegistry kvStateRegistry,
 			TypeSerializer<K> keySerializer,
@@ -145,7 +169,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			KeyGroupRange keyGroupRange,
 			boolean asynchronousSnapshots,
 			ExecutionConfig executionConfig,
-			LocalRecoveryConfig localRecoveryConfig) {
+			LocalRecoveryConfig localRecoveryConfig,
+			PriorityQueueSetFactory priorityQueueSetFactory) {
 
 		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange, executionConfig);
 		this.localRecoveryConfig = Preconditions.checkNotNull(localRecoveryConfig);
@@ -157,6 +182,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		this.snapshotStrategy = new HeapSnapshotStrategy(synchronicityTrait);
 		LOG.info("Initializing heap keyed state backend with stream factory.");
 		this.restoredKvStateMetaInfos = new HashMap<>();
+		this.priorityQueueSetFactory = priorityQueueSetFactory;
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
index 7017905..e5f610e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.runtime.state.InternalPriorityQueue;
+import org.apache.flink.runtime.state.PriorityComparator;
 import org.apache.flink.util.CloseableIterator;
 
 import javax.annotation.Nonnegative;
@@ -27,9 +28,10 @@ import javax.annotation.Nullable;
 
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.Comparator;
 import java.util.Iterator;
 import java.util.NoSuchElementException;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 import static org.apache.flink.util.CollectionUtil.MAX_ARRAY_SIZE;
 
@@ -56,9 +58,9 @@ public class HeapPriorityQueue<T extends HeapPriorityQueueElement> implements In
 	private static final int QUEUE_HEAD_INDEX = 1;
 
 	/**
-	 * Comparator for the contained elements.
+	 * Comparator for the priority of contained elements.
 	 */
-	private final Comparator<T> elementComparator;
+	private final PriorityComparator<T> elementPriorityComparator;
 
 	/**
 	 * The array that represents the heap-organized priority queue.
@@ -73,19 +75,28 @@ public class HeapPriorityQueue<T extends HeapPriorityQueueElement> implements In
 	/**
 	 * Creates an empty {@link HeapPriorityQueue} with the requested initial capacity.
 	 *
-	 * @param elementComparator comparator for the contained elements.
+	 * @param elementPriorityComparator comparator for the priority of contained elements.
 	 * @param minimumCapacity the minimum and initial capacity of this priority queue.
 	 */
 	@SuppressWarnings("unchecked")
 	public HeapPriorityQueue(
-		@Nonnull Comparator<T> elementComparator,
+		@Nonnull PriorityComparator<T> elementPriorityComparator,
 		@Nonnegative int minimumCapacity) {
 
-		this.elementComparator = elementComparator;
+		this.elementPriorityComparator = elementPriorityComparator;
 		this.queue = (T[]) new HeapPriorityQueueElement[QUEUE_HEAD_INDEX + minimumCapacity];
 	}
 
 	@Override
+	public void bulkPoll(@Nonnull Predicate<T> canConsume, @Nonnull Consumer<T> consumer) {
+		T element;
+		while ((element = peek()) != null && canConsume.test(element)) {
+			poll();
+			consumer.accept(element);
+		}
+	}
+
+	@Override
 	@Nullable
 	public T poll() {
 		return size() > 0 ? removeElementAtIndex(QUEUE_HEAD_INDEX) : null;
@@ -227,7 +238,7 @@ public class HeapPriorityQueue<T extends HeapPriorityQueueElement> implements In
 		final T currentElement = heap[idx];
 		int parentIdx = idx >>> 1;
 
-		while (parentIdx > 0 && isElementLessThen(currentElement, heap[parentIdx])) {
+		while (parentIdx > 0 && isElementPriorityLessThen(currentElement, heap[parentIdx])) {
 			moveElementToIdx(heap[parentIdx], idx);
 			idx = parentIdx;
 			parentIdx >>>= 1;
@@ -245,19 +256,19 @@ public class HeapPriorityQueue<T extends HeapPriorityQueueElement> implements In
 		int secondChildIdx = firstChildIdx + 1;
 
 		if (isElementIndexValid(secondChildIdx, heapSize) &&
-			isElementLessThen(heap[secondChildIdx], heap[firstChildIdx])) {
+			isElementPriorityLessThen(heap[secondChildIdx], heap[firstChildIdx])) {
 			firstChildIdx = secondChildIdx;
 		}
 
 		while (isElementIndexValid(firstChildIdx, heapSize) &&
-			isElementLessThen(heap[firstChildIdx], currentElement)) {
+			isElementPriorityLessThen(heap[firstChildIdx], currentElement)) {
 			moveElementToIdx(heap[firstChildIdx], idx);
 			idx = firstChildIdx;
 			firstChildIdx = idx << 1;
 			secondChildIdx = firstChildIdx + 1;
 
 			if (isElementIndexValid(secondChildIdx, heapSize) &&
-				isElementLessThen(heap[secondChildIdx], heap[firstChildIdx])) {
+				isElementPriorityLessThen(heap[secondChildIdx], heap[firstChildIdx])) {
 				firstChildIdx = secondChildIdx;
 			}
 		}
@@ -269,8 +280,8 @@ public class HeapPriorityQueue<T extends HeapPriorityQueueElement> implements In
 		return elementIndex <= heapSize;
 	}
 
-	private boolean isElementLessThen(T a, T b) {
-		return elementComparator.compare(a, b) < 0;
+	private boolean isElementPriorityLessThen(T a, T b) {
+		return elementPriorityComparator.comparePriority(a, b) < 0;
 	}
 
 	private void moveElementToIdx(T element, int idx) {

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java
index 61313e9..79f319c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSet.java
@@ -18,20 +18,17 @@
 
 package org.apache.flink.runtime.state.heap;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.PriorityComparator;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Set;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
@@ -49,7 +46,9 @@ import static org.apache.flink.util.Preconditions.checkArgument;
  *
  * @param <T> type of the contained elements.
  */
-public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement> extends HeapPriorityQueue<T> {
+public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement>
+	extends HeapPriorityQueue<T>
+	implements KeyGroupedInternalPriorityQueue<T> {
 
 	/**
 	 * Function to extract the key from contained elements.
@@ -74,7 +73,7 @@ public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement> extends He
 	/**
 	 * Creates an empty {@link HeapPriorityQueueSet} with the requested initial capacity.
 	 *
-	 * @param elementComparator comparator for the contained elements.
+	 * @param elementPriorityComparator comparator for the priority of contained elements.
 	 * @param keyExtractor function to extract a key from the contained elements.
 	 * @param minimumCapacity the minimum and initial capacity of this priority queue.
 	 * @param keyGroupRange the key-group range of the elements in this set.
@@ -82,13 +81,13 @@ public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement> extends He
 	 */
 	@SuppressWarnings("unchecked")
 	public HeapPriorityQueueSet(
-		@Nonnull Comparator<T> elementComparator,
+		@Nonnull PriorityComparator<T> elementPriorityComparator,
 		@Nonnull KeyExtractorFunction<T> keyExtractor,
 		@Nonnegative int minimumCapacity,
 		@Nonnull KeyGroupRange keyGroupRange,
 		@Nonnegative int totalNumberOfKeyGroups) {
 
-		super(elementComparator, minimumCapacity);
+		super(elementPriorityComparator, minimumCapacity);
 
 		this.keyExtractor = keyExtractor;
 
@@ -147,28 +146,9 @@ public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement> extends He
 		}
 	}
 
-	/**
-	 * Returns an unmodifiable set of all elements in the given key-group.
-	 */
-	@Nonnull
-	public Set<T> getElementsForKeyGroup(@Nonnegative int keyGroupIdx) {
-		return Collections.unmodifiableSet(getDedupMapForKeyGroup(keyGroupIdx).keySet());
-	}
-
-	@VisibleForTesting
-	@SuppressWarnings("unchecked")
-	@Nonnull
-	public List<Set<T>> getElementsByKeyGroup() {
-		List<Set<T>> result = new ArrayList<>(deduplicationMapsByKeyGroup.length);
-		for (int i = 0; i < deduplicationMapsByKeyGroup.length; ++i) {
-			result.add(i, Collections.unmodifiableSet(deduplicationMapsByKeyGroup[i].keySet()));
-		}
-		return result;
-	}
-
 	private HashMap<T, T> getDedupMapForKeyGroup(
-		@Nonnegative int keyGroupIdx) {
-		return deduplicationMapsByKeyGroup[globalKeyGroupToLocalIndex(keyGroupIdx)];
+		@Nonnegative int keyGroupId) {
+		return deduplicationMapsByKeyGroup[globalKeyGroupToLocalIndex(keyGroupId)];
 	}
 
 	private HashMap<T, T> getDedupMapForElement(T element) {
@@ -182,4 +162,10 @@ public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement> extends He
 		checkArgument(keyGroupRange.contains(keyGroup));
 		return keyGroup - keyGroupRange.getStartKeyGroup();
 	}
+
+	@Nonnull
+	@Override
+	public Set<T> getSubsetForKeyGroup(int keyGroupId) {
+		return getDedupMapForKeyGroup(keyGroupId).keySet();
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
new file mode 100644
index 0000000..ee6fda9
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.PriorityComparator;
+import org.apache.flink.runtime.state.PriorityQueueSetFactory;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ *
+ */
+public class HeapPriorityQueueSetFactory implements PriorityQueueSetFactory {
+
+	@Nonnull
+	private final KeyGroupRange keyGroupRange;
+
+	@Nonnegative
+	private final int totalKeyGroups;
+
+	@Nonnegative
+	private final int minimumCapacity;
+
+	public HeapPriorityQueueSetFactory(
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int totalKeyGroups,
+		@Nonnegative int minimumCapacity) {
+
+		this.keyGroupRange = keyGroupRange;
+		this.totalKeyGroups = totalKeyGroups;
+		this.minimumCapacity = minimumCapacity;
+	}
+
+	@Nonnull
+	@Override
+	public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+		@Nonnull String stateName,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
+		@Nonnull PriorityComparator<T> elementPriorityComparator,
+		@Nonnull KeyExtractorFunction<T> keyExtractor) {
+		return new HeapPriorityQueueSet<>(
+			elementPriorityComparator,
+			keyExtractor,
+			minimumCapacity,
+			keyGroupRange,
+			totalKeyGroups);
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
index af4d54f..6f4f911 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
@@ -22,7 +22,10 @@ import org.apache.flink.runtime.state.InternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.PriorityComparator;
 import org.apache.flink.util.CloseableIterator;
+import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.IOUtils;
 
 import javax.annotation.Nonnegative;
@@ -30,7 +33,10 @@ import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.util.Collection;
-import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 /**
  * This implementation of {@link InternalPriorityQueue} is internally partitioned into sub-queues per key-group and
@@ -41,7 +47,7 @@ import java.util.Comparator;
  * @param <PQ> type type of sub-queue used for each key-group partition.
  */
 public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueue<T> & HeapPriorityQueueElement>
-	implements InternalPriorityQueue<T> {
+	implements InternalPriorityQueue<T>, KeyGroupedInternalPriorityQueue<T> {
 
 	/** A heap of heap sets. Each sub-heap represents the partition for a key-group.*/
 	@Nonnull
@@ -66,7 +72,7 @@ public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueu
 	@SuppressWarnings("unchecked")
 	public KeyGroupPartitionedPriorityQueue(
 		@Nonnull KeyExtractorFunction<T> keyExtractor,
-		@Nonnull Comparator<T> elementComparator,
+		@Nonnull PriorityComparator<T> elementPriorityComparator,
 		@Nonnull PartitionQueueSetFactory<T, PQ> orderedCacheFactory,
 		@Nonnull KeyGroupRange keyGroupRange,
 		@Nonnegative int totalKeyGroups) {
@@ -76,16 +82,25 @@ public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueu
 		this.firstKeyGroup = keyGroupRange.getStartKeyGroup();
 		this.keyGroupedHeaps = (PQ[]) new InternalPriorityQueue[keyGroupRange.getNumberOfKeyGroups()];
 		this.heapOfkeyGroupedHeaps = new HeapPriorityQueue<>(
-			new InternalPriorityQueueComparator<>(elementComparator),
+			new InternalPriorityQueueComparator<>(elementPriorityComparator),
 			keyGroupRange.getNumberOfKeyGroups());
 		for (int i = 0; i < keyGroupedHeaps.length; i++) {
 			final PQ keyGroupSubHeap =
-				orderedCacheFactory.create(firstKeyGroup + i, totalKeyGroups, elementComparator);
+				orderedCacheFactory.create(firstKeyGroup + i, totalKeyGroups, elementPriorityComparator);
 			keyGroupedHeaps[i] = keyGroupSubHeap;
 			heapOfkeyGroupedHeaps.add(keyGroupSubHeap);
 		}
 	}
 
+	@Override
+	public void bulkPoll(@Nonnull Predicate<T> canConsume, @Nonnull Consumer<T> consumer) {
+		T element;
+		while ((element = peek()) != null && canConsume.test(element)) {
+			poll();
+			consumer.accept(element);
+		}
+	}
+
 	@Nullable
 	@Override
 	public T poll() {
@@ -173,9 +188,28 @@ public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueu
 	private int computeKeyGroupIndex(T element) {
 		final Object extractKeyFromElement = keyExtractor.extractKeyFromElement(element);
 		final int keyGroupId = KeyGroupRangeAssignment.assignToKeyGroup(extractKeyFromElement, totalKeyGroups);
+		return globalKeyGroupToLocalIndex(keyGroupId);
+	}
+
+	private int globalKeyGroupToLocalIndex(int keyGroupId) {
 		return keyGroupId - firstKeyGroup;
 	}
 
+	@Nonnull
+	@Override
+	public Set<T> getSubsetForKeyGroup(int keyGroupId) {
+		HashSet<T> result = new HashSet<>();
+		PQ partitionQueue = keyGroupedHeaps[globalKeyGroupToLocalIndex(keyGroupId)];
+		try (CloseableIterator<T> iterator = partitionQueue.iterator()) {
+			while (iterator.hasNext()) {
+				result.add(iterator.next());
+			}
+		} catch (Exception e) {
+			throw new FlinkRuntimeException("Exception while iterating key group.", e);
+		}
+		return result;
+	}
+
 	/**
 	 * Iterator for {@link KeyGroupPartitionedPriorityQueue}. This iterator is not guaranteeing any order of elements.
 	 * Using code must {@link #close()} after usage.
@@ -236,24 +270,24 @@ public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueu
 	 * @param <Q> type of queue.
 	 */
 	private static final class InternalPriorityQueueComparator<T, Q extends InternalPriorityQueue<T>>
-		implements Comparator<Q> {
+		implements PriorityComparator<Q> {
 
 		/** Comparator for the queue elements, so we can compare their heads. */
 		@Nonnull
-		private final Comparator<T> elementComparator;
+		private final PriorityComparator<T> elementPriorityComparator;
 
-		InternalPriorityQueueComparator(@Nonnull Comparator<T> elementComparator) {
-			this.elementComparator = elementComparator;
+		InternalPriorityQueueComparator(@Nonnull PriorityComparator<T> elementPriorityComparator) {
+			this.elementPriorityComparator = elementPriorityComparator;
 		}
 
 		@Override
-		public int compare(Q o1, Q o2) {
+		public int comparePriority(Q o1, Q o2) {
 			final T left = o1.peek();
 			final T right = o2.peek();
 			if (left == null) {
 				return (right == null ? 0 : 1);
 			} else {
-				return (right == null ? -1 : elementComparator.compare(left, right));
+				return (right == null ? -1 : elementPriorityComparator.comparePriority(left, right));
 			}
 		}
 	}
@@ -271,10 +305,13 @@ public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueu
 		 *
 		 * @param keyGroupId the key-group of the elements managed by the produced queue.
 		 * @param numKeyGroups the total number of key-groups in the job.
-		 * @param elementComparator the comparator that determines the order of the managed elements.
+		 * @param elementPriorityComparator the comparator that determines the order of managed elements by priority.
 		 * @return a new queue for the given key-group.
 		 */
 		@Nonnull
-		PQS create(@Nonnegative int keyGroupId, @Nonnegative int numKeyGroups, @Nonnull Comparator<T> elementComparator);
+		PQS create(
+			@Nonnegative int keyGroupId,
+			@Nonnegative int numKeyGroups,
+			@Nonnull PriorityComparator<T> elementPriorityComparator);
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java
index 0e7d9dd..14c281e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/TreeOrderedSetCache.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.heap;
 
+import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.Preconditions;
 
 import it.unimi.dsi.fastutil.objects.ObjectAVLTreeSet;
@@ -125,4 +126,10 @@ public class TreeOrderedSetCache<E> implements CachingInternalPriorityQueueSet.O
 	public E peekLast() {
 		return !avlTree.isEmpty() ? avlTree.last() : null;
 	}
+
+	@Nonnull
+	@Override
+	public CloseableIterator<E> orderedIterator() {
+		return CloseableIterator.adapterForIterator(avlTree.iterator());
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 3da60e4..d78944c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.state.filesystem.AbstractFileStateBackend;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.util.TernaryBoolean;
 
 import javax.annotation.Nullable;
@@ -309,7 +310,8 @@ public class MemoryStateBackend extends AbstractFileStateBackend implements Conf
 			TaskKvStateRegistry kvStateRegistry) {
 
 		TaskStateManager taskStateManager = env.getTaskStateManager();
-
+		HeapPriorityQueueSetFactory priorityQueueSetFactory =
+			new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128);
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,
 				keySerializer,
@@ -318,7 +320,8 @@ public class MemoryStateBackend extends AbstractFileStateBackend implements Conf
 				keyGroupRange,
 				isUsingAsynchronousSnapshots(),
 				env.getExecutionConfig(),
-				taskStateManager.createLocalRecoveryConfig());
+				taskStateManager.createLocalRecoveryConfig(),
+				priorityQueueSetFactory);
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
index c0c3ba4..0cd551c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
@@ -51,8 +51,16 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger {
 
 	protected static final KeyGroupRange KEY_GROUP_RANGE = new KeyGroupRange(0, 2);
 	protected static final KeyExtractorFunction<TestElement> KEY_EXTRACTOR_FUNCTION = TestElement::getKey;
-	protected static final Comparator<TestElement> TEST_ELEMENT_COMPARATOR =
-		Comparator.comparingLong(TestElement::getPriority).thenComparingLong(TestElement::getKey);
+	protected static final PriorityComparator<TestElement> TEST_ELEMENT_PRIORITY_COMPARATOR =
+		(left, right) -> Long.compare(left.getPriority(), right.getPriority());
+	protected static final Comparator<TestElement> TEST_ELEMENT_COMPARATOR = (o1, o2) -> {
+		int priorityCmp = TEST_ELEMENT_PRIORITY_COMPARATOR.comparePriority(o1, o2);
+		if (priorityCmp != 0) {
+			return priorityCmp;
+		}
+		// to fully comply with compareTo/equals contract.
+		return Long.compare(o1.getKey(), o2.getKey());
+	};
 
 	protected static void insertRandomElements(
 		@Nonnull InternalPriorityQueue<TestElement> priorityQueue,

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
index 3c06b71..dfcdffc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
@@ -53,7 +53,8 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			new KeyGroupRange(0, 15),
 			true,
 			executionConfig,
-			TestLocalRecoveryConfig.disabled());
+			TestLocalRecoveryConfig.disabled(),
+			mock(PriorityQueueSetFactory.class));
 
 		try {
 			Assert.assertTrue(
@@ -75,7 +76,8 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			new KeyGroupRange(0, 15),
 			true,
 			executionConfig,
-			TestLocalRecoveryConfig.disabled());
+			TestLocalRecoveryConfig.disabled(),
+			mock(PriorityQueueSetFactory.class));
 
 		try {
 			Assert.assertTrue(
@@ -115,7 +117,8 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			new KeyGroupRange(0, 15),
 			true,
 			executionConfig,
-			TestLocalRecoveryConfig.disabled());
+			TestLocalRecoveryConfig.disabled(),
+			mock(PriorityQueueSetFactory.class));
 
 		try {
 
@@ -156,7 +159,8 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			new KeyGroupRange(0, 15),
 			true,
 			executionConfig,
-			TestLocalRecoveryConfig.disabled());
+			TestLocalRecoveryConfig.disabled(),
+			mock(PriorityQueueSetFactory.class));
 		try {
 
 			stateBackend.restore(StateObjectCollection.singleton(stateHandle));

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java
index 618da4e..415497d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetTest.java
@@ -25,7 +25,7 @@ public class HeapPriorityQueueSetTest extends HeapPriorityQueueTest {
 	@Override
 	protected HeapPriorityQueueSet<TestElement> newPriorityQueue(int initialCapacity) {
 		return new HeapPriorityQueueSet<>(
-			TEST_ELEMENT_COMPARATOR,
+			TEST_ELEMENT_PRIORITY_COMPARATOR,
 			KEY_EXTRACTOR_FUNCTION,
 			initialCapacity,
 			KEY_GROUP_RANGE,

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java
index 8ffb8b8..6ba5a68 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueTest.java
@@ -89,7 +89,7 @@ public class HeapPriorityQueueTest extends InternalPriorityQueueTestBase {
 
 	@Override
 	protected HeapPriorityQueue<TestElement> newPriorityQueue(int initialCapacity) {
-		return new HeapPriorityQueue<>(TEST_ELEMENT_COMPARATOR, initialCapacity);
+		return new HeapPriorityQueue<>(TEST_ELEMENT_PRIORITY_COMPARATOR, initialCapacity);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
index bf428dc..cf6aef4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
@@ -49,14 +49,18 @@ public abstract class HeapStateBackendTestBase {
 	}
 
 	public <K> HeapKeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
+		final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 15);
+		final int numKeyGroups = keyGroupRange.getNumberOfKeyGroups();
+
 		return new HeapKeyedStateBackend<>(
 			mock(TaskKvStateRegistry.class),
 			keySerializer,
 			HeapStateBackendTestBase.class.getClassLoader(),
-			16,
-			new KeyGroupRange(0, 15),
+			numKeyGroups,
+			keyGroupRange,
 			async,
 			new ExecutionConfig(),
-			TestLocalRecoveryConfig.disabled());
+			TestLocalRecoveryConfig.disabled(),
+			new HeapPriorityQueueSetFactory(keyGroupRange, numKeyGroups, 128));
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java
index 277de19..d348e10 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueueTest.java
@@ -29,7 +29,7 @@ public class KeyGroupPartitionedPriorityQueueTest extends InternalPriorityQueueT
 	protected InternalPriorityQueue<TestElement> newPriorityQueue(int initialCapacity) {
 		return new KeyGroupPartitionedPriorityQueue<>(
 			KEY_EXTRACTOR_FUNCTION,
-			TEST_ELEMENT_COMPARATOR,
+			TEST_ELEMENT_PRIORITY_COMPARATOR,
 			newFactory(initialCapacity),
 			KEY_GROUP_RANGE, KEY_GROUP_RANGE.getNumberOfKeyGroups());
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java
new file mode 100644
index 0000000..ede45e3
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RockDBBackendOptions.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.configuration.ConfigOptions;
+
+/**
+ * Configuration options for the RocksDB backend.
+ */
+public class RockDBBackendOptions {
+
+	/**
+	 * Choice of implementation for priority queue state (e.g. timers).
+	 */
+	public static final ConfigOption<String> PRIORITY_QUEUE_STATE_TYPE = ConfigOptions
+		.key("backend.rocksdb.priority_queue_state_type")
+		.defaultValue(RocksDBStateBackend.PriorityQueueStateType.HEAP.name())
+		.withDescription("This determines the implementation for the priority queue state (e.g. timers). Options are" +
+			"either " + RocksDBStateBackend.PriorityQueueStateType.HEAP.name() + " (heap-based, default) or " +
+			RocksDBStateBackend.PriorityQueueStateType.ROCKS.name() + " for in implementation based on RocksDB.");
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 21d2a65..f2430ae 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -58,14 +58,18 @@ import org.apache.flink.runtime.state.DirectoryStateHandle;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
 import org.apache.flink.runtime.state.IncrementalLocalKeyedStateHandle;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.LocalRecoveryDirectoryProvider;
 import org.apache.flink.runtime.state.PlaceholderStreamStateHandle;
+import org.apache.flink.runtime.state.PriorityComparator;
+import org.apache.flink.runtime.state.PriorityQueueSetFactory;
 import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotDirectory;
@@ -76,7 +80,13 @@ import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TieBreakingPriorityComparator;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
+import org.apache.flink.runtime.state.heap.CachingInternalPriorityQueueSet;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
+import org.apache.flink.runtime.state.heap.KeyGroupPartitionedPriorityQueue;
+import org.apache.flink.runtime.state.heap.TreeOrderedSetCache;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.FileUtils;
 import org.apache.flink.util.FlinkRuntimeException;
@@ -243,6 +253,9 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/** The snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */
 	private final SnapshotStrategy<SnapshotResult<KeyedStateHandle>> snapshotStrategy;
 
+	/** Factory for priority queue state. */
+	private PriorityQueueSetFactory priorityQueueFactory;
+
 	public RocksDBKeyedStateBackend(
 		String operatorIdentifier,
 		ClassLoader userCodeClassLoader,
@@ -255,7 +268,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		KeyGroupRange keyGroupRange,
 		ExecutionConfig executionConfig,
 		boolean enableIncrementalCheckpointing,
-		LocalRecoveryConfig localRecoveryConfig
+		LocalRecoveryConfig localRecoveryConfig,
+		RocksDBStateBackend.PriorityQueueStateType priorityQueueStateType
 	) throws IOException {
 
 		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange, executionConfig);
@@ -296,6 +310,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		this.writeOptions = new WriteOptions().setDisableWAL(true);
 
+		switch (priorityQueueStateType) {
+			case HEAP:
+				this.priorityQueueFactory = new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128);
+				break;
+			case ROCKS:
+				this.priorityQueueFactory = new RocksDBPriorityQueueSetFactory();
+				break;
+			default:
+				break;
+		}
+
 		LOG.debug("Setting initial keyed backend uid for operator {} to {}.", this.operatorIdentifier, this.backendUID);
 	}
 
@@ -378,6 +403,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				IOUtils.closeQuietly(columnMetaData.f0);
 			}
 
+			// ... then close the priority queue related resources ...
+			if (priorityQueueFactory instanceof AutoCloseable) {
+				IOUtils.closeQuietly((AutoCloseable) priorityQueueFactory);
+			}
+
 			// ... and finally close the DB instance ...
 			IOUtils.closeQuietly(db);
 
@@ -394,6 +424,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	@Nonnull
+	@Override
+	public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+		@Nonnull String stateName,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
+		@Nonnull PriorityComparator<T> elementComparator,
+		@Nonnull KeyExtractorFunction<T> keyExtractor) {
+
+		return priorityQueueFactory.create(stateName, byteOrderedElementSerializer, elementComparator, keyExtractor);
+	}
+
 	private void cleanInstanceBasePath() {
 		LOG.info("Deleting existing instance base directory {}.", instanceBasePath);
 
@@ -1290,7 +1331,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				namespaceSerializer,
 				stateDesc.getSerializer());
 
-			ColumnFamilyHandle columnFamily = createColumnFamily(stateName);
+			ColumnFamilyHandle columnFamily = createColumnFamily(stateName, db);
 
 			stateInfo = Tuple2.of(columnFamily, newMetaInfo);
 			kvStateInformation.put(stateDesc.getName(), stateInfo);
@@ -1302,7 +1343,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/**
 	 * Creates a column family handle for use with a k/v state.
 	 */
-	private ColumnFamilyHandle createColumnFamily(String stateName) throws IOException {
+	private ColumnFamilyHandle createColumnFamily(String stateName, RocksDB db) {
 		byte[] nameBytes = stateName.getBytes(ConfigConstants.DEFAULT_CHARSET);
 		Preconditions.checkState(!Arrays.equals(RocksDB.DEFAULT_COLUMN_FAMILY, nameBytes),
 			"The chosen state name 'default' collides with the name of the default column family!");
@@ -1312,7 +1353,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		try {
 			return db.createColumnFamily(columnDescriptor);
 		} catch (RocksDBException e) {
-			throw new IOException("Error creating ColumnFamilyHandle.", e);
+			throw new FlinkRuntimeException("Error creating ColumnFamilyHandle.", e);
 		}
 	}
 
@@ -2579,4 +2620,126 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		ReadOptions readOptions) {
 		return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle, readOptions));
 	}
+
+	/**
+	 * Encapsulates the logic and resources in connection with creating priority queue state structures.
+	 */
+	class RocksDBPriorityQueueSetFactory implements PriorityQueueSetFactory, AutoCloseable {
+
+		/** Default cache size per key-group. */
+		private static final int DEFAULT_CACHES_SIZE = 8 * 1024;
+
+		/** A shared buffer to serialize elements for the priority queue. */
+		@Nonnull
+		private final ByteArrayOutputStreamWithPos elementSerializationOutStream;
+
+		/** A shared adapter wrapper around elementSerializationOutStream to become a {@link DataOutputView}. */
+		@Nonnull
+		private final DataOutputViewStreamWrapper elementSerializationOutView;
+
+		/** A shared {@link RocksDBWriteBatchWrapper} to batch modifications to priority queues. */
+		@Nonnull
+		private final RocksDBWriteBatchWrapper writeBatchWrapper;
+
+		/** Map to track all column families created to back priority queues. */
+		@Nonnull
+		private final Map<String, ColumnFamilyHandle> priorityQueueColumnFamilies;
+
+		/** The mandatory default column family, so that we can close it later. */
+		@Nonnull
+		private final ColumnFamilyHandle defaultColumnFamily;
+
+		/** Path of the RocksDB instance that holds the priority queues. */
+		@Nonnull
+		private final File pqInstanceRocksDBPath;
+
+		/** RocksDB instance that holds the priority queues. */
+		@Nonnull
+		private final RocksDB pqDb;
+
+		RocksDBPriorityQueueSetFactory() throws IOException {
+			this.pqInstanceRocksDBPath = new File(instanceBasePath, "pqdb");
+			if (pqInstanceRocksDBPath.exists()) {
+				try {
+					FileUtils.deleteDirectory(pqInstanceRocksDBPath);
+				} catch (IOException ex) {
+					LOG.warn("Could not delete instance path for PQ RocksDB: " + pqInstanceRocksDBPath, ex);
+				}
+			}
+			List<ColumnFamilyHandle> columnFamilyHandles = new ArrayList<>(1);
+			this.pqDb = openDB(pqInstanceRocksDBPath.getAbsolutePath(), Collections.emptyList(), columnFamilyHandles);
+			this.elementSerializationOutStream = new ByteArrayOutputStreamWithPos();
+			this.elementSerializationOutView = new DataOutputViewStreamWrapper(elementSerializationOutStream);
+			this.writeBatchWrapper = new RocksDBWriteBatchWrapper(pqDb, writeOptions);
+			this.defaultColumnFamily = columnFamilyHandles.get(0);
+			this.priorityQueueColumnFamilies = new HashMap<>();
+		}
+
+		@Nonnull
+		@Override
+		public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+			@Nonnull String stateName,
+			@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
+			@Nonnull PriorityComparator<T> elementPriorityComparator,
+			@Nonnull KeyExtractorFunction<T> keyExtractor) {
+
+			final ColumnFamilyHandle columnFamilyHandle =
+				priorityQueueColumnFamilies.computeIfAbsent(
+					stateName,
+					(name) -> RocksDBKeyedStateBackend.this.createColumnFamily(name, pqDb));
+
+			@Nonnull
+			TieBreakingPriorityComparator<T> tieBreakingComparator =
+				new TieBreakingPriorityComparator<>(
+					elementPriorityComparator,
+					byteOrderedElementSerializer,
+					elementSerializationOutStream,
+					elementSerializationOutView);
+
+			return new KeyGroupPartitionedPriorityQueue<>(
+				keyExtractor,
+				elementPriorityComparator,
+				new KeyGroupPartitionedPriorityQueue.PartitionQueueSetFactory<T, CachingInternalPriorityQueueSet<T>>() {
+					@Nonnull
+					@Override
+					public CachingInternalPriorityQueueSet<T> create(
+						int keyGroupId,
+						int numKeyGroups,
+						@Nonnull PriorityComparator<T> elementPriorityComparator) {
+
+						CachingInternalPriorityQueueSet.OrderedSetCache<T> cache =
+							new TreeOrderedSetCache<>(tieBreakingComparator, DEFAULT_CACHES_SIZE);
+						CachingInternalPriorityQueueSet.OrderedSetStore<T> store =
+							new RocksDBOrderedSetStore<>(
+								keyGroupId,
+								keyGroupPrefixBytes,
+								pqDb,
+								columnFamilyHandle,
+								byteOrderedElementSerializer,
+								elementSerializationOutStream,
+								elementSerializationOutView,
+								writeBatchWrapper);
+
+						return new CachingInternalPriorityQueueSet<>(cache, store);
+					}
+				},
+				keyGroupRange,
+				numberOfKeyGroups);
+		}
+
+		@Override
+		public void close() {
+			IOUtils.closeQuietly(writeBatchWrapper);
+			for (ColumnFamilyHandle columnFamilyHandle : priorityQueueColumnFamilies.values()) {
+				IOUtils.closeQuietly(columnFamilyHandle);
+			}
+			IOUtils.closeQuietly(defaultColumnFamily);
+			IOUtils.closeQuietly(pqDb);
+			try {
+				FileUtils.deleteDirectory(pqInstanceRocksDBPath);
+			} catch (IOException ex) {
+				LOG.warn("Could not delete instance path for PQ RocksDB: " + pqInstanceRocksDBPath, ex);
+			}
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java
index e512933..5284314 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStore.java
@@ -28,7 +28,6 @@ import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.FlinkRuntimeException;
 
 import org.rocksdb.ColumnFamilyHandle;
-import org.rocksdb.ReadOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksDBException;
 
@@ -61,10 +60,6 @@ public class RocksDBOrderedSetStore<T> implements CachingInternalPriorityQueueSe
 	@Nonnull
 	private final ColumnFamilyHandle columnFamilyHandle;
 
-	/** Read options for RocksDB. */
-	@Nonnull
-	private final ReadOptions readOptions;
-
 	/**
 	 * Serializer for the contained elements. The lexicographical order of the bytes of serialized objects must be
 	 * aligned with their logical order.
@@ -93,14 +88,12 @@ public class RocksDBOrderedSetStore<T> implements CachingInternalPriorityQueueSe
 		@Nonnegative int keyGroupPrefixBytes,
 		@Nonnull RocksDB db,
 		@Nonnull ColumnFamilyHandle columnFamilyHandle,
-		@Nonnull ReadOptions readOptions,
 		@Nonnull TypeSerializer<T> byteOrderProducingSerializer,
 		@Nonnull ByteArrayOutputStreamWithPos outputStream,
 		@Nonnull DataOutputViewStreamWrapper outputView,
 		@Nonnull RocksDBWriteBatchWrapper batchWrapper) {
 		this.db = db;
 		this.columnFamilyHandle = columnFamilyHandle;
-		this.readOptions = readOptions;
 		this.byteOrderProducingSerializer = byteOrderProducingSerializer;
 		this.outputStream = outputStream;
 		this.outputView = outputView;
@@ -169,7 +162,7 @@ public class RocksDBOrderedSetStore<T> implements CachingInternalPriorityQueueSe
 
 		return new RocksToJavaIteratorAdapter(
 			new RocksIteratorWrapper(
-				db.newIterator(columnFamilyHandle, readOptions)));
+				db.newIterator(columnFamilyHandle)));
 	}
 
 	/**
@@ -232,6 +225,10 @@ public class RocksDBOrderedSetStore<T> implements CachingInternalPriorityQueueSe
 		private RocksToJavaIteratorAdapter(@Nonnull RocksIteratorWrapper iterator) {
 			this.iterator = iterator;
 			try {
+				// TODO we could check if it is more efficient to make the seek more specific, e.g. with a provided hint
+				// that is lexicographically closer the first expected element in the key-group. I wonder if this could
+				// help to improve the seek if there are many tombstones for elements at the beginning of the key-group
+				// (like for elements that have been removed in previous polling, before they are compacted away).
 				iterator.seek(groupPrefixBytes);
 				deserializeNextElementIfAvailable();
 			} catch (Exception ex) {

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 81d6265..998521b 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -59,6 +59,7 @@ import java.util.List;
 import java.util.Random;
 import java.util.UUID;
 
+import static org.apache.flink.contrib.streaming.state.RockDBBackendOptions.PRIORITY_QUEUE_STATE_TYPE;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -76,6 +77,14 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  */
 public class RocksDBStateBackend extends AbstractStateBackend implements ConfigurableStateBackend {
 
+	/**
+	 * The options to chose for the type of priority queue state.
+	 */
+	public enum PriorityQueueStateType {
+		HEAP,
+		ROCKS
+	}
+
 	private static final long serialVersionUID = 1L;
 
 	private static final Logger LOG = LoggerFactory.getLogger(RocksDBStateBackend.class);
@@ -109,6 +118,9 @@ public class RocksDBStateBackend extends AbstractStateBackend implements Configu
 	/** This determines if incremental checkpointing is enabled. */
 	private final TernaryBoolean enableIncrementalCheckpointing;
 
+	/** This determines the type of priority queue state. */
+	private final PriorityQueueStateType priorityQueueStateType;
+
 	// -- runtime values, set on TaskManager when initializing / using the backend
 
 	/** Base paths for RocksDB directory, as initialized. */
@@ -221,6 +233,8 @@ public class RocksDBStateBackend extends AbstractStateBackend implements Configu
 	public RocksDBStateBackend(StateBackend checkpointStreamBackend, TernaryBoolean enableIncrementalCheckpointing) {
 		this.checkpointStreamBackend = checkNotNull(checkpointStreamBackend);
 		this.enableIncrementalCheckpointing = enableIncrementalCheckpointing;
+		// for now, we use still the heap-based implementation as default
+		this.priorityQueueStateType = PriorityQueueStateType.HEAP;
 	}
 
 	/**
@@ -256,6 +270,11 @@ public class RocksDBStateBackend extends AbstractStateBackend implements Configu
 		this.enableIncrementalCheckpointing = original.enableIncrementalCheckpointing.resolveUndefined(
 			config.getBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS));
 
+		final String priorityQueueTypeString = config.getString(PRIORITY_QUEUE_STATE_TYPE.key(), "");
+
+		this.priorityQueueStateType = priorityQueueTypeString.length() > 0 ?
+			PriorityQueueStateType.valueOf(priorityQueueTypeString.toUpperCase()) : original.priorityQueueStateType;
+
 		// configure local directories
 		if (original.localRocksDbDirectories != null) {
 			this.localRocksDbDirectories = original.localRocksDbDirectories;
@@ -422,7 +441,8 @@ public class RocksDBStateBackend extends AbstractStateBackend implements Configu
 				keyGroupRange,
 				env.getExecutionConfig(),
 				isIncrementalCheckpointsEnabled(),
-				localRecoveryConfig);
+				localRecoveryConfig,
+				priorityQueueStateType);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java
index ae20cf2..5f26835 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/CachingInternalPriorityQueueSetWithRocksDBStoreTest.java
@@ -57,7 +57,6 @@ public class CachingInternalPriorityQueueSetWithRocksDBStoreTest extends Caching
 			prefixBytes,
 			rocksDBResource.getRocksDB(),
 			rocksDBResource.getDefaultColumnFamily(),
-			rocksDBResource.getReadOptions(),
 			TestElementSerializer.INSTANCE,
 			outputStream,
 			outputView,

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java
index 256a83b..0b1d07b 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBOrderedSetStoreTest.java
@@ -124,7 +124,6 @@ public class RocksDBOrderedSetStoreTest {
 			keyGroupPrefixBytes,
 			rocksDBResource.getRocksDB(),
 			rocksDBResource.getDefaultColumnFamily(),
-			rocksDBResource.getReadOptions(),
 			byteOrderSerializer,
 			outputStreamWithPos,
 			outputView,

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index ad89583..69069d6 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -240,7 +240,8 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 				new KeyGroupRange(0, 0),
 				new ExecutionConfig(),
 				enableIncrementalCheckpointing,
-				TestLocalRecoveryConfig.disabled());
+				TestLocalRecoveryConfig.disabled(),
+				RocksDBStateBackend.PriorityQueueStateType.HEAP);
 
 			verify(columnFamilyOptions, Mockito.times(1))
 				.setMergeOperatorName(RocksDBKeyedStateBackend.MERGE_OPERATOR_NAME);

http://git-wip-us.apache.org/repos/asf/flink/blob/79b38f8f/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index 9915dd5..797a26a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -730,9 +730,11 @@ public abstract class AbstractStreamOperator<OUT>
 		checkTimerServiceInitialization();
 
 		// the following casting is to overcome type restrictions.
-		TypeSerializer<K> keySerializer = (TypeSerializer<K>) getKeyedStateBackend().getKeySerializer();
+		KeyedStateBackend<K> keyedStateBackend = getKeyedStateBackend();
+		TypeSerializer<K> keySerializer = keyedStateBackend.getKeySerializer();
 		InternalTimeServiceManager<K> keyedTimeServiceHandler = (InternalTimeServiceManager<K>) timeServiceManager;
-		return keyedTimeServiceHandler.getInternalTimerService(name, keySerializer, namespaceSerializer, triggerable);
+		TimerSerializer<K, N> timerSerializer = new TimerSerializer<>(keySerializer, namespaceSerializer);
+		return keyedTimeServiceHandler.getInternalTimerService(name, timerSerializer, triggerable);
 	}
 
 	public void processWatermark(Watermark mark) throws Exception {