You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by az...@apache.org on 2020/05/19 07:46:55 UTC

[flink] 03/04: [FLINK-15758][MemManager] Release segment and its unsafe memory in GC Cleaner

This is an automated email from the ASF dual-hosted git repository.

azagrebin pushed a commit to branch release-1.10
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 2a4520935c8b546ac05f664f947d143b06322000
Author: Andrey Zagrebin <az...@apache.org>
AuthorDate: Tue Feb 4 17:34:20 2020 +0100

    [FLINK-15758][MemManager] Release segment and its unsafe memory in GC Cleaner
    
    After #9747, managed memory is allocated from UNSAFE, not as direct nio buffers as before 1.10.
    The releasing of segments released also underlying unsafe memory which is dangerous in general
    as there can be still references to java objects giving access to the released memory. If this reference
    ever leaks, the illegal memory access can result in memory corruption of other code parts w/o even segmentation fault.
    
    The solution can be similar to how Java handles direct memory limit:
    - underlying byte buffers of segments are registered to phantom reference queue with a Java GC cleaner which releases the unsafe memory
    - all allocations and reservations are managed by a memory limit and an atomic available memory
    - if available memory is not enough while reserving, the phantom reference queue processing is triggered to run hopefully discovered by GC cleaners
    - memory can be released directly or in a GC cleaner
    
    The GC is also sped up by nulling out byte buffer reference in `HybridMemorySegment#free` which is inaccessible anyways after freeing.
    Otherwise also a lot of tests, which hold accidental references to memory segments, have to be fixed to not hold them.
    The `MemoryManager#verifyEmpty` checks that everything can be GC'ed at the end of the tests and
    after slot closing in production to detect memory leaks if any other references are held, e.g. from `HybridMemorySegment#wrap`.
    
    This closes #11109.
---
 .../flink/core/memory/HybridMemorySegment.java     |  27 +-
 .../flink/core/memory/MemorySegmentFactory.java    |  17 +-
 .../org/apache/flink/core/memory/MemoryUtils.java  |   7 +-
 .../apache/flink/util/JavaGcCleanerWrapper.java    | 413 ++++++++++++++-------
 .../flink/core/memory/CrossSegmentTypeTest.java    |   2 +-
 .../flink/core/memory/EndiannessAccessChecks.java  |   2 +-
 .../HybridOffHeapUnsafeMemorySegmentTest.java      |   4 +-
 .../flink/core/memory/MemorySegmentChecksTest.java |   4 +-
 .../core/memory/MemorySegmentUndersizedTest.java   |   4 +-
 .../core/memory/OperationsOnFreedSegmentTest.java  |   2 +-
 .../flink/util/JavaGcCleanerWrapperTest.java       |   2 +-
 .../apache/flink/runtime/memory/MemoryManager.java | 100 ++---
 .../flink/runtime/memory/UnsafeMemoryBudget.java   | 183 +++++++++
 .../flink/runtime/memory/MemoryManagerTest.java    |  48 +++
 .../runtime/memory/UnsafeMemoryBudgetTest.java     |  85 +++++
 15 files changed, 665 insertions(+), 235 deletions(-)

diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java b/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java
index 1693e9a..fb7a4ba 100644
--- a/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java
+++ b/flink-core/src/main/java/org/apache/flink/core/memory/HybridMemorySegment.java
@@ -56,11 +56,7 @@ public final class HybridMemorySegment extends MemorySegment {
 	 * released.
 	 */
 	@Nullable
-	private final ByteBuffer offHeapBuffer;
-
-	/** The cleaner is called to free the underlying native memory. */
-	@Nullable
-	private final Runnable cleaner;
+	private ByteBuffer offHeapBuffer;
 
 	/**
 	  * Creates a new memory segment that represents the memory backing the given direct byte buffer.
@@ -71,13 +67,11 @@ public final class HybridMemorySegment extends MemorySegment {
 	  *
 	  * @param buffer The byte buffer whose memory is represented by this memory segment.
 	  * @param owner The owner references by this memory segment.
-	  * @param cleaner optional action to run upon freeing the segment.
 	  * @throws IllegalArgumentException Thrown, if the given ByteBuffer is not direct.
 	  */
-	HybridMemorySegment(@Nonnull ByteBuffer buffer, @Nullable Object owner, @Nullable Runnable cleaner) {
+	HybridMemorySegment(@Nonnull ByteBuffer buffer, @Nullable Object owner) {
 		super(getByteBufferAddress(buffer), buffer.capacity(), owner);
 		this.offHeapBuffer = buffer;
-		this.cleaner = cleaner;
 	}
 
 	/**
@@ -91,13 +85,18 @@ public final class HybridMemorySegment extends MemorySegment {
 	HybridMemorySegment(byte[] buffer, Object owner) {
 		super(buffer, owner);
 		this.offHeapBuffer = null;
-		this.cleaner = null;
 	}
 
 	// -------------------------------------------------------------------------
 	//  MemorySegment operations
 	// -------------------------------------------------------------------------
 
+	@Override
+	public void free() {
+		super.free();
+		offHeapBuffer = null; // to enable GC of unsafe memory
+	}
+
 	/**
 	 * Gets the buffer that owns the memory of this memory segment.
 	 *
@@ -106,6 +105,8 @@ public final class HybridMemorySegment extends MemorySegment {
 	public ByteBuffer getOffHeapBuffer() {
 		if (offHeapBuffer != null) {
 			return offHeapBuffer;
+		} else if (isFreed()) {
+			throw new IllegalStateException("segment has been freed");
 		} else {
 			throw new IllegalStateException("Memory segment does not represent off heap memory");
 		}
@@ -134,14 +135,6 @@ public final class HybridMemorySegment extends MemorySegment {
 		}
 	}
 
-	@Override
-	public void free() {
-		super.free();
-		if (cleaner != null) {
-			cleaner.run();
-		}
-	}
-
 	// ------------------------------------------------------------------------
 	//  Random Access get() and put() methods
 	// ------------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java
index 2751d9c..ee301a1 100644
--- a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java
+++ b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegmentFactory.java
@@ -19,6 +19,7 @@
 package org.apache.flink.core.memory;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.util.ExceptionUtils;
 
 import org.slf4j.Logger;
@@ -37,6 +38,7 @@ import java.nio.ByteBuffer;
 @Internal
 public final class MemorySegmentFactory {
 	private static final Logger LOG = LoggerFactory.getLogger(MemorySegmentFactory.class);
+	private static final Runnable NO_OP = () -> {};
 
 	/**
 	 * Creates a new memory segment that targets the given heap memory region.
@@ -100,7 +102,12 @@ public final class MemorySegmentFactory {
 	 */
 	public static MemorySegment allocateUnpooledOffHeapMemory(int size, Object owner) {
 		ByteBuffer memory = allocateDirectMemory(size);
-		return new HybridMemorySegment(memory, owner, null);
+		return new HybridMemorySegment(memory, owner);
+	}
+
+	@VisibleForTesting
+	public static MemorySegment allocateOffHeapUnsafeMemory(int size) {
+		return allocateOffHeapUnsafeMemory(size, null, NO_OP);
 	}
 
 	private static ByteBuffer allocateDirectMemory(int size) {
@@ -131,12 +138,14 @@ public final class MemorySegmentFactory {
 	 *
 	 * @param size The size of the off-heap unsafe memory segment to allocate.
 	 * @param owner The owner to associate with the off-heap unsafe memory segment.
+	 * @param customCleanupAction A custom action to run upon calling GC cleaner.
 	 * @return A new memory segment, backed by off-heap unsafe memory.
 	 */
-	public static MemorySegment allocateOffHeapUnsafeMemory(int size, Object owner) {
+	public static MemorySegment allocateOffHeapUnsafeMemory(int size, Object owner, Runnable customCleanupAction) {
 		long address = MemoryUtils.allocateUnsafe(size);
 		ByteBuffer offHeapBuffer = MemoryUtils.wrapUnsafeMemoryWithByteBuffer(address, size);
-		return new HybridMemorySegment(offHeapBuffer, owner, MemoryUtils.createMemoryGcCleaner(offHeapBuffer, address));
+		MemoryUtils.createMemoryGcCleaner(offHeapBuffer, address, customCleanupAction);
+		return new HybridMemorySegment(offHeapBuffer, owner);
 	}
 
 	/**
@@ -150,7 +159,7 @@ public final class MemorySegmentFactory {
 	 * @return A new memory segment representing the given off-heap memory.
 	 */
 	public static MemorySegment wrapOffHeapMemory(ByteBuffer memory) {
-		return new HybridMemorySegment(memory, null, null);
+		return new HybridMemorySegment(memory, null);
 	}
 
 }
diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/MemoryUtils.java b/flink-core/src/main/java/org/apache/flink/core/memory/MemoryUtils.java
index 7f6508c..34cac43 100644
--- a/flink-core/src/main/java/org/apache/flink/core/memory/MemoryUtils.java
+++ b/flink-core/src/main/java/org/apache/flink/core/memory/MemoryUtils.java
@@ -107,8 +107,11 @@ public class MemoryUtils {
 	 * @param address address of the unsafe memory to release
 	 * @return action to run to release the unsafe memory manually
 	 */
-	static Runnable createMemoryGcCleaner(Object owner, long address) {
-		return JavaGcCleanerWrapper.create(owner, () -> releaseUnsafe(address));
+	static Runnable createMemoryGcCleaner(Object owner, long address, Runnable customCleanup) {
+		return JavaGcCleanerWrapper.createCleaner(owner, () -> {
+			releaseUnsafe(address);
+			customCleanup.run();
+		});
 	}
 
 	private static void releaseUnsafe(long address) {
diff --git a/flink-core/src/main/java/org/apache/flink/util/JavaGcCleanerWrapper.java b/flink-core/src/main/java/org/apache/flink/util/JavaGcCleanerWrapper.java
index becd028..ae8edd3 100644
--- a/flink-core/src/main/java/org/apache/flink/util/JavaGcCleanerWrapper.java
+++ b/flink-core/src/main/java/org/apache/flink/util/JavaGcCleanerWrapper.java
@@ -21,10 +21,14 @@ package org.apache.flink.util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nullable;
+
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Optional;
+import java.util.function.Supplier;
 
 /**
  * Java GC Cleaner wrapper.
@@ -43,109 +47,251 @@ public enum JavaGcCleanerWrapper {
 	private static final Logger LOG = LoggerFactory.getLogger(JavaGcCleanerWrapper.class);
 
 	private static final Collection<CleanerProvider> CLEANER_PROVIDERS =
-		Arrays.asList(LegacyCleanerProvider.INSTANCE, Java9CleanerProvider.INSTANCE);
-	private static final CleanerFactory CLEANER_FACTORY = findGcCleaner();
+		Arrays.asList(createLegacyCleanerProvider(), createJava9CleanerProvider());
+	private static final CleanerManager CLEANER_MANAGER = findGcCleanerManager();
+
+	private static CleanerProvider createLegacyCleanerProvider() {
+		String name = "Legacy (before Java 9) cleaner";
+		ReflectionUtils reflectionUtils = new ReflectionUtils(name + " provider");
+		String cleanerClassName = "sun.misc.Cleaner";
+
+		// Actual Legacy code under the hood:
+		//
+		// public static Runnable createCleaner(Object owner, Runnable cleanOperation) {
+		//     sun.misc.Cleaner jvmCleaner = sun.misc.Cleaner.create(owner, cleanOperation);
+		//     return () -> jvmCleaner.clean();
+		// }
+		//
+		// public static boolean tryRunPendingCleaners() throws InterruptedException {
+		//     sun.misc.JavaLangRefAccess javaLangRefAccess = sun.misc.SharedSecrets.getJavaLangRefAccess();
+		//	   return javaLangRefAccess.tryHandlePendingReference();
+		// }
+		//
+		return new CleanerProvider(
+			name,
+			new CleanerFactoryProvider(
+				name,
+				reflectionUtils,
+				cleanerClassName,
+				Optional::empty, // there is no Cleaner object, static method of its class will be called to create it
+				"create", // static method of Cleaner class to create it
+				cleanerClassName, // Cleaner is Cleanable in this case
+				"clean"),
+			new PendingCleanersRunnerProvider(
+				name,
+				reflectionUtils,
+				"sun.misc.SharedSecrets",
+				"sun.misc.JavaLangRefAccess",
+				"getJavaLangRefAccess",
+				"tryHandlePendingReference"));
+	}
 
-	private static CleanerFactory findGcCleaner() {
-		CleanerFactory foundCleanerFactory = null;
+	private static CleanerProvider createJava9CleanerProvider() {
+		String name = "New Java 9+ cleaner";
+		ReflectionUtils reflectionUtils = new ReflectionUtils(name + " provider");
+		String cleanerClassName = "java.lang.ref.Cleaner";
+
+		// Actual Java 9+ code under the hood:
+		//
+		// public static Runnable createCleaner(Object owner, Runnable cleanOperation) {
+		//     java.lang.ref.Cleaner jvmCleaner = java.lang.ref.Cleaner.create();
+		//     java.lang.ref.Cleaner.Cleanable cleanable = jvmCleaner.register(owner, cleanOperation);
+		//     return () -> cleanable.clean();
+		// }
+		//
+		// public static boolean tryRunPendingCleaners() throws InterruptedException {
+		//     jdk.internal.misc.JavaLangRefAccess javaLangRefAccess = jdk.internal.misc.SharedSecrets.getJavaLangRefAccess();
+		//	   return javaLangRefAccess.waitForReferenceProcessing();
+		// }
+		//
+		return new CleanerProvider(
+			name,
+			new CleanerFactoryProvider(
+				name,
+				reflectionUtils,
+				cleanerClassName,
+				() -> {
+					Class<?> cleanerClass = reflectionUtils.findClass(cleanerClassName);
+					Method cleanerCreateMethod = reflectionUtils.findMethod(cleanerClass, "create");
+					try {
+						return Optional.of(cleanerCreateMethod.invoke(null));
+					} catch (IllegalAccessException | InvocationTargetException e) {
+						throw new FlinkRuntimeException("Failed to create a Java 9 Cleaner", e);
+					}
+				},
+				"register",
+				"java.lang.ref.Cleaner$Cleanable",
+				"clean"),
+			new PendingCleanersRunnerProvider(
+				name,
+				reflectionUtils,
+				"jdk.internal.misc.SharedSecrets",
+				"jdk.internal.misc.JavaLangRefAccess",
+				"getJavaLangRefAccess",
+				"waitForReferenceProcessing"));
+	}
+
+	private static CleanerManager findGcCleanerManager() {
+		CleanerManager foundCleanerManager = null;
 		Throwable t = null;
 		for (CleanerProvider cleanerProvider : CLEANER_PROVIDERS) {
-			//noinspection OverlyBroadCatchBlock
 			try {
-				foundCleanerFactory = cleanerProvider.createCleanerFactory();
+				foundCleanerManager = cleanerProvider.createCleanerManager();
 				break;
 			} catch (Throwable e) {
 				t = ExceptionUtils.firstOrSuppressed(e, t);
 			}
 		}
 
-		if (foundCleanerFactory == null) {
+		if (foundCleanerManager == null) {
 			String errorMessage = String.format("Failed to find GC Cleaner among available providers: %s", CLEANER_PROVIDERS);
 			throw new Error(errorMessage, t);
 		}
-		return foundCleanerFactory;
-	}
-
-	public static Runnable create(Object owner, Runnable cleanOperation) {
-		return CLEANER_FACTORY.create(owner, cleanOperation);
+		return foundCleanerManager;
 	}
 
-	@FunctionalInterface
-	private interface CleanerProvider {
-		CleanerFactory createCleanerFactory() throws ClassNotFoundException;
+	public static Runnable createCleaner(Object owner, Runnable cleanOperation) {
+		return CLEANER_MANAGER.create(owner, cleanOperation);
 	}
 
-	@FunctionalInterface
-	private interface CleanerFactory {
-		Runnable create(Object owner, Runnable cleanOperation);
+	public static boolean tryRunPendingCleaners() throws InterruptedException {
+		return CLEANER_MANAGER.tryRunPendingCleaners();
 	}
 
-	private enum LegacyCleanerProvider implements CleanerProvider {
-		INSTANCE;
+	private static class CleanerProvider {
+		private final String cleanerName;
+		private final CleanerFactoryProvider cleanerFactoryProvider;
+		private final PendingCleanersRunnerProvider pendingCleanersRunnerProvider;
+
+		private CleanerProvider(
+				String cleanerName,
+				CleanerFactoryProvider cleanerFactoryProvider,
+				PendingCleanersRunnerProvider pendingCleanersRunnerProvider) {
+			this.cleanerName = cleanerName;
+			this.cleanerFactoryProvider = cleanerFactoryProvider;
+			this.pendingCleanersRunnerProvider = pendingCleanersRunnerProvider;
+		}
 
-		private static final String LEGACY_CLEANER_CLASS_NAME = "sun.misc.Cleaner";
+		private CleanerManager createCleanerManager() {
+			return new CleanerManager(
+				cleanerName,
+				cleanerFactoryProvider.createCleanerFactory(),
+				pendingCleanersRunnerProvider.createPendingCleanersRunner());
+		}
 
 		@Override
-		public CleanerFactory createCleanerFactory() {
-			Class<?> cleanerClass = findCleanerClass();
-			Method cleanerCreateMethod = getCleanerCreateMethod(cleanerClass);
-			Method cleanerCleanMethod = getCleanerCleanMethod(cleanerClass);
-			return new LegacyCleanerFactory(cleanerCreateMethod, cleanerCleanMethod);
+		public String toString() {
+			return cleanerName + " provider";
 		}
+	}
 
-		private static Class<?> findCleanerClass() {
-			try {
-				return Class.forName(LEGACY_CLEANER_CLASS_NAME);
-			} catch (ClassNotFoundException e) {
-				throw new FlinkRuntimeException("Failed to find Java legacy Cleaner class", e);
-			}
+	private static class CleanerManager {
+		private final String cleanerName;
+		private final CleanerFactory cleanerFactory;
+		private final PendingCleanersRunner pendingCleanersRunner;
+
+		private CleanerManager(
+				String cleanerName,
+				CleanerFactory cleanerFactory,
+				PendingCleanersRunner pendingCleanersRunner) {
+			this.cleanerName = cleanerName;
+			this.cleanerFactory = cleanerFactory;
+			this.pendingCleanersRunner = pendingCleanersRunner;
 		}
 
-		private static Method getCleanerCreateMethod(Class<?> cleanerClass) {
-			try {
-				return cleanerClass.getMethod("create", Object.class, Runnable.class);
-			} catch (NoSuchMethodException e) {
-				throw new FlinkRuntimeException("Failed to find Java legacy Cleaner#create method", e);
-			}
+		private Runnable create(Object owner, Runnable cleanOperation) {
+			return cleanerFactory.create(owner, cleanOperation);
 		}
 
-		private static Method getCleanerCleanMethod(Class<?> cleanerClass) {
-			try {
-				return cleanerClass.getMethod("clean");
-			} catch (NoSuchMethodException e) {
-				throw new FlinkRuntimeException("Failed to find Java legacy Cleaner#clean method", e);
-			}
+		private boolean tryRunPendingCleaners() throws InterruptedException {
+			return pendingCleanersRunner.tryRunPendingCleaners();
 		}
 
 		@Override
 		public String toString() {
-			return "Legacy cleaner provider before Java 9 using " + LEGACY_CLEANER_CLASS_NAME;
+			return cleanerName + " manager";
 		}
 	}
 
-	private static final class LegacyCleanerFactory implements CleanerFactory {
-		private final Method cleanerCreateMethod;
-		private final Method cleanerCleanMethod;
+	private static class CleanerFactoryProvider {
+		private final String cleanerName;
+		private final ReflectionUtils reflectionUtils;
+		private final String cleanerClassName;
+		private final Supplier<Optional<Object>> cleanerSupplier;
+		private final String cleanableCreationMethodName;
+		private final String cleanableClassName;
+		private final String cleanMethodName;
+
+		private CleanerFactoryProvider(
+				String cleanerName,
+				ReflectionUtils reflectionUtils,
+				String cleanerClassName,
+				Supplier<Optional<Object>> cleanerSupplier,
+				String cleanableCreationMethodName, // Cleaner is a factory for Cleanable
+				String cleanableClassName,
+				String cleanMethodName) {
+			this.cleanerName = cleanerName;
+			this.reflectionUtils = reflectionUtils;
+			this.cleanerClassName = cleanerClassName;
+			this.cleanerSupplier = cleanerSupplier;
+			this.cleanableCreationMethodName = cleanableCreationMethodName;
+			this.cleanableClassName = cleanableClassName;
+			this.cleanMethodName = cleanMethodName;
+		}
 
-		private LegacyCleanerFactory(Method cleanerCreateMethod, Method cleanerCleanMethod) {
-			this.cleanerCreateMethod = cleanerCreateMethod;
-			this.cleanerCleanMethod = cleanerCleanMethod;
+		private CleanerFactory createCleanerFactory() {
+			Class<?> cleanerClass = reflectionUtils.findClass(cleanerClassName);
+			Method cleanableCreationMethod = reflectionUtils.findMethod(
+				cleanerClass,
+				cleanableCreationMethodName,
+				Object.class,
+				Runnable.class);
+			Class<?> cleanableClass = reflectionUtils.findClass(cleanableClassName);
+			Method cleanMethod = reflectionUtils.findMethod(cleanableClass, cleanMethodName);
+			return new CleanerFactory(
+				cleanerName,
+				cleanerSupplier.get().orElse(null), // static method of Cleaner class will be called to create Cleanable
+				cleanableCreationMethod,
+				cleanMethod);
 		}
 
 		@Override
-		public Runnable create(Object owner, Runnable cleanupOperation) {
-			Object cleaner;
+		public String toString() {
+			return cleanerName + " factory provider using " + cleanerClassName;
+		}
+	}
+
+	private static class CleanerFactory {
+		private final String cleanerName;
+		@Nullable
+		private final Object cleaner;
+		private final Method cleanableCreationMethod;
+		private final Method cleanMethod;
+
+		private CleanerFactory(
+			String cleanerName,
+			@Nullable Object cleaner,
+			Method cleanableCreationMethod,
+			Method cleanMethod) {
+			this.cleanerName = cleanerName;
+			this.cleaner = cleaner;
+			this.cleanableCreationMethod = cleanableCreationMethod;
+			this.cleanMethod = cleanMethod;
+		}
+
+		private Runnable create(Object owner, Runnable cleanupOperation) {
+			Object cleanable;
 			try {
-				cleaner = cleanerCreateMethod.invoke(null, owner, cleanupOperation);
+				cleanable = cleanableCreationMethod.invoke(cleaner, owner, cleanupOperation);
 			} catch (IllegalAccessException | InvocationTargetException e) {
-				throw new Error("Failed to create a Java legacy Cleaner", e);
+				throw new Error("Failed to create a " + cleanerName, e);
 			}
 			String ownerString = owner.toString(); // lambda should not capture the owner object
 			return () -> {
 				try {
-					cleanerCleanMethod.invoke(cleaner);
+					cleanMethod.invoke(cleanable);
 				} catch (IllegalAccessException | InvocationTargetException e) {
-					String message = String.format("FATAL UNEXPECTED - Failed to invoke a Java legacy Cleaner for %s", ownerString);
+					String message = String.format("FATAL UNEXPECTED - Failed to invoke a %s for %s", cleanerName, ownerString);
 					LOG.error(message, e);
 					throw new Error(message, e);
 				}
@@ -153,106 +299,115 @@ public enum JavaGcCleanerWrapper {
 		}
 	}
 
-	/** New cleaner provider for Java 9+. */
-	private enum Java9CleanerProvider implements CleanerProvider {
-		INSTANCE;
+	private static class PendingCleanersRunnerProvider {
+		private final String cleanerName;
+		private final ReflectionUtils reflectionUtils;
+		private final String sharedSecretsClassName;
+		private final String javaLangRefAccessClassName;
+		private final String getJavaLangRefAccessName;
+		private final String tryHandlePendingReferenceName;
+
+		private PendingCleanersRunnerProvider(
+				String cleanerName,
+				ReflectionUtils reflectionUtils,
+				String sharedSecretsClassName,
+				String javaLangRefAccessClassName,
+				String getJavaLangRefAccessName,
+				String tryHandlePendingReferenceName) {
+			this.cleanerName = cleanerName;
+			this.reflectionUtils = reflectionUtils;
+			this.sharedSecretsClassName = sharedSecretsClassName;
+			this.javaLangRefAccessClassName = javaLangRefAccessClassName;
+			this.getJavaLangRefAccessName = getJavaLangRefAccessName;
+			this.tryHandlePendingReferenceName = tryHandlePendingReferenceName;
+		}
 
-		private static final String JAVA9_CLEANER_CLASS_NAME = "java.lang.ref.Cleaner";
+		private PendingCleanersRunner createPendingCleanersRunner() {
+			Class<?> sharedSecretsClass = reflectionUtils.findClass(sharedSecretsClassName);
+			Class<?> javaLangRefAccessClass = reflectionUtils.findClass(javaLangRefAccessClassName);
+			Method getJavaLangRefAccessMethod = reflectionUtils.findMethod(sharedSecretsClass, getJavaLangRefAccessName);
+			Method tryHandlePendingReferenceMethod = reflectionUtils.findMethod(
+				javaLangRefAccessClass,
+				tryHandlePendingReferenceName);
+			return new PendingCleanersRunner(getJavaLangRefAccessMethod, tryHandlePendingReferenceMethod);
+		}
 
 		@Override
-		public CleanerFactory createCleanerFactory() {
-			Class<?> cleanerClass = findCleanerClass();
-			Method cleanerCreateMethod = getCleanerCreateMethod(cleanerClass);
-			Object cleaner = createCleaner(cleanerCreateMethod);
-			Method cleanerRegisterMethod = getCleanerRegisterMethod(cleanerClass);
-			Class<?> cleanableClass = findCleanableClass();
-			Method cleanMethod = getCleanMethod(cleanableClass);
-			return new Java9CleanerFactory(cleaner, cleanerRegisterMethod, cleanMethod);
+		public String toString() {
+			return "Pending " + cleanerName + "s runner provider";
 		}
+	}
 
-		private static Class<?> findCleanerClass() {
-			try {
-				return Class.forName(JAVA9_CLEANER_CLASS_NAME);
-			} catch (ClassNotFoundException e) {
-				throw new FlinkRuntimeException("Failed to find Java 9 Cleaner class", e);
-			}
-		}
+	private static class PendingCleanersRunner {
+		private final Method getJavaLangRefAccessMethod;
+		private final Method waitForReferenceProcessingMethod;
 
-		private static Method getCleanerCreateMethod(Class<?> cleanerClass) {
-			try {
-				return cleanerClass.getMethod("create");
-			} catch (NoSuchMethodException e) {
-				throw new FlinkRuntimeException("Failed to find Java 9 Cleaner#create method", e);
-			}
+		private PendingCleanersRunner(Method getJavaLangRefAccessMethod, Method waitForReferenceProcessingMethod) {
+			this.getJavaLangRefAccessMethod = getJavaLangRefAccessMethod;
+			this.waitForReferenceProcessingMethod = waitForReferenceProcessingMethod;
 		}
 
-		private static Object createCleaner(Method cleanerCreateMethod) {
+		private boolean tryRunPendingCleaners() throws InterruptedException {
+			Object javaLangRefAccess = getJavaLangRefAccess();
 			try {
-				return cleanerCreateMethod.invoke(null);
+				return (Boolean) waitForReferenceProcessingMethod.invoke(javaLangRefAccess);
 			} catch (IllegalAccessException | InvocationTargetException e) {
-				throw new FlinkRuntimeException("Failed to create a Java 9 Cleaner", e);
-			}
-		}
-
-		private static Method getCleanerRegisterMethod(Class<?> cleanerClass) {
-			try {
-				return cleanerClass.getMethod("register", Object.class, Runnable.class);
-			} catch (NoSuchMethodException e) {
-				throw new FlinkRuntimeException("Failed to find Java 9 Cleaner#create method", e);
+				throwIfCauseIsInterruptedException(e);
+				return throwInvocationError(e, javaLangRefAccess, waitForReferenceProcessingMethod);
 			}
 		}
 
-		private static Class<?> findCleanableClass() {
+		private Object getJavaLangRefAccess() {
 			try {
-				return Class.forName("java.lang.ref.Cleaner$Cleanable");
-			} catch (ClassNotFoundException e) {
-				throw new FlinkRuntimeException("Failed to find Java 9 Cleaner#Cleanable class", e);
+				return getJavaLangRefAccessMethod.invoke(null);
+			} catch (IllegalAccessException | InvocationTargetException e) {
+				return throwInvocationError(e, null, waitForReferenceProcessingMethod);
 			}
 		}
 
-		private static Method getCleanMethod(Class<?> cleanableClass) {
-			try {
-				return cleanableClass.getMethod("clean");
-			} catch (NoSuchMethodException e) {
-				throw new FlinkRuntimeException("Failed to find Java 9 Cleaner$Cleanable#clean method", e);
+		private static void throwIfCauseIsInterruptedException(Throwable t) throws InterruptedException {
+			// if the original wrapped method can throw InterruptedException
+			// then we may want to explicitly handle in the user code for certain implementations
+			if (t.getCause() instanceof InterruptedException) {
+				throw (InterruptedException) t.getCause();
 			}
 		}
 
-		@Override
-		public String toString() {
-			return "New cleaner provider for Java 9+" + JAVA9_CLEANER_CLASS_NAME;
+		private static <T> T throwInvocationError(Throwable t, @Nullable Object obj, Method method) {
+			String message = String.format(
+				"FATAL UNEXPECTED - Failed to invoke %s%s",
+				obj == null ? "" : obj.getClass().getSimpleName() + '#',
+				method.getName());
+			LOG.error(message, t);
+			throw new Error(message, t);
 		}
 	}
 
-	private static final class Java9CleanerFactory implements CleanerFactory {
-		private final Object cleaner;
-		private final Method cleanerRegisterMethod;
-		private final Method cleanMethod;
+	private static class ReflectionUtils {
+		private final String logPrefix;
 
-		private Java9CleanerFactory(Object cleaner, Method cleanerRegisterMethod, Method cleanMethod) {
-			this.cleaner = cleaner;
-			this.cleanerRegisterMethod = cleanerRegisterMethod;
-			this.cleanMethod = cleanMethod;
+		private ReflectionUtils(String logPrefix) {
+			this.logPrefix = logPrefix;
 		}
 
-		@Override
-		public Runnable create(Object owner, Runnable cleanupOperation) {
-			Object cleanable;
+		private Class<?> findClass(String className) {
 			try {
-				cleanable = cleanerRegisterMethod.invoke(cleaner, owner, cleanupOperation);
-			} catch (IllegalAccessException | InvocationTargetException e) {
-				throw new Error("Failed to create a Java 9 Cleaner", e);
+				return Class.forName(className);
+			} catch (ClassNotFoundException e) {
+				throw new FlinkRuntimeException(
+					String.format("%s: Failed to find %s class", logPrefix, className.split("\\.")[0]),
+					e);
+			}
+		}
+
+		private Method findMethod(Class<?> cl, String methodName, Class<?>... parameterTypes) {
+			try {
+				return cl.getMethod(methodName, parameterTypes);
+			} catch (NoSuchMethodException e) {
+				throw new FlinkRuntimeException(
+					String.format("%s: Failed to find %s#%s method", logPrefix, cl.getSimpleName(), methodName),
+					e);
 			}
-			String ownerString = owner.toString(); // lambda should not capture the owner object
-			return () -> {
-				try {
-					cleanMethod.invoke(cleanable);
-				} catch (IllegalAccessException | InvocationTargetException e) {
-					String message = String.format("FATAL UNEXPECTED - Failed to invoke a Java 9 Cleaner$Cleanable for %s", ownerString);
-					LOG.error(message, e);
-					throw new Error(message, e);
-				}
-			};
 		}
 	}
 }
diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/CrossSegmentTypeTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/CrossSegmentTypeTest.java
index 51804d0..056c5ee 100644
--- a/flink-core/src/test/java/org/apache/flink/core/memory/CrossSegmentTypeTest.java
+++ b/flink-core/src/test/java/org/apache/flink/core/memory/CrossSegmentTypeTest.java
@@ -159,7 +159,7 @@ public class CrossSegmentTypeTest {
 			new HeapMemorySegment(new byte[size]),
 			MemorySegmentFactory.allocateUnpooledSegment(size),
 			MemorySegmentFactory.allocateUnpooledOffHeapMemory(size),
-			MemorySegmentFactory.allocateOffHeapUnsafeMemory(size, null)
+			MemorySegmentFactory.allocateOffHeapUnsafeMemory(size)
 		};
 		return segments;
 	}
diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/EndiannessAccessChecks.java b/flink-core/src/test/java/org/apache/flink/core/memory/EndiannessAccessChecks.java
index c2db44e..6e81b8e 100644
--- a/flink-core/src/test/java/org/apache/flink/core/memory/EndiannessAccessChecks.java
+++ b/flink-core/src/test/java/org/apache/flink/core/memory/EndiannessAccessChecks.java
@@ -47,7 +47,7 @@ public class EndiannessAccessChecks {
 
 	@Test
 	public void testHybridOffHeapUnsafeSegment() {
-		testBigAndLittleEndianAccessUnaligned(MemorySegmentFactory.allocateOffHeapUnsafeMemory(11111, null));
+		testBigAndLittleEndianAccessUnaligned(MemorySegmentFactory.allocateOffHeapUnsafeMemory(11111));
 	}
 
 	private void testBigAndLittleEndianAccessUnaligned(MemorySegment segment) {
diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java
index f167203..e5c70cf 100644
--- a/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java
+++ b/flink-core/src/test/java/org/apache/flink/core/memory/HybridOffHeapUnsafeMemorySegmentTest.java
@@ -33,11 +33,11 @@ public class HybridOffHeapUnsafeMemorySegmentTest extends HybridOffHeapMemorySeg
 
 	@Override
 	MemorySegment createSegment(int size) {
-		return MemorySegmentFactory.allocateOffHeapUnsafeMemory(size, null);
+		return MemorySegmentFactory.allocateOffHeapUnsafeMemory(size);
 	}
 
 	@Override
 	MemorySegment createSegment(int size, Object owner) {
-		return MemorySegmentFactory.allocateOffHeapUnsafeMemory(size, owner);
+		return MemorySegmentFactory.allocateOffHeapUnsafeMemory(size, owner, () -> {});
 	}
 }
diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentChecksTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentChecksTest.java
index 3e3e267..09619cd 100644
--- a/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentChecksTest.java
+++ b/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentChecksTest.java
@@ -46,12 +46,12 @@ public class MemorySegmentChecksTest {
 
 	@Test(expected = NullPointerException.class)
 	public void testHybridOffHeapNullBuffer2() {
-		new HybridMemorySegment(null, new Object(), () -> {});
+		new HybridMemorySegment((ByteBuffer) null, new Object());
 	}
 
 	@Test(expected = IllegalArgumentException.class)
 	public void testHybridNonDirectBuffer() {
-		new HybridMemorySegment(ByteBuffer.allocate(1024), new Object(), () -> {});
+		new HybridMemorySegment(ByteBuffer.allocate(1024), new Object());
 	}
 
 	@Test(expected = IllegalArgumentException.class)
diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentUndersizedTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentUndersizedTest.java
index 1363703..1566ac0 100644
--- a/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentUndersizedTest.java
+++ b/flink-core/src/test/java/org/apache/flink/core/memory/MemorySegmentUndersizedTest.java
@@ -63,7 +63,7 @@ public class MemorySegmentUndersizedTest {
 
 	@Test
 	public void testZeroSizeOffHeapUnsafeHybridSegment() {
-		MemorySegment segment = MemorySegmentFactory.allocateOffHeapUnsafeMemory(0, null);
+		MemorySegment segment = MemorySegmentFactory.allocateOffHeapUnsafeMemory(0);
 
 		testZeroSizeBuffer(segment);
 		testSegmentWithSizeLargerZero(segment);
@@ -86,7 +86,7 @@ public class MemorySegmentUndersizedTest {
 
 	@Test
 	public void testSizeOneOffHeapUnsafeHybridSegment() {
-		testSegmentWithSizeLargerZero(MemorySegmentFactory.allocateOffHeapUnsafeMemory(1, null));
+		testSegmentWithSizeLargerZero(MemorySegmentFactory.allocateOffHeapUnsafeMemory(1));
 	}
 
 	private static void testZeroSizeBuffer(MemorySegment segment) {
diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/OperationsOnFreedSegmentTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/OperationsOnFreedSegmentTest.java
index bf27fc1..0013399 100644
--- a/flink-core/src/test/java/org/apache/flink/core/memory/OperationsOnFreedSegmentTest.java
+++ b/flink-core/src/test/java/org/apache/flink/core/memory/OperationsOnFreedSegmentTest.java
@@ -128,7 +128,7 @@ public class OperationsOnFreedSegmentTest {
 		MemorySegment heap = new HeapMemorySegment(new byte[PAGE_SIZE]);
 		MemorySegment hybridHeap = MemorySegmentFactory.wrap(new byte[PAGE_SIZE]);
 		MemorySegment hybridOffHeap = MemorySegmentFactory.allocateUnpooledOffHeapMemory(PAGE_SIZE);
-		MemorySegment hybridOffHeapUnsafe = MemorySegmentFactory.allocateOffHeapUnsafeMemory(PAGE_SIZE, null);
+		MemorySegment hybridOffHeapUnsafe = MemorySegmentFactory.allocateOffHeapUnsafeMemory(PAGE_SIZE);
 
 		MemorySegment[] segments = { heap, hybridHeap, hybridOffHeap, hybridOffHeapUnsafe };
 
diff --git a/flink-core/src/test/java/org/apache/flink/util/JavaGcCleanerWrapperTest.java b/flink-core/src/test/java/org/apache/flink/util/JavaGcCleanerWrapperTest.java
index ead8fce..1785003 100644
--- a/flink-core/src/test/java/org/apache/flink/util/JavaGcCleanerWrapperTest.java
+++ b/flink-core/src/test/java/org/apache/flink/util/JavaGcCleanerWrapperTest.java
@@ -32,7 +32,7 @@ public class JavaGcCleanerWrapperTest {
 	@Test
 	public void testCleanOperationRunsOnlyOnceEitherOnGcOrExplicitly() throws InterruptedException {
 		AtomicInteger callCounter = new AtomicInteger();
-		Runnable cleaner = JavaGcCleanerWrapper.create(new Object(), callCounter::incrementAndGet);
+		Runnable cleaner = JavaGcCleanerWrapper.createCleaner(new Object(), callCounter::incrementAndGet);
 		System.gc(); // not guaranteed to be run always but should in practice
 		Thread.sleep(10); // more chance for GC to run
 		cleaner.run();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java
index e9c7b30..948053a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java
@@ -29,7 +29,6 @@ import org.apache.flink.util.function.ThrowingRunnable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import javax.annotation.Nonnegative;
 import javax.annotation.Nullable;
 
 import java.util.ArrayList;
@@ -42,7 +41,6 @@ import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 
@@ -77,11 +75,9 @@ public class MemoryManager {
 
 	private final long pageSize;
 
-	private final long totalMemorySize;
-
 	private final long totalNumberOfPages;
 
-	private final AtomicLong availableMemorySize;
+	private final UnsafeMemoryBudget memoryBudget;
 
 	private final SharedResources sharedResources;
 
@@ -98,13 +94,12 @@ public class MemoryManager {
 		sanityCheck(memorySize, pageSize);
 
 		this.pageSize = pageSize;
-		this.totalMemorySize = memorySize;
+		this.memoryBudget = new UnsafeMemoryBudget(memorySize);
 		this.totalNumberOfPages = memorySize / pageSize;
 		this.allocatedSegments = new ConcurrentHashMap<>();
 		this.reservedMemory = new ConcurrentHashMap<>();
-		this.availableMemorySize = new AtomicLong(totalMemorySize);
 		this.sharedResources = new SharedResources();
-		verifyIntTotalNumberOfPages(totalMemorySize, totalNumberOfPages);
+		verifyIntTotalNumberOfPages(memorySize, totalNumberOfPages);
 
 		LOG.debug(
 			"Initialized MemoryManager with total memory size {} and page size {}.",
@@ -146,7 +141,6 @@ public class MemoryManager {
 			// mark as shutdown and release memory
 			isShutDown = true;
 			reservedMemory.clear();
-			availableMemorySize.set(totalMemorySize);
 
 			// go over all allocated segments and release them
 			for (Set<MemorySegment> segments : allocatedSegments.values()) {
@@ -175,7 +169,7 @@ public class MemoryManager {
 	 * @return True, if the memory manager is empty and valid, false if it is not empty or corrupted.
 	 */
 	public boolean verifyEmpty() {
-		return availableMemorySize.get() == totalMemorySize;
+		return memoryBudget.verifyEmpty();
 	}
 
 	// ------------------------------------------------------------------------
@@ -230,16 +224,17 @@ public class MemoryManager {
 
 		long memoryToReserve = numberOfPages * pageSize;
 		try {
-			reserveMemory(memoryToReserve);
+			memoryBudget.reserveMemory(memoryToReserve);
 		} catch (MemoryReservationException e) {
 			throw new MemoryAllocationException(String.format("Could not allocate %d pages", numberOfPages), e);
 		}
 
+		Runnable pageCleanup = this::releasePage;
 		allocatedSegments.compute(owner, (o, currentSegmentsForOwner) -> {
 			Set<MemorySegment> segmentsForOwner = currentSegmentsForOwner == null ?
 				new HashSet<>(numberOfPages) : currentSegmentsForOwner;
 			for (long i = numberOfPages; i > 0; i--) {
-				MemorySegment segment = allocateOffHeapUnsafeMemory(getPageSize(), owner);
+				MemorySegment segment = allocateOffHeapUnsafeMemory(getPageSize(), owner, pageCleanup);
 				target.add(segment);
 				segmentsForOwner.add(segment);
 			}
@@ -249,6 +244,10 @@ public class MemoryManager {
 		Preconditions.checkState(!isShutDown, "Memory manager has been concurrently shut down.");
 	}
 
+	private void releasePage() {
+		memoryBudget.releaseMemory(getPageSize());
+	}
+
 	/**
 	 * Tries to release the memory for the specified segment.
 	 *
@@ -270,9 +269,7 @@ public class MemoryManager {
 		try {
 			allocatedSegments.computeIfPresent(segment.getOwner(), (o, segsForOwner) -> {
 				segment.free();
-				if (segsForOwner.remove(segment)) {
-					releaseMemory(getPageSize());
-				}
+				segsForOwner.remove(segment);
 				return segsForOwner.isEmpty() ? null : segsForOwner;
 			});
 		}
@@ -296,8 +293,6 @@ public class MemoryManager {
 
 		Preconditions.checkState(!isShutDown, "Memory manager has been shut down.");
 
-		AtomicLong releasedMemory = new AtomicLong(0L);
-
 		// since concurrent modifications to the collection
 		// can disturb the release, we need to try potentially multiple times
 		boolean successfullyReleased = false;
@@ -316,7 +311,7 @@ public class MemoryManager {
 					segment = segmentsIterator.next();
 				}
 				while (segment != null) {
-					segment = releaseSegmentsForOwnerUntilNextOwner(segment, segmentsIterator, releasedMemory);
+					segment = releaseSegmentsForOwnerUntilNextOwner(segment, segmentsIterator);
 				}
 				segments.clear();
 				// the only way to exit the loop
@@ -326,18 +321,15 @@ public class MemoryManager {
 				// call releases the memory. fall through the loop and try again
 			}
 		} while (!successfullyReleased);
-
-		releaseMemory(releasedMemory.get());
 	}
 
 	private MemorySegment releaseSegmentsForOwnerUntilNextOwner(
 			MemorySegment firstSeg,
-			Iterator<MemorySegment> segmentsIterator,
-			AtomicLong releasedMemory) {
+			Iterator<MemorySegment> segmentsIterator) {
 		AtomicReference<MemorySegment> nextOwnerMemorySegment = new AtomicReference<>();
 		Object owner = firstSeg.getOwner();
 		allocatedSegments.compute(owner, (o, segsForOwner) -> {
-			releasedMemory.addAndGet(freeSegment(firstSeg, segsForOwner));
+			freeSegment(firstSeg, segsForOwner);
 			while (segmentsIterator.hasNext()) {
 				MemorySegment segment = segmentsIterator.next();
 				try {
@@ -349,7 +341,7 @@ public class MemoryManager {
 						nextOwnerMemorySegment.set(segment);
 						break;
 					}
-					releasedMemory.addAndGet(freeSegment(segment, segsForOwner));
+					freeSegment(segment, segsForOwner);
 				} catch (Throwable t) {
 					throw new RuntimeException(
 						"Error removing book-keeping reference to allocated memory segment.", t);
@@ -360,9 +352,11 @@ public class MemoryManager {
 		return nextOwnerMemorySegment.get();
 	}
 
-	private long freeSegment(MemorySegment segment, @Nullable Collection<MemorySegment> segments) {
+	private static void freeSegment(MemorySegment segment, @Nullable Collection<MemorySegment> segments) {
 		segment.free();
-		return segments != null && segments.remove(segment) ? getPageSize() : 0L;
+		if (segments != null) {
+			segments.remove(segment);
+		}
 	}
 
 	/**
@@ -386,12 +380,9 @@ public class MemoryManager {
 		}
 
 		// free each segment
-		long releasedMemory = 0L;
-		for (MemorySegment segment : segments) {
+		for (MemorySegment segment: segments) {
 			segment.free();
-			releasedMemory += getPageSize();
 		}
-		releaseMemory(releasedMemory);
 
 		segments.clear();
 	}
@@ -410,7 +401,7 @@ public class MemoryManager {
 			return;
 		}
 
-		reserveMemory(size);
+		memoryBudget.reserveMemory(size);
 
 		reservedMemory.compute(owner, (o, memoryReservedForOwner) ->
 			memoryReservedForOwner == null ? size : memoryReservedForOwner + size);
@@ -450,7 +441,7 @@ public class MemoryManager {
 
 	private long releaseAndCalculateReservedMemory(long memoryToFree, long currentlyReserved) {
 		final long effectiveMemoryToRelease = Math.min(currentlyReserved, memoryToFree);
-		releaseMemory(effectiveMemoryToRelease);
+		memoryBudget.releaseMemory(effectiveMemoryToRelease);
 
 		return currentlyReserved - effectiveMemoryToRelease;
 	}
@@ -470,7 +461,7 @@ public class MemoryManager {
 		checkMemoryReservationPreconditions(owner, 0L);
 		Long memoryReservedForOwner = reservedMemory.remove(owner);
 		if (memoryReservedForOwner != null) {
-			releaseMemory(memoryReservedForOwner);
+			memoryBudget.releaseMemory(memoryReservedForOwner);
 		}
 	}
 
@@ -595,7 +586,7 @@ public class MemoryManager {
 	 * @return The total size of memory.
 	 */
 	public long getMemorySize() {
-		return totalMemorySize;
+		return memoryBudget.getTotalMemorySize();
 	}
 
 	/**
@@ -604,7 +595,7 @@ public class MemoryManager {
 	 * @return The available amount of memory.
 	 */
 	public long availableMemory() {
-		return availableMemorySize.get();
+		return memoryBudget.getAvailableMemorySize();
 	}
 
 	/**
@@ -636,44 +627,7 @@ public class MemoryManager {
 			"The fraction of memory to allocate must within (0, 1], was: %s", fraction);
 
 		//noinspection NumericCastThatLosesPrecision
-		return (long) Math.floor(totalMemorySize * fraction);
-	}
-
-	private void reserveMemory(long size) throws MemoryReservationException {
-		long availableOrReserved = tryReserveMemory(size);
-		if (availableOrReserved < size) {
-			throw new MemoryReservationException(
-				String.format("Could not allocate %d bytes, only %d bytes are remaining", size, availableOrReserved));
-		}
-	}
-
-	private long tryReserveMemory(long size) {
-		long currentAvailableMemorySize;
-		while (size <= (currentAvailableMemorySize = availableMemorySize.get())) {
-			if (availableMemorySize.compareAndSet(currentAvailableMemorySize, currentAvailableMemorySize - size)) {
-				return size;
-			}
-		}
-		return currentAvailableMemorySize;
-	}
-
-	private void releaseMemory(@Nonnegative long size) {
-		if (size == 0) {
-			return;
-		}
-		boolean released = false;
-		long currentAvailableMemorySize = 0L;
-		while (!released && totalMemorySize >= (currentAvailableMemorySize = availableMemorySize.get()) + size) {
-			released = availableMemorySize
-				.compareAndSet(currentAvailableMemorySize, currentAvailableMemorySize + size);
-		}
-		if (!released) {
-			throw new IllegalStateException(String.format(
-				"Trying to release more managed memory (%d bytes) than has been allocated (%d bytes), the total size is %d bytes",
-				size,
-				currentAvailableMemorySize,
-				totalMemorySize));
-		}
+		return (long) Math.floor(memoryBudget.getTotalMemorySize() * fraction);
 	}
 
 	// ------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/memory/UnsafeMemoryBudget.java b/flink-runtime/src/main/java/org/apache/flink/runtime/memory/UnsafeMemoryBudget.java
new file mode 100644
index 0000000..a85f40e
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/memory/UnsafeMemoryBudget.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.memory;
+
+import org.apache.flink.util.JavaGcCleanerWrapper;
+
+import javax.annotation.Nonnegative;
+
+import java.util.concurrent.atomic.AtomicLong;
+
+/**
+ * Tracker of memory allocation and release within a custom limit.
+ *
+ * <p>This memory management in Flink uses the same approach as Java uses to allocate direct memory
+ * and release it upon GC but memory here can be also released directly with {@link #releaseMemory(long)}.
+ * The memory reservation {@link #reserveMemory(long)} tries firstly to run all phantom cleaners, created with
+ * {@link org.apache.flink.core.memory.MemoryUtils#createMemoryGcCleaner(Object, long, Runnable)},
+ * for objects which are ready to be garbage collected. If memory is still not available, it triggers GC and
+ * continues to process any ready cleaners making {@link #MAX_SLEEPS} attempts before throwing {@link OutOfMemoryError}.
+ */
+class UnsafeMemoryBudget {
+	// max. number of sleeps during try-reserving with exponentially
+	// increasing delay before throwing OutOfMemoryError:
+	// 1, 2, 4, 8, 16, 32, 64, 128, 256, 512 (total 1023 ms ~ 1 s)
+	// which means that MemoryReservationException will be thrown after 1 s of trying
+	private static final int MAX_SLEEPS = 10;
+	private static final int RETRIGGER_GC_AFTER_SLEEPS = 9; // ~ 0.5 sec
+
+	private final long totalMemorySize;
+
+	private final AtomicLong availableMemorySize;
+
+	UnsafeMemoryBudget(long totalMemorySize) {
+		this.totalMemorySize = totalMemorySize;
+		this.availableMemorySize = new AtomicLong(totalMemorySize);
+	}
+
+	long getTotalMemorySize() {
+		return totalMemorySize;
+	}
+
+	long getAvailableMemorySize() {
+		return availableMemorySize.get();
+	}
+
+	boolean verifyEmpty() {
+		try {
+			reserveMemory(totalMemorySize);
+		} catch (MemoryReservationException e) {
+			return false;
+		}
+		releaseMemory(totalMemorySize);
+		return availableMemorySize.get() == totalMemorySize;
+	}
+
+	/**
+	 * Reserve memory of certain size if it is available.
+	 *
+	 * <p>Adjusted version of {@link java.nio.Bits#reserveMemory(long, int)} taken from Java 11.
+	 */
+	@SuppressWarnings({"OverlyComplexMethod", "JavadocReference", "NestedTryStatement"})
+	void reserveMemory(long size) throws MemoryReservationException {
+		long availableOrReserved = tryReserveMemory(size);
+		// optimist!
+		if (availableOrReserved >= size) {
+			return;
+		}
+
+		boolean interrupted = false;
+		try {
+
+			// Retry allocation until success or there are no more
+			// references (including Cleaners that might free direct
+			// buffer memory) to process and allocation still fails.
+			boolean refprocActive;
+			do {
+				try {
+					refprocActive = JavaGcCleanerWrapper.tryRunPendingCleaners();
+				} catch (InterruptedException e) {
+					// Defer interrupts and keep trying.
+					interrupted = true;
+					refprocActive = true;
+				}
+				availableOrReserved = tryReserveMemory(size);
+				if (availableOrReserved >= size) {
+					return;
+				}
+			} while (refprocActive);
+
+			// trigger VM's Reference processing
+			System.gc();
+
+			// A retry loop with exponential back-off delays.
+			// Sometimes it would suffice to give up once reference
+			// processing is complete.  But if there are many threads
+			// competing for memory, this gives more opportunities for
+			// any given thread to make progress.  In particular, this
+			// seems to be enough for a stress test like
+			// DirectBufferAllocTest to (usually) succeed, while
+			// without it that test likely fails.  Since failure here
+			// ends in MemoryReservationException, there's no need to hurry.
+			long sleepTime = 1;
+			int sleeps = 0;
+			while (true) {
+				availableOrReserved = tryReserveMemory(size);
+				if (availableOrReserved >= size) {
+					return;
+				}
+				if (sleeps >= MAX_SLEEPS) {
+					break;
+				}
+				if (sleeps >= RETRIGGER_GC_AFTER_SLEEPS) {
+					// trigger again VM's Reference processing if we have to wait longer
+					System.gc();
+				}
+				try {
+					if (!JavaGcCleanerWrapper.tryRunPendingCleaners()) {
+						Thread.sleep(sleepTime);
+						sleepTime <<= 1;
+						sleeps++;
+					}
+				} catch (InterruptedException e) {
+					interrupted = true;
+				}
+			}
+
+			// no luck
+			throw new MemoryReservationException(
+				String.format("Could not allocate %d bytes, only %d bytes are remaining", size, availableOrReserved));
+
+		} finally {
+			if (interrupted) {
+				// don't swallow interrupts
+				Thread.currentThread().interrupt();
+			}
+		}
+	}
+
+	private long tryReserveMemory(long size) {
+		long currentAvailableMemorySize;
+		while (size <= (currentAvailableMemorySize = availableMemorySize.get())) {
+			if (availableMemorySize.compareAndSet(currentAvailableMemorySize, currentAvailableMemorySize - size)) {
+				return size;
+			}
+		}
+		return currentAvailableMemorySize;
+	}
+
+	void releaseMemory(@Nonnegative long size) {
+		if (size == 0) {
+			return;
+		}
+		boolean released = false;
+		long currentAvailableMemorySize = 0L;
+		while (!released && totalMemorySize >= (currentAvailableMemorySize = availableMemorySize.get()) + size) {
+			released = availableMemorySize
+				.compareAndSet(currentAvailableMemorySize, currentAvailableMemorySize + size);
+		}
+		if (!released) {
+			throw new IllegalStateException(String.format(
+				"Trying to release more managed memory (%d bytes) than has been allocated (%d bytes), the total size is %d bytes",
+				size,
+				currentAvailableMemorySize,
+				totalMemorySize));
+		}
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/memory/MemoryManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/memory/MemoryManagerTest.java
index adccccd..5297525 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/memory/MemoryManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/memory/MemoryManagerTest.java
@@ -27,10 +27,12 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Random;
+import java.util.stream.Collectors;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -292,6 +294,52 @@ public class MemoryManagerTest {
 		memoryManager.releaseAllMemory(owner2);
 	}
 
+	@Test(expected = MemoryAllocationException.class)
+	public void testAllocationFailsIfSegmentsNotGced() throws MemoryAllocationException {
+		List<ByteBuffer> byteBuffers = allocateAndReleaseAllSegmentsButKeepWrappedBufferRefs();
+		// this allocation should fail
+		memoryManager.allocatePages(new Object(), 1);
+		// this should not be reached but keeps the reference to the allocated memory and prevents its GC
+		byteBuffers.get(0).put(0, (byte) 1);
+	}
+
+	@Test(expected = MemoryReservationException.class)
+	public void testReservationFailsIfSegmentsNotGced() throws MemoryAllocationException, MemoryReservationException {
+		List<ByteBuffer> byteBuffers = allocateAndReleaseAllSegmentsButKeepWrappedBufferRefs();
+		// this allocation should fail
+		memoryManager.reserveMemory(new Object(), MemoryManager.DEFAULT_PAGE_SIZE);
+		// this should not be reached but keeps the reference to the allocated memory and prevents its GC
+		byteBuffers.get(0).put(0, (byte) 1);
+	}
+
+	@Test
+	public void testAllocationSuccessIfSegmentsGced() throws MemoryAllocationException {
+		allocateAndReleaseAllSegmentsButKeepWrappedBufferRefs();
+		// no reference to the allocated segments at this point, so the memory should be released by GC
+		// and this allocation should be successful
+		memoryManager.release(memoryManager.allocatePages(new Object(), 1));
+	}
+
+	@Test
+	public void testReservationSuccessIfSegmentsGced() throws MemoryAllocationException, MemoryReservationException {
+		allocateAndReleaseAllSegmentsButKeepWrappedBufferRefs();
+		// no reference to the allocated segments at this point, so the memory should be released by GC
+		Object owner = new Object();
+		// and this reservation should be successful
+		memoryManager.reserveMemory(owner, MemoryManager.DEFAULT_PAGE_SIZE);
+		memoryManager.releaseMemory(owner, MemoryManager.DEFAULT_PAGE_SIZE);
+	}
+
+	private List<ByteBuffer> allocateAndReleaseAllSegmentsButKeepWrappedBufferRefs() throws MemoryAllocationException {
+		List<MemorySegment> segments = memoryManager.allocatePages(new Object(), NUM_PAGES);
+		List<ByteBuffer> buffers = segments
+			.stream()
+			.map(segment -> segment.wrap(0, 1))
+			.collect(Collectors.toList());
+		memoryManager.release(segments);
+		return buffers;
+	}
+
 	@Test
 	public void testComputeMemorySize() {
 		double fraction = 0.6;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/memory/UnsafeMemoryBudgetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/memory/UnsafeMemoryBudgetTest.java
new file mode 100644
index 0000000..4f6edd8
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/memory/UnsafeMemoryBudgetTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.memory;
+
+import org.apache.flink.util.JavaGcCleanerWrapper;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Test;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertThat;
+
+/** Test suite for {@link UnsafeMemoryBudget}. */
+public class UnsafeMemoryBudgetTest extends TestLogger {
+
+	@Test
+	public void testGetTotalMemory() {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		assertThat(budget.getTotalMemorySize(), is(100L));
+	}
+
+	@Test
+	public void testReserveMemory() throws MemoryReservationException {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		budget.reserveMemory(50L);
+		assertThat(budget.getAvailableMemorySize(), is(50L));
+	}
+
+	@Test(expected = MemoryReservationException.class)
+	public void testReserveMemoryOverLimitFails() throws MemoryReservationException {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		budget.reserveMemory(120L);
+	}
+
+	@Test
+	public void testReleaseMemory() throws MemoryReservationException {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		budget.reserveMemory(50L);
+		budget.releaseMemory(30L);
+		assertThat(budget.getAvailableMemorySize(), is(80L));
+	}
+
+	@Test(expected = IllegalStateException.class)
+	public void testReleaseMemoryMoreThanReservedFails() throws MemoryReservationException {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		budget.reserveMemory(50L);
+		budget.releaseMemory(70L);
+	}
+
+	@Test(expected = MemoryReservationException.class)
+	public void testReservationFailsIfOwnerNotGced() throws MemoryReservationException {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		Object memoryOwner = new Object();
+		budget.reserveMemory(50L);
+		JavaGcCleanerWrapper.createCleaner(memoryOwner, () -> budget.releaseMemory(50L));
+		budget.reserveMemory(60L);
+		// this should not be reached but keeps the reference to the memoryOwner and prevents its GC
+		log.info(memoryOwner.toString());
+	}
+
+	@Test
+	public void testReservationSuccessIfOwnerGced() throws MemoryReservationException {
+		UnsafeMemoryBudget budget = new UnsafeMemoryBudget(100L);
+		budget.reserveMemory(50L);
+		JavaGcCleanerWrapper.createCleaner(new Object(), () -> budget.releaseMemory(50L));
+		budget.reserveMemory(60L);
+		assertThat(budget.getAvailableMemorySize(), is(40L));
+	}
+}