You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2020/05/19 15:12:06 UTC

[flink] 06/10: [FLINK-17547][task] Use iterator for unconsumed buffers. Motivation: support spilled records Changes: 1. change SpillingAdaptiveSpanningRecordDeserializer.getUnconsumedBuffer signature 2. adapt channel state persistence to new types

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1c9bf0368e9a233a2a013436628790c2c2b60bcb
Author: Roman Khachatryan <kh...@gmail.com>
AuthorDate: Mon May 18 20:29:05 2020 +0200

    [FLINK-17547][task] Use iterator for unconsumed buffers.
    Motivation: support spilled records
    Changes:
    1. change SpillingAdaptiveSpanningRecordDeserializer.getUnconsumedBuffer
    signature
    2. adapt channel state persistence to new types
    
    No changes in existing logic.
---
 .../org/apache/flink/util/CloseableIterator.java   | 77 +++++++++++++++++++++-
 .../main/java/org/apache/flink/util/IOUtils.java   |  7 ++
 .../channel/ChannelStateWriteRequest.java          | 33 +++++++---
 .../ChannelStateWriteRequestDispatcherImpl.java    |  6 +-
 .../ChannelStateWriteRequestExecutorImpl.java      | 20 ++++--
 .../checkpoint/channel/ChannelStateWriter.java     |  6 +-
 .../checkpoint/channel/ChannelStateWriterImpl.java | 17 +++--
 .../api/serialization/NonSpanningWrapper.java      | 22 +++++--
 .../api/serialization/RecordDeserializer.java      |  4 +-
 .../network/api/serialization/SpanningWrapper.java | 12 ++--
 ...SpillingAdaptiveSpanningRecordDeserializer.java | 15 +----
 .../partition/consumer/RemoteInputChannel.java     |  3 +-
 .../ChannelStateWriteRequestDispatcherTest.java    | 10 ++-
 .../ChannelStateWriteRequestExecutorImplTest.java  |  1 -
 .../channel/ChannelStateWriterImplTest.java        | 13 ++--
 .../channel/CheckpointInProgressRequestTest.java   |  7 +-
 .../checkpoint/channel/MockChannelStateWriter.java | 11 +++-
 .../channel/RecordingChannelStateWriter.java       | 12 +++-
 .../SpanningRecordSerializationTest.java           |  9 +--
 .../partition/consumer/SingleInputGateTest.java    | 14 +++-
 .../runtime/state/ChannelPersistenceITCase.java    |  3 +-
 .../runtime/io/CheckpointBarrierUnaligner.java     |  3 +-
 .../runtime/io/StreamTaskNetworkInput.java         | 11 ++--
 23 files changed, 235 insertions(+), 81 deletions(-)

diff --git a/flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java b/flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java
index 09ea046..cc51324 100644
--- a/flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java
+++ b/flink-core/src/main/java/org/apache/flink/util/CloseableIterator.java
@@ -20,10 +20,15 @@ package org.apache.flink.util;
 
 import javax.annotation.Nonnull;
 
+import java.util.ArrayDeque;
 import java.util.Collections;
+import java.util.Deque;
 import java.util.Iterator;
+import java.util.List;
 import java.util.function.Consumer;
 
