You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2020/05/28 06:15:04 UTC

[flink] 04/04: [FLINK-17928][checkpointing] Fix ChannelStateHandle size

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit a9782a2483cace80ffd06fb4fbcb990407375a4d
Author: Roman Khachatryan <kh...@gmail.com>
AuthorDate: Wed May 27 15:39:19 2020 +0200

    [FLINK-17928][checkpointing] Fix ChannelStateHandle size
    
    Store state size explicitly because underlying state
    handle may be shared.
---
 .../channel/ChannelStateCheckpointWriter.java      | 33 ++++++++++--------
 .../metadata/ChannelStateHandleSerializer.java     | 11 +++---
 .../runtime/state/AbstractChannelStateHandle.java  | 39 ++++++++++++++++++++--
 .../runtime/state/InputChannelStateHandle.java     | 10 +++++-
 .../state/ResultSubpartitionStateHandle.java       | 10 +++++-
 .../channel/ChannelStateCheckpointWriterTest.java  | 30 +++++++++++++++++
 6 files changed, 112 insertions(+), 21 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
index 97d1625..9643d4c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint.channel;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.state.AbstractChannelStateHandle;
+import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamFactory.CheckpointStateOutputStream;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
@@ -61,8 +62,8 @@ class ChannelStateCheckpointWriter {
 	private final DataOutputStream dataStream;
 	private final CheckpointStateOutputStream checkpointStream;
 	private final ChannelStateWriteResult result;
-	private final Map<InputChannelInfo, List<Long>> inputChannelOffsets = new HashMap<>();
-	private final Map<ResultSubpartitionInfo, List<Long>> resultSubpartitionOffsets = new HashMap<>();
+	private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>();
+	private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets = new HashMap<>();
 	private final ChannelStateSerializer serializer;
 	private final long checkpointId;
 	private boolean allInputsReceived = false;
@@ -115,17 +116,19 @@ class ChannelStateCheckpointWriter {
 		write(resultSubpartitionOffsets, info, flinkBuffers, !allOutputsReceived);
 	}
 
-	private <K> void write(Map<K, List<Long>> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
+	private <K> void write(Map<K, StateContentMetaInfo> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
 		try {
 			if (result.isDone()) {
 				return;
 			}
 			runWithChecks(() -> {
 				checkState(precondition);
-				offsets
-					.computeIfAbsent(key, unused -> new ArrayList<>())
-					.add(checkpointStream.getPos());
+				long offset = checkpointStream.getPos();
 				serializer.writeData(dataStream, flinkBuffers);
+				long size = checkpointStream.getPos() - offset;
+				offsets
+					.computeIfAbsent(key, unused -> new StateContentMetaInfo())
+					.withDataAdded(offset, size);
 			});
 		} finally {
 			for (Buffer flinkBuffer : flinkBuffers) {
@@ -179,10 +182,10 @@ class ChannelStateCheckpointWriter {
 	private <I, H extends AbstractChannelStateHandle<I>> void complete(
 			StreamStateHandle underlying,
 			CompletableFuture<Collection<H>> future,
-			Map<I, List<Long>> offsets,
+			Map<I, StateContentMetaInfo> offsets,
 			HandleFactory<I, H> handleFactory) throws IOException {
 		final Collection<H> handles = new ArrayList<>();
-		for (Map.Entry<I, List<Long>> e : offsets.entrySet()) {
+		for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
 			handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
 		}
 		future.complete(handles);
@@ -193,15 +196,19 @@ class ChannelStateCheckpointWriter {
 			HandleFactory<I, H> handleFactory,
 			StreamStateHandle underlying,
 			I channelInfo,
-			List<Long> offsets) throws IOException {
+			StateContentMetaInfo contentMetaInfo) throws IOException {
 		Optional<byte[]> bytes = underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and removing this method: https://issues.apache.org/jira/browse/FLINK-17972
 		if (bytes.isPresent()) {
+			StreamStateHandle extracted = new ByteStreamStateHandle(
+				randomUUID().toString(),
+				serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
 			return handleFactory.create(
 				channelInfo,
-				new ByteStreamStateHandle(randomUUID().toString(), serializer.extractAndMerge(bytes.get(), offsets)),
-				singletonList(serializer.getHeaderLength()));
+				extracted,
+				singletonList(serializer.getHeaderLength()),
+				extracted.getStateSize());
 		} else {
-			return handleFactory.create(channelInfo, underlying, offsets);
+			return handleFactory.create(channelInfo, underlying, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
 		}
 	}
 
@@ -221,7 +228,7 @@ class ChannelStateCheckpointWriter {
 	}
 
 	private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
-		H create(I info, StreamStateHandle underlying, List<Long> offsets);
+		H create(I info, StreamStateHandle underlying, List<Long> offsets, long size);
 
 		HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL = InputChannelStateHandle::new;
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/ChannelStateHandleSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/ChannelStateHandleSerializer.java
index 30a4a30..d1f79ec 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/ChannelStateHandleSerializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/ChannelStateHandleSerializer.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint.metadata;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.state.AbstractChannelStateHandle;
+import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -51,7 +52,7 @@ class ChannelStateHandleSerializer {
 
 		return deserializeChannelStateHandle(
 			is -> new ResultSubpartitionInfo(is.readInt(), is.readInt()),
-			(streamStateHandle, longs, info) -> new ResultSubpartitionStateHandle(info, streamStateHandle, longs),
+			(streamStateHandle, contentMetaInfo, info) -> new ResultSubpartitionStateHandle(info, streamStateHandle, contentMetaInfo),
 			dis,
 			context);
 	}
@@ -69,7 +70,7 @@ class ChannelStateHandleSerializer {
 
 		return deserializeChannelStateHandle(
 			is -> new InputChannelInfo(is.readInt(), is.readInt()),
-			(streamStateHandle, longs, inputChannelInfo) -> new InputChannelStateHandle(inputChannelInfo, streamStateHandle, longs),
+			(streamStateHandle, contentMetaInfo, inputChannelInfo) -> new InputChannelStateHandle(inputChannelInfo, streamStateHandle, contentMetaInfo),
 			dis,
 			context);
 	}
@@ -83,12 +84,13 @@ class ChannelStateHandleSerializer {
 		for (long offset : handle.getOffsets()) {
 			dos.writeLong(offset);
 		}
+		dos.writeLong(handle.getStateSize());
 		serializeStreamStateHandle(handle.getDelegate(), dos);
 	}
 
 	private static <Info, Handle extends AbstractChannelStateHandle<Info>> Handle deserializeChannelStateHandle(
 			FunctionWithException<DataInputStream, Info, IOException> infoReader,
-			TriFunctionWithException<StreamStateHandle, List<Long>, Info, Handle, IOException> handleBuilder,
+			TriFunctionWithException<StreamStateHandle, StateContentMetaInfo, Info, Handle, IOException> handleBuilder,
 			DataInputStream dis,
 			MetadataV2V3SerializerBase.DeserializationContext context) throws IOException {
 
@@ -98,6 +100,7 @@ class ChannelStateHandleSerializer {
 		for (int i = 0; i < offsetsSize; i++) {
 			offsets.add(dis.readLong());
 		}
-		return handleBuilder.apply(deserializeStreamStateHandle(dis, context), offsets, info);
+		final long size = dis.readLong();
+		return handleBuilder.apply(deserializeStreamStateHandle(dis, context), new StateContentMetaInfo(offsets, size), info);
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractChannelStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractChannelStateHandle.java
index 04bb757..c359d9d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractChannelStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractChannelStateHandle.java
@@ -19,6 +19,8 @@ package org.apache.flink.runtime.state;
 
 import org.apache.flink.annotation.Internal;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 
@@ -39,11 +41,13 @@ public abstract class AbstractChannelStateHandle<Info> implements StateObject {
 	 * Start offsets in a {@link org.apache.flink.core.fs.FSDataInputStream stream} {@link StreamStateHandle#openInputStream obtained} from {@link #delegate}.
 	 */
 	private final List<Long> offsets;
+	private final long size;
 
-	AbstractChannelStateHandle(StreamStateHandle delegate, List<Long> offsets, Info info) {
+	AbstractChannelStateHandle(StreamStateHandle delegate, List<Long> offsets, Info info, long size) {
 		this.info = checkNotNull(info);
 		this.delegate = checkNotNull(delegate);
 		this.offsets = checkNotNull(offsets);
+		this.size = size;
 	}
 
 	@Override
@@ -53,7 +57,7 @@ public abstract class AbstractChannelStateHandle<Info> implements StateObject {
 
 	@Override
 	public long getStateSize() {
-		return delegate.getStateSize();
+		return size; // can not rely on delegate.getStateSize because it can be shared
 	}
 
 	public List<Long> getOffsets() {
@@ -84,4 +88,35 @@ public abstract class AbstractChannelStateHandle<Info> implements StateObject {
 	public int hashCode() {
 		return Objects.hash(info, delegate, offsets);
 	}
+
+	/**
+	 * Describes the underlying content.
+	 */
+	public static class StateContentMetaInfo {
+		private final List<Long> offsets;
+		private long size = 0;
+
+		public StateContentMetaInfo() {
+			this(new ArrayList<>(), 0);
+		}
+
+		public StateContentMetaInfo(List<Long> offsets, long size) {
+			this.offsets = offsets;
+			this.size = size;
+		}
+
+		public void withDataAdded(long offset, long size) {
+			this.offsets.add(offset);
+			this.size += size;
+		}
+
+		public List<Long> getOffsets() {
+			return Collections.unmodifiableList(offsets);
+		}
+
+		public long getSize() {
+			return size;
+		}
+	}
+
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/InputChannelStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/InputChannelStateHandle.java
index 3a0f5d4..a0f0099 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/InputChannelStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/InputChannelStateHandle.java
@@ -30,7 +30,15 @@ public class InputChannelStateHandle extends AbstractChannelStateHandle<InputCha
 
 	private static final long serialVersionUID = 1L;
 
+	public InputChannelStateHandle(InputChannelInfo info, StreamStateHandle delegate, StateContentMetaInfo contentMetaInfo) {
+		this(info, delegate, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
+	}
+
 	public InputChannelStateHandle(InputChannelInfo info, StreamStateHandle delegate, List<Long> offset) {
-		super(delegate, offset, info);
+		this(info, delegate, offset, delegate.getStateSize());
+	}
+
+	public InputChannelStateHandle(InputChannelInfo info, StreamStateHandle delegate, List<Long> offset, long size) {
+		super(delegate, offset, info, size);
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ResultSubpartitionStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ResultSubpartitionStateHandle.java
index 761c35a..92fc7c6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ResultSubpartitionStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ResultSubpartitionStateHandle.java
@@ -30,7 +30,15 @@ public class ResultSubpartitionStateHandle extends AbstractChannelStateHandle<Re
 
 	private static final long serialVersionUID = 1L;
 
+	public ResultSubpartitionStateHandle(ResultSubpartitionInfo info, StreamStateHandle delegate, StateContentMetaInfo contentMetaInfo) {
+		this(info, delegate, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
+	}
+
 	public ResultSubpartitionStateHandle(ResultSubpartitionInfo info, StreamStateHandle delegate, List<Long> offset) {
-		super(delegate, offset, info);
+		this(info, delegate, offset, delegate.getStateSize());
+	}
+
+	public ResultSubpartitionStateHandle(ResultSubpartitionInfo info, StreamStateHandle delegate, List<Long> offset, long size) {
+		super(delegate, offset, info, size);
 	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
index 91b5f9f..adacf02 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
@@ -41,6 +41,7 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
+import java.util.stream.IntStream;
 
 import static java.util.Collections.singletonList;
 import static org.apache.flink.core.fs.Path.fromLocalFile;
@@ -64,6 +65,35 @@ public class ChannelStateCheckpointWriterTest {
 	public final TemporaryFolder temporaryFolder = new TemporaryFolder();
 
 	@Test
+	public void testFileHandleSize() throws Exception {
+		int numChannels = 3;
+		int numWritesPerChannel = 4;
+		int numBytesPerWrite = 5;
+		ChannelStateWriteResult result = new ChannelStateWriteResult();
+		ChannelStateCheckpointWriter writer = createWriter(
+			result,
+			new FsCheckpointStreamFactory(
+				getSharedInstance(),
+				fromLocalFile(temporaryFolder.newFolder("checkpointsDir")),
+				fromLocalFile(temporaryFolder.newFolder("sharedStateDir")),
+					numBytesPerWrite - 1,
+					numBytesPerWrite - 1).createCheckpointStateOutputStream(EXCLUSIVE));
+
+		InputChannelInfo[] channels = IntStream.range(0, numChannels).mapToObj(i -> new InputChannelInfo(0, i)).toArray(InputChannelInfo[]::new);
+		for (int call = 0; call < numWritesPerChannel; call++) {
+			for (int channel = 0; channel < numChannels; channel++) {
+				write(writer, channels[channel], getData(numBytesPerWrite));
+			}
+		}
+		writer.completeInput();
+		writer.completeOutput();
+
+		for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
+			assertEquals((Integer.BYTES + numBytesPerWrite) * numWritesPerChannel, handle.getStateSize());
+		}
+	}
+
+	@Test
 	@SuppressWarnings("ConstantConditions")
 	public void testSmallFilesNotWritten() throws Exception {
 		int threshold = 100;