+import static java.util.Arrays.asList;
+
 /**
  * This interface represents an {@link Iterator} that is also {@link AutoCloseable}. A typical use-case for this
  * interface are iterators that are based on native-resources such as files, network, or database connections. Clients
@@ -37,7 +42,42 @@ public interface CloseableIterator<T> extends Iterator<T>, AutoCloseable {
 
 	@Nonnull
 	static <T> CloseableIterator<T> adapterForIterator(@Nonnull Iterator<T> iterator) {
-		return new IteratorAdapter<>(iterator);
+		return adapterForIterator(iterator, () -> {});
+	}
+
+	static <T> CloseableIterator<T> adapterForIterator(@Nonnull Iterator<T> iterator, AutoCloseable close) {
+		return new IteratorAdapter<>(iterator, close);
+	}
+
+	static <T> CloseableIterator<T> fromList(List<T> list, Consumer<T> closeNotConsumed) {
+		return new CloseableIterator<T>(){
+			private final Deque<T> stack = new ArrayDeque<>(list);
+
+			@Override
+			public boolean hasNext() {
+				return !stack.isEmpty();
+			}
+
+			@Override
+			public T next() {
+				return stack.poll();
+			}
+
+			@Override
+			public void close() throws Exception {
+				Exception exception = null;
+				for (T el : stack) {
+					try {
+						closeNotConsumed.accept(el);
+					} catch (Exception e) {
+						exception = ExceptionUtils.firstOrSuppressed(e, exception);
+					}
+				}
+				if (exception != null) {
+					throw exception;
+				}
+			}
+		};
 	}
 
 	@SuppressWarnings("unchecked")
@@ -45,6 +85,34 @@ public interface CloseableIterator<T> extends Iterator<T>, AutoCloseable {
 		return (CloseableIterator<T>) EMPTY_INSTANCE;
 	}
 
+	static <T> CloseableIterator<T> ofElements(Consumer<T> closeNotConsumed, T... elements) {
+		return fromList(asList(elements), closeNotConsumed);
+	}
+
+	static <E> CloseableIterator<E> ofElement(E element, Consumer<E> closeIfNotConsumed) {
+		return new CloseableIterator<E>(){
+			private boolean hasNext = true;
+
+			@Override
+			public boolean hasNext() {
+				return hasNext;
+			}
+
+			@Override
+			public E next() {
+				hasNext = false;
+				return element;
+			}
+
+			@Override
+			public void close() {
+				if (hasNext) {
+					closeIfNotConsumed.accept(element);
+				}
+			}
+		};
+	}
+
 	/**
 	 * Adapter from {@link Iterator} to {@link CloseableIterator}. Does nothing on {@link #close()}.
 	 *
@@ -54,9 +122,11 @@ public interface CloseableIterator<T> extends Iterator<T>, AutoCloseable {
 
 		@Nonnull
 		private final Iterator<E> delegate;
+		private final AutoCloseable close;
 
-		IteratorAdapter(@Nonnull Iterator<E> delegate) {
+		IteratorAdapter(@Nonnull Iterator<E> delegate, AutoCloseable close) {
 			this.delegate = delegate;
+			this.close = close;
 		}
 
 		@Override
@@ -80,7 +150,8 @@ public interface CloseableIterator<T> extends Iterator<T>, AutoCloseable {
 		}
 
 		@Override
-		public void close() {
+		public void close() throws Exception {
+			close.close();
 		}
 	}
 }
diff --git a/flink-core/src/main/java/org/apache/flink/util/IOUtils.java b/flink-core/src/main/java/org/apache/flink/util/IOUtils.java
index 1f9af18..0b8f210 100644
--- a/flink-core/src/main/java/org/apache/flink/util/IOUtils.java
+++ b/flink-core/src/main/java/org/apache/flink/util/IOUtils.java
@@ -216,6 +216,13 @@ public final class IOUtils {
 	}
 
 	/**
+	 * @see #closeAll(Iterable)
+	 */
+	public static void closeAll(AutoCloseable... closeables) throws Exception {
+		closeAll(asList(closeables));
+	}
+
+	/**
 	 * Closes all {@link AutoCloseable} objects in the parameter, suppressing exceptions. Exception will be emitted
 	 * after calling close() on every object.
 	 *
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
index bd6c7ba..0848698 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
@@ -20,23 +20,24 @@ package org.apache.flink.runtime.checkpoint.channel;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.ThrowingConsumer;
 
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
 
 import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.CANCELLED;
 import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.COMPLETED;
 import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.EXECUTING;
 import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.FAILED;
 import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRequestState.NEW;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 interface ChannelStateWriteRequest {
 	long getCheckpointId();
 
-	void cancel(Throwable cause);
+	void cancel(Throwable cause) throws Exception;
 
 	static CheckpointInProgressRequest completeInput(long checkpointId) {
 		return new CheckpointInProgressRequest("completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput, false);
@@ -46,8 +47,24 @@ interface ChannelStateWriteRequest {
 		return new CheckpointInProgressRequest("completeOutput", checkpointId, ChannelStateCheckpointWriter::completeOutput, false);
 	}
 
-	static ChannelStateWriteRequest write(long checkpointId, InputChannelInfo info, Buffer... flinkBuffers) {
-		return new CheckpointInProgressRequest("writeInput", checkpointId, writer -> writer.writeInput(info, flinkBuffers), recycle(flinkBuffers), false);
+	static ChannelStateWriteRequest write(long checkpointId, InputChannelInfo info, CloseableIterator<Buffer> iterator) {
+		return new CheckpointInProgressRequest(
+			"writeInput",
+			checkpointId,
+			writer -> {
+				while (iterator.hasNext()) {
+					Buffer buffer = iterator.next();
+					try {
+						checkArgument(buffer.isBuffer());
+					} catch (Exception e) {
+						buffer.recycleBuffer();
+						throw e;
+					}
+					writer.writeInput(info, buffer);
+				}
+			},
+			throwable -> iterator.close(),
+			false);
 	}
 
 	static ChannelStateWriteRequest write(long checkpointId, ResultSubpartitionInfo info, Buffer... flinkBuffers) {
@@ -62,7 +79,7 @@ interface ChannelStateWriteRequest {
 		return new CheckpointInProgressRequest("abort", checkpointId, writer -> writer.fail(cause), true);
 	}
 
-	static Consumer<Throwable> recycle(Buffer[] flinkBuffers) {
+	static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) {
 		return unused -> {
 			for (Buffer b : flinkBuffers) {
 				b.recycleBuffer();
@@ -112,7 +129,7 @@ enum CheckpointInProgressRequestState {
 
 final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
 	private final ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action;
-	private final Consumer<Throwable> discardAction;
+	private final ThrowingConsumer<Throwable, Exception> discardAction;
 	private final long checkpointId;
 	private final String name;
 	private final boolean ignoreMissingWriter;
@@ -123,7 +140,7 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
 		}, ignoreMissingWriter);
 	}
 
-	CheckpointInProgressRequest(String name, long checkpointId, ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action, Consumer<Throwable> discardAction, boolean ignoreMissingWriter) {
+	CheckpointInProgressRequest(String name, long checkpointId, ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action, ThrowingConsumer<Throwable, Exception> discardAction, boolean ignoreMissingWriter) {
 		this.checkpointId = checkpointId;
 		this.action = checkNotNull(action);
 		this.discardAction = checkNotNull(discardAction);
@@ -137,7 +154,7 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
 	}
 
 	@Override
-	public void cancel(Throwable cause) {
+	public void cancel(Throwable cause) throws Exception {
 		if (state.compareAndSet(NEW, CANCELLED) || state.compareAndSet(FAILED, CANCELLED)) {
 			discardAction.accept(cause);
 		}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
index 843663e..0a15b91 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
@@ -51,7 +51,11 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
 		try {
 			dispatchInternal(request);
 		} catch (Exception e) {
-			request.cancel(e);
+			try {
+				request.cancel(e);
+			} catch (Exception ex) {
+				e.addSuppressed(ex);
+			}
 			throw e;
 		}
 	}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
index cbcc3f7..e87a21c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
@@ -32,6 +32,9 @@ import java.util.List;
 import java.util.concurrent.BlockingDeque;
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.LinkedBlockingDeque;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.IOUtils.closeAll;
 
 /**
  * Executes {@link ChannelStateWriteRequest}s in a separate thread. Any exception occurred during execution causes this
@@ -67,8 +70,15 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
 		} catch (Exception ex) {
 			thrown = ex;
 		} finally {
-			cleanupRequests();
-			dispatcher.fail(thrown == null ? new CancellationException() : thrown);
+			try {
+				closeAll(
+					this::cleanupRequests,
+					() -> dispatcher.fail(thrown == null ? new CancellationException() : thrown)
+				);
+			} catch (Exception e) {
+				//noinspection NonAtomicOperationOnVolatileField
+				thrown = ExceptionUtils.firstOrSuppressed(e, thrown);
+			}
 		}
 		LOG.debug("loop terminated");
 	}
@@ -87,14 +97,12 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
 		}
 	}
 
-	private void cleanupRequests() {
+	private void cleanupRequests() throws Exception {
 		Throwable cause = thrown == null ? new CancellationException() : thrown;
 		List<ChannelStateWriteRequest> drained = new ArrayList<>();
 		deque.drainTo(drained);
 		LOG.info("discarding {} drained requests", drained.size());
-		for (ChannelStateWriteRequest request : drained) {
-			request.cancel(cause);
-		}
+		closeAll(drained.stream().<AutoCloseable>map(request -> () -> request.cancel(cause)).collect(Collectors.toList()));
 	}
 
 	@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java
index e19b1e2..5dad559 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriter.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
+import org.apache.flink.util.CloseableIterator;
 
 import java.io.Closeable;
 import java.util.Collection;
@@ -99,11 +100,10 @@ public interface ChannelStateWriter extends Closeable {
 	 *                    It is intended to use for incremental snapshots.
 	 *                    If no data is passed it is ignored.
 	 * @param data zero or more <b>data</b> buffers ordered by their sequence numbers
-	 * @throws IllegalArgumentException if one or more passed buffers {@link Buffer#isBuffer()  isn't a buffer}
 	 * @see org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter#SEQUENCE_NUMBER_RESTORED
 	 * @see org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter#SEQUENCE_NUMBER_UNKNOWN
 	 */
-	void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, Buffer... data) throws IllegalArgumentException;
+	void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> data);
 
 	/**
 	 * Add in-flight buffers from the {@link org.apache.flink.runtime.io.network.partition.ResultSubpartition ResultSubpartition}.
@@ -161,7 +161,7 @@ public interface ChannelStateWriter extends Closeable {
 		}
 
 		@Override
-		public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, Buffer... data) {
+		public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> data) {
 		}
 
 		@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
index 412a9f5..b6fa588 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.Preconditions;
 
 import org.slf4j.Logger;
@@ -103,10 +104,9 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
 	}
 
 	@Override
-	public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, Buffer... data) {
-		LOG.debug("add input data, checkpoint id: {}, channel: {}, startSeqNum: {}, num buffers: {}",
-			checkpointId, info, startSeqNum, data == null ? 0 : data.length);
-		enqueue(write(checkpointId, info, checkBufferType(data)), false);
+	public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> iterator) {
+		LOG.debug("add input data, checkpoint id: {}, channel: {}, startSeqNum: {}", checkpointId, info, startSeqNum);
+		enqueue(write(checkpointId, info, iterator), false);
 	}
 
 	@Override
@@ -168,8 +168,13 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
 				executor.submit(request);
 			}
 		} catch (Exception e) {
-			request.cancel(e);
-			throw new RuntimeException("unable to send request to worker", e);
+			RuntimeException wrapped = new RuntimeException("unable to send request to worker", e);
+			try {
+				request.cancel(e);
+			} catch (Exception cancelException) {
+				wrapped.addSuppressed(cancelException);
+			}
+			throw wrapped;
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/NonSpanningWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/NonSpanningWrapper.java
index 5de5467..343c6f4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/NonSpanningWrapper.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/NonSpanningWrapper.java
@@ -22,17 +22,21 @@ import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer.NextRecordResponse;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.util.CloseableIterator;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.UTFDataFormatException;
 import java.nio.ByteBuffer;
-import java.util.Optional;
 
 import static org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER;
 import static org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult.LAST_RECORD_FROM_BUFFER;
 import static org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult.PARTIAL_RECORD;
 import static org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer.LENGTH_BYTES;
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.DATA_BUFFER;
 
 final class NonSpanningWrapper implements DataInputView {
 
@@ -69,13 +73,13 @@ final class NonSpanningWrapper implements DataInputView {
 		this.limit = limit;
 	}
 
-	Optional<MemorySegment> getUnconsumedSegment() {
+	CloseableIterator<Buffer> getUnconsumedSegment() {
 		if (!hasRemaining()) {
-			return Optional.empty();
+			return CloseableIterator.empty();
 		}
-		MemorySegment target = MemorySegmentFactory.allocateUnpooledSegment(remaining());
-		segment.copyTo(position, target, 0, remaining());
-		return Optional.of(target);
+		MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(remaining());
+		this.segment.copyTo(position, segment, 0, remaining());
+		return singleBufferIterator(segment);
 	}
 
 	boolean hasRemaining() {
@@ -359,4 +363,10 @@ final class NonSpanningWrapper implements DataInputView {
 		return recordLength <= remaining();
 	}
 
+	static CloseableIterator<Buffer> singleBufferIterator(MemorySegment target) {
+		return CloseableIterator.ofElement(
+			new NetworkBuffer(target, FreeingBufferRecycler.INSTANCE, DATA_BUFFER, target.size()),
+			Buffer::recycleBuffer);
+	}
+
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordDeserializer.java
index 4f4d621..07ff5ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordDeserializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/RecordDeserializer.java
@@ -20,9 +20,9 @@ package org.apache.flink.runtime.io.network.api.serialization;
 
 import org.apache.flink.core.io.IOReadableWritable;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.util.CloseableIterator;
 
 import java.io.IOException;
-import java.util.Optional;
 
 /**
  * Interface for turning sequences of memory segments into records.
@@ -71,5 +71,5 @@ public interface RecordDeserializer<T extends IOReadableWritable> {
 	 * <p>Note that the unconsumed buffer might be null if the whole buffer was already consumed
 	 * before and there are no partial length or data remained in the end of buffer.
 	 */
-	Optional<Buffer> getUnconsumedBuffer() throws IOException;
+	CloseableIterator<Buffer> getUnconsumedBuffer() throws IOException;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningWrapper.java
index 430f0db..18ea6cc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningWrapper.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningWrapper.java
@@ -23,6 +23,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputSerializer;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.StringUtils;
 
 import java.io.BufferedInputStream;
@@ -34,11 +36,11 @@ import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.channels.FileChannel;
 import java.util.Arrays;
-import java.util.Optional;
 import java.util.Random;
 
 import static java.lang.Math.max;
 import static java.lang.Math.min;
+import static org.apache.flink.runtime.io.network.api.serialization.NonSpanningWrapper.singleBufferIterator;
 import static org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer.LENGTH_BYTES;
 import static org.apache.flink.util.FileUtils.writeCompletely;
 import static org.apache.flink.util.IOUtils.closeAllQuietly;
@@ -165,15 +167,15 @@ final class SpanningWrapper {
 		}
 	}
 
-	Optional<MemorySegment> getUnconsumedSegment() throws IOException {
+	CloseableIterator<Buffer> getUnconsumedSegment() throws IOException {
 		if (isReadingLength()) {
-			return Optional.of(copyLengthBuffer());
+			return singleBufferIterator(copyLengthBuffer());
 		} else if (isAboveSpillingThreshold()) {
 			throw new UnsupportedOperationException("Unaligned checkpoint currently do not support spilled records.");
 		} else if (recordLength == -1) {
-			return Optional.empty(); // no remaining partial length or data
+			return CloseableIterator.empty(); // no remaining partial length or data
 		} else {
-			return Optional.of(copyDataBuffer());
+			return singleBufferIterator(copyDataBuffer());
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
index 75e6b0b..2d4c24c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpillingAdaptiveSpanningRecordDeserializer.java
@@ -21,18 +21,15 @@ package org.apache.flink.runtime.io.network.api.serialization;
 import org.apache.flink.core.io.IOReadableWritable;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
-import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
-import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.util.CloseableIterator;
 
 import javax.annotation.concurrent.NotThreadSafe;
 
 import java.io.IOException;
-import java.util.Optional;
 
 import static org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult.INTERMEDIATE_RECORD_FROM_BUFFER;
 import static org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult.LAST_RECORD_FROM_BUFFER;
 import static org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult.PARTIAL_RECORD;
-import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.DATA_BUFFER;
 
 /**
  * @param <T> The type of the record to be deserialized.
@@ -76,14 +73,8 @@ public class SpillingAdaptiveSpanningRecordDeserializer<T extends IOReadableWrit
 	}
 
 	@Override
-	public Optional<Buffer> getUnconsumedBuffer() throws IOException {
-		final Optional<MemorySegment> unconsumedSegment;
-		if (nonSpanningWrapper.hasRemaining()) {
-			unconsumedSegment = nonSpanningWrapper.getUnconsumedSegment();
-		} else {
-			unconsumedSegment = spanningWrapper.getUnconsumedSegment();
-		}
-		return unconsumedSegment.map(segment -> new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE, DATA_BUFFER, segment.size()));
+	public CloseableIterator<Buffer> getUnconsumedBuffer() throws IOException {
+		return nonSpanningWrapper.hasRemaining() ? nonSpanningWrapper.getUnconsumedSegment() : spanningWrapper.getUnconsumedSegment();
 	}
 
 	@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
index 6db81e9..4e1f260 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferProvider;
 import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener;
 import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.util.CloseableIterator;
 
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
@@ -207,7 +208,7 @@ public class RemoteInputChannel extends InputChannel {
 				checkpointId,
 				channelInfo,
 				ChannelStateWriter.SEQUENCE_NUMBER_UNKNOWN,
-				inflightBuffers.toArray(new Buffer[0]));
+				CloseableIterator.fromList(inflightBuffers, Buffer::recycleBuffer));
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
index f953c22..00c8ca7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
@@ -17,8 +17,13 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.util.CloseableIterator;
 
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -83,7 +88,10 @@ public class ChannelStateWriteRequestDispatcherTest {
 	}
 
 	private static ChannelStateWriteRequest writeIn() {
-		return write(CHECKPOINT_ID, new InputChannelInfo(1, 1));
+		return write(CHECKPOINT_ID, new InputChannelInfo(1, 1), CloseableIterator.ofElement(
+			new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1), FreeingBufferRecycler.INSTANCE),
+			Buffer::recycleBuffer
+		));
 	}
 
 	private static ChannelStateWriteRequest writeOut() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
index 5aad9c6..a299b34 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
@@ -30,7 +30,6 @@ import java.util.concurrent.LinkedBlockingDeque;
 
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestDispatcher.NO_OP;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
-import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
index 44552e6..92a7e88 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
@@ -35,6 +35,7 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
 
 import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
+import static org.apache.flink.util.CloseableIterator.ofElements;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertSame;
@@ -47,14 +48,16 @@ public class ChannelStateWriterImplTest {
 	private static final long CHECKPOINT_ID = 42L;
 
 	@Test(expected = IllegalArgumentException.class)
-	public void testAddEventBuffer() {
+	public void testAddEventBuffer() throws Exception {
+
 		NetworkBuffer dataBuf = getBuffer();
 		NetworkBuffer eventBuf = getBuffer();
 		eventBuf.setDataType(Buffer.DataType.EVENT_BUFFER);
-		ChannelStateWriterImpl writer = openWriter();
-		callStart(writer);
 		try {
-			writer.addInputData(CHECKPOINT_ID, new InputChannelInfo(1, 1), 1, eventBuf, dataBuf);
+			runWithSyncWorker(writer -> {
+				callStart(writer);
+				writer.addInputData(CHECKPOINT_ID, new InputChannelInfo(1, 1), 1, ofElements(Buffer::recycleBuffer, eventBuf, dataBuf));
+			});
 		} finally {
 			assertTrue(dataBuf.isRecycled());
 		}
@@ -285,7 +288,7 @@ public class ChannelStateWriterImplTest {
 	}
 
 	private void callAddInputData(ChannelStateWriter writer, NetworkBuffer... buffer) {
-		writer.addInputData(CHECKPOINT_ID, new InputChannelInfo(1, 1), 1, buffer);
+		writer.addInputData(CHECKPOINT_ID, new InputChannelInfo(1, 1), 1, ofElements(Buffer::recycleBuffer, buffer));
 	}
 
 	private void callAbort(ChannelStateWriter writer) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
index 3617b8f..556bf49 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
@@ -23,6 +23,7 @@ import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 
 /**
  * {@link CheckpointInProgressRequest} test.
@@ -41,7 +42,11 @@ public class CheckpointInProgressRequestTest {
 		Thread[] threads = new Thread[barrier.getParties()];
 		for (int i = 0; i < barrier.getParties(); i++) {
 			threads[i] = new Thread(() -> {
-				request.cancel(new RuntimeException("test"));
+				try {
+					request.cancel(new RuntimeException("test"));
+				} catch (Exception e) {
+					fail(e.getMessage());
+				}
 				await(barrier);
 			});
 		}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java
index 5dcc00c..0a61066d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/MockChannelStateWriter.java
@@ -19,6 +19,9 @@ package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.util.CloseableIterator;
+
+import static org.apache.flink.util.ExceptionUtils.rethrow;
 
 /**
  * A no op implementation that performs basic checks of the contract, but does not actually write any data.
@@ -49,10 +52,12 @@ public class MockChannelStateWriter implements ChannelStateWriter {
 	}
 
 	@Override
-	public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, Buffer... data) {
+	public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> iterator) {
 		checkCheckpointId(checkpointId);
-		for (final Buffer buffer : data) {
-			buffer.recycleBuffer();
+		try {
+			iterator.close();
+		} catch (Exception e) {
+			rethrow(e);
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java
index b53e37b..d0cfe3f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/RecordingChannelStateWriter.java
@@ -19,12 +19,15 @@ package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.util.CloseableIterator;
 
 import org.apache.flink.shaded.guava18.com.google.common.collect.LinkedListMultimap;
 import org.apache.flink.shaded.guava18.com.google.common.collect.ListMultimap;
 
 import java.util.Arrays;
 
+import static org.apache.flink.util.ExceptionUtils.rethrow;
+
 /**
  * A simple {@link ChannelStateWriter} used to write unit tests.
  */
@@ -54,9 +57,14 @@ public class RecordingChannelStateWriter extends MockChannelStateWriter {
 	}
 
 	@Override
-	public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, Buffer... data) {
+	public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> iterator) {
 		checkCheckpointId(checkpointId);
-		addedInput.putAll(info, Arrays.asList(data));
+		iterator.forEachRemaining(b -> addedInput.put(info, b));
+		try {
+			iterator.close();
+		} catch (Exception e) {
+			rethrow(e);
+		}
 	}
 
 	@Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
index 183df10..e35b311 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializationTest.java
@@ -30,6 +30,7 @@ import org.apache.flink.testutils.serialization.types.IntType;
 import org.apache.flink.testutils.serialization.types.SerializationTestType;
 import org.apache.flink.testutils.serialization.types.SerializationTestTypeFactory;
 import org.apache.flink.testutils.serialization.types.Util;
+import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Assert;
@@ -46,7 +47,6 @@ import java.nio.channels.WritableByteChannel;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Optional;
 import java.util.Random;
 
 import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.buildSingleBuffer;
@@ -293,14 +293,15 @@ public class SpanningRecordSerializationTest extends TestLogger {
 		}
 	}
 
-	private static void assertUnconsumedBuffer(ByteArrayOutputStream expected, Optional<Buffer> actual) {
-		if (!actual.isPresent()) {
+	private static void assertUnconsumedBuffer(ByteArrayOutputStream expected, CloseableIterator<Buffer> actual) throws Exception {
+		if (!actual.hasNext()) {
 			Assert.assertEquals(expected.size(), 0);
 		}
 
 		ByteBuffer expectedByteBuffer = ByteBuffer.wrap(expected.toByteArray());
-		ByteBuffer actualByteBuffer = actual.get().getNioBufferReadable();
+		ByteBuffer actualByteBuffer = actual.next().getNioBufferReadable();
 		Assert.assertEquals(expectedByteBuffer, actualByteBuffer);
+		actual.close();
 	}
 
 	private static void writeBuffer(ByteBuffer buffer, OutputStream stream) throws IOException {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index abcf563..6f49b44 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -66,6 +66,7 @@ import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.ExceptionUtils;
 
 import org.junit.Test;
@@ -74,7 +75,6 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
@@ -94,6 +94,7 @@ import static org.apache.flink.runtime.io.network.partition.InputGateFairnessTes
 import static org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannelTest.submitTasksAndWaitForResults;
 import static org.apache.flink.runtime.io.network.util.TestBufferFactory.createBuffer;
 import static org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder.createRemoteWithIdAndLocation;
+import static org.apache.flink.util.ExceptionUtils.rethrow;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
@@ -989,8 +990,15 @@ public class SingleInputGateTest extends InputGateTestBase {
 
 		inputChannel.spillInflightBuffers(0, new ChannelStateWriterImpl.NoOpChannelStateWriter() {
 			@Override
-			public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, Buffer... data) {
-				inflightBuffers.addAll(Arrays.asList(data));
+			public void addInputData(long checkpointId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> iterator) {
+				List<Buffer> list = new ArrayList<>();
+				iterator.forEachRemaining(list::add);
+				inflightBuffers.addAll(list);
+				try {
+					iterator.close();
+				} catch (Exception e) {
+					rethrow(e);
+				}
 			}
 		});
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
index 3f5e2cc..a77dbbf 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
@@ -50,6 +50,7 @@ import static java.util.Collections.singletonMap;
 import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateReader.ReadResult.NO_MORE_DATA;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.SEQUENCE_NUMBER_UNKNOWN;
+import static org.apache.flink.util.CloseableIterator.ofElements;
 import static org.apache.flink.util.Preconditions.checkState;
 import static org.junit.Assert.assertArrayEquals;
 
@@ -102,7 +103,7 @@ public class ChannelPersistenceITCase {
 			writer.open();
 			writer.start(checkpointId, new CheckpointOptions(CHECKPOINT, new CheckpointStorageLocationReference("poly".getBytes())));
 			for (Map.Entry<InputChannelInfo, Buffer> e : icBuffers.entrySet()) {
-				writer.addInputData(checkpointId, e.getKey(), SEQUENCE_NUMBER_UNKNOWN, e.getValue());
+				writer.addInputData(checkpointId, e.getKey(), SEQUENCE_NUMBER_UNKNOWN, ofElements(Buffer::recycleBuffer, e.getValue()));
 			}
 			writer.finishInput(checkpointId);
 			for (Map.Entry<ResultSubpartitionInfo, Buffer> e : rsBuffers.entrySet()) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
index f98e83f..d39accf 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
@@ -45,6 +45,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.function.Function;
 import java.util.stream.IntStream;
 
+import static org.apache.flink.util.CloseableIterator.ofElement;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -330,7 +331,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 					currentReceivedCheckpointId,
 					channelInfo,
 					ChannelStateWriter.SEQUENCE_NUMBER_UNKNOWN,
-					buffer);
+					ofElement(buffer, Buffer::recycleBuffer));
 			} else {
 				buffer.recycleBuffer();
 			}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java
index 07826c7..6723a2d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInput.java
@@ -212,12 +212,11 @@ public final class StreamTaskNetworkInput<T> implements StreamTaskInput<T> {
 			// Assumption for retrieving buffers = one concurrent checkpoint
 			RecordDeserializer<?> deserializer = recordDeserializers[channelIndex];
 			if (deserializer != null) {
-				deserializer.getUnconsumedBuffer().ifPresent(buffer ->
-					channelStateWriter.addInputData(
-						checkpointId,
-						channel.getChannelInfo(),
-						ChannelStateWriter.SEQUENCE_NUMBER_UNKNOWN,
-						buffer));
+				channelStateWriter.addInputData(
+					checkpointId,
+					channel.getChannelInfo(),
+					ChannelStateWriter.SEQUENCE_NUMBER_UNKNOWN,
+					deserializer.getUnconsumedBuffer());
 			}
 
 			checkpointedInputGate.spillInflightBuffers(checkpointId, channelIndex, channelStateWriter);