You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2020/10/22 14:19:49 UTC

[GitHub] [flink] StephanEwen commented on a change in pull request #13595: [FLINK-19582][network] Introduce sort-merge based blocking shuffle to Flink

StephanEwen commented on a change in pull request #13595:
URL: https://github.com/apache/flink/pull/13595#discussion_r510200867



##########
File path: flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PartitionSortedBuffer.java
##########
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+
+import javax.annotation.concurrent.NotThreadSafe;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * A {@link SortBuffer} implementation which sorts all appended records only by subpartition index. Records of the
+ * same subpartition keep the appended order.
+ *
+ * <p>It maintains a list of {@link MemorySegment}s as a joint buffer. Data will be appended to the joint buffer
+ * sequentially. When writing a record, an index entry will be appended first. Each index entry has 4 fields: 4
+ * bytes record length, 4 bytes {@link DataType} and 8 bytes address pointing to the next index entry of the same
+ * channel which will be used to index the next record to read when coping data from this {@link SortBuffer}. For
+ * simplicity, no index entry can span multiple segments. The corresponding record data sits right after its index
+ * entry and different from the index entry, records have variable length thus may span multiple segments.
+ */
+@NotThreadSafe
+public class PartitionSortedBuffer implements SortBuffer {
+
+	/**
+	 * Size of an index entry: 4 bytes for record length, 4 bytes for data type and 8 bytes
+	 * for pointer to next entry.
+	 */
+	private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
+
+	/** A buffer pool to request memory segments from. */
+	private final BufferPool bufferPool;
+
+	/** A segment list as a joint buffer which stores all records and index entries. */
+	private final ArrayList<MemorySegment> buffers = new ArrayList<>();
+
+	/** Addresses of the first record's index entry for each subpartition. */
+	private final long[] firstIndexEntryAddresses;
+
+	/** Addresses of the last record's index entry for each subpartition. */
+	private final long[] lastIndexEntryAddresses;
+
+	/** Size of buffers requested from buffer pool. All buffers must be of the same size. */
+	private final int bufferSize;
+
+	// ----------------------------------------------------------------------------------------------
+	// Statistics and states
+	// ----------------------------------------------------------------------------------------------
+
+	/** Total number of bytes already appended to this sort buffer. */
+	private long numTotalBytes;
+
+	/** Total number of records already appended to this sort buffer. */
+	private long numTotalRecords;
+
+	/** Total number of bytes already read from this sort buffer. */
+	private long numTotalBytesRead;
+
+	/** Whether this sort buffer is finished. One can only read a finished sort buffer. */
+	private boolean isFinished;
+
+	/** Whether this sort buffer is released. A released sort buffer can not be used. */
+	private boolean isReleased;
+
+	// ----------------------------------------------------------------------------------------------
+	// For writing
+	// ----------------------------------------------------------------------------------------------
+
+	/** Array index in the segment list of the current available buffer for writing. */
+	private int writeSegmentIndex;
+
+	/** Next position in the current available buffer for writing. */
+	private int writeSegmentOffset;
+
+	// ----------------------------------------------------------------------------------------------
+	// For reading
+	// ----------------------------------------------------------------------------------------------
+
+	/** Index entry address of the current record or event to be read. */
+	private long readIndexEntryAddress;
+
+	/** Record bytes remaining after last copy, which must be read first in next copy. */
+	private int recordRemainingBytes;
+
+	/** Current available channel to read data from. */
+	private int readChannelIndex = -1;
+
+	public PartitionSortedBuffer(BufferPool bufferPool, int numSubpartitions, int bufferSize) {
+		checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is too small.");
+
+		this.bufferPool = checkNotNull(bufferPool);
+		this.bufferSize = bufferSize;
+		this.firstIndexEntryAddresses = new long[numSubpartitions];
+		this.lastIndexEntryAddresses = new long[numSubpartitions];
+
+		// initialized with -1 means the corresponding channel has no data
+		Arrays.fill(firstIndexEntryAddresses, -1L);
+		Arrays.fill(lastIndexEntryAddresses, -1L);
+	}
+
+	@Override
+	public boolean append(ByteBuffer source, int targetChannel, DataType dataType) throws IOException {
+		checkState(!isFinished, "Sort buffer is already finished.");
+		checkState(!isReleased, "Sort buffer is already released.");
+
+		int totalBytes = source.remaining();
+		if (totalBytes == 0) {
+			return true;
+		}
+
+		// return false directly if it can not allocate enough buffers for the given record
+		if (!allocateBuffersForRecord(totalBytes)) {
+			return false;
+		}
+
+		// write the index entry and record or event data
+		writeIndex(targetChannel, totalBytes, dataType);
+		writeRecord(source);
+
+		++numTotalRecords;
+		numTotalBytes += totalBytes;
+
+		return true;
+	}
+
+	private void writeIndex(int channelIndex, int numRecordBytes, Buffer.DataType dataType) {
+		MemorySegment segment = buffers.get(writeSegmentIndex);
+
+		// record length takes the high 32 bits and data type takes the low 32 bits
+		segment.putLong(writeSegmentOffset, ((long) numRecordBytes << 32) | dataType.ordinal());
+
+		// segment index takes the high 32 bits and segment offset takes the low 32 bits
+		long indexEntryAddress = ((long) writeSegmentIndex << 32) | writeSegmentOffset;
+
+		long lastIndexEntryAddress =  lastIndexEntryAddresses[channelIndex];
+		lastIndexEntryAddresses[channelIndex] = indexEntryAddress;
+
+		if (lastIndexEntryAddress >= 0) {
+			// link the previous index entry of the given channel to the new index entry
+			segment = buffers.get(getHigh32BitsFromLongAsInteger(lastIndexEntryAddress));
+			segment.putLong(getLow32BitsFromLongAsInteger(lastIndexEntryAddress) + 8, indexEntryAddress);
+		} else {
+			firstIndexEntryAddresses[channelIndex] = indexEntryAddress;
+		}
+
+		// move the write position forward so as to write the corresponding record
+		updateWriteSegmentIndexAndOffset(INDEX_ENTRY_SIZE);
+	}
+
+	private void writeRecord(ByteBuffer source) {
+		while (source.hasRemaining()) {
+			MemorySegment segment = buffers.get(writeSegmentIndex);
+			int toCopy = Math.min(bufferSize - writeSegmentOffset, source.remaining());
+			segment.put(writeSegmentOffset, source, toCopy);
+
+			// move the write position forward so as to write the remaining bytes or next record
+			updateWriteSegmentIndexAndOffset(toCopy);
+		}
+	}
+
+	private boolean allocateBuffersForRecord(int numRecordBytes) throws IOException {
+		int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
+		int availableBytes = writeSegmentIndex == buffers.size() ? 0 : bufferSize - writeSegmentOffset;
+
+		// return directly if current available bytes is adequate
+		if (availableBytes >= numBytesRequired) {
+			return true;
+		}
+
+		// skip the remaining free space if the available bytes is not enough for an index entry
+		if (availableBytes < INDEX_ENTRY_SIZE) {
+			updateWriteSegmentIndexAndOffset(availableBytes);
+			availableBytes = 0;
+		}
+
+		// allocate exactly enough buffers for the appended record
+		do {
+			MemorySegment segment = requestBufferFromPool();
+			if (segment == null) {
+				// return false if we can not allocate enough buffers for the appended record
+				return false;
+			}
+
+			assert segment.size() == bufferSize;
+			availableBytes += bufferSize;
+			buffers.add(segment);
+		} while (availableBytes < numBytesRequired);
+
+		return true;
+	}
+
+	private MemorySegment requestBufferFromPool() throws IOException {
+		try {
+			// blocking request buffers if there is still guaranteed memory
+			if (buffers.size() < bufferPool.getNumberOfRequiredMemorySegments()) {
+				return bufferPool.requestBufferBuilderBlocking().getMemorySegment();
+			}
+		} catch (InterruptedException e) {
+			throw new IOException("Interrupted while requesting buffer.");
+		}
+
+		BufferBuilder buffer = bufferPool.requestBufferBuilder();
+		return buffer != null ? buffer.getMemorySegment() : null;
+	}
+
+	private void updateWriteSegmentIndexAndOffset(int numBytes) {
+		writeSegmentOffset += numBytes;
+
+		// using the next available free buffer if the current is full
+		if (writeSegmentOffset == bufferSize) {
+			++writeSegmentIndex;
+			writeSegmentOffset = 0;
+		}
+	}
+
+	@Override
+	public BufferWithChannel copyData(MemorySegment target) {
+		checkState(hasRemaining(), "No data remaining.");
+		checkState(isFinished, "Should finish the sort buffer first before coping any data.");
+		checkState(!isReleased, "Sort buffer is already released.");
+
+		int numBytesCopied = 0;
+		DataType bufferDataType = DataType.DATA_BUFFER;
+		int channelIndex = readChannelIndex;
+
+		do {
+			int sourceSegmentIndex = getHigh32BitsFromLongAsInteger(readIndexEntryAddress);
+			int sourceSegmentOffset = getLow32BitsFromLongAsInteger(readIndexEntryAddress);
+			MemorySegment sourceSegment = buffers.get(sourceSegmentIndex);
+
+			long lengthAndDataType = sourceSegment.getLong(sourceSegmentOffset);
+			int length = getHigh32BitsFromLongAsInteger(lengthAndDataType);
+			DataType dataType = DataType.values()[getLow32BitsFromLongAsInteger(lengthAndDataType)];
+
+			// return the data read directly if the next to read is an event
+			if (dataType.isEvent() && numBytesCopied > 0) {
+				break;
+			}
+			bufferDataType = dataType;
+
+			// get the next index entry address and move the read position forward
+			long nextReadIndexEntryAddress = sourceSegment.getLong(sourceSegmentOffset + 8);
+			sourceSegmentOffset += INDEX_ENTRY_SIZE;
+
+			// allocate a temp buffer for the event if the target buffer is not big enough
+			if (bufferDataType.isEvent() && target.size() < length) {
+				target = MemorySegmentFactory.allocateUnpooledSegment(length);
+			}
+
+			numBytesCopied += copyRecordOrEvent(
+				target, numBytesCopied, sourceSegmentIndex, sourceSegmentOffset, length);
+
+			if (recordRemainingBytes == 0) {
+				// move to next channel if the current channel has been finished
+				if (readIndexEntryAddress == lastIndexEntryAddresses[channelIndex]) {
+					updateReadChannelAndIndexEntryAddress();
+					break;
+				}
+				readIndexEntryAddress = nextReadIndexEntryAddress;
+			}
+		} while (numBytesCopied < target.size() && bufferDataType.isBuffer());
+
+		numTotalBytesRead += numBytesCopied;
+		Buffer buffer = new NetworkBuffer(target, (buf) -> {}, bufferDataType, numBytesCopied);
+		return new BufferWithChannel(buffer, channelIndex);
+	}
+
+	private int copyRecordOrEvent(
+			MemorySegment targetSegment,
+			int targetSegmentOffset,
+			int sourceSegmentIndex,
+			int sourceSegmentOffset,
+			int recordLength) {
+		if (recordRemainingBytes > 0) {
+			// skip the data already read if there is remaining partial record after the previous copy
+			long position = (long) sourceSegmentOffset + (recordLength - recordRemainingBytes);
+			sourceSegmentIndex += (position / bufferSize);
+			sourceSegmentOffset = (int) (position % bufferSize);
+		} else {
+			recordRemainingBytes = recordLength;
+		}
+
+		int targetSegmentSize = targetSegment.size();
+		int numBytesToCopy = Math.min(targetSegmentSize - targetSegmentOffset, recordRemainingBytes);
+		do {
+			// move to next data buffer if all data of the current buffer has been copied
+			if (sourceSegmentOffset == bufferSize) {
+				++sourceSegmentIndex;
+				sourceSegmentOffset = 0;
+			}
+
+			int sourceRemainingBytes = Math.min(bufferSize - sourceSegmentOffset, recordRemainingBytes);
+			int numBytes = Math.min(targetSegmentSize - targetSegmentOffset, sourceRemainingBytes);
+			MemorySegment sourceSegment = buffers.get(sourceSegmentIndex);
+			sourceSegment.copyTo(sourceSegmentOffset, targetSegment, targetSegmentOffset, numBytes);
+
+			recordRemainingBytes -= numBytes;
+			targetSegmentOffset += numBytes;
+			sourceSegmentOffset += numBytes;
+		} while ((recordRemainingBytes > 0 && targetSegmentOffset < targetSegmentSize));
+
+		return numBytesToCopy;
+	}
+
+	private void updateReadChannelAndIndexEntryAddress() {
+		// skip the channels without any data
+		while (++readChannelIndex < firstIndexEntryAddresses.length) {
+			if ((readIndexEntryAddress = firstIndexEntryAddresses[readChannelIndex]) >= 0) {
+				break;
+			}
+		}
+	}
+
+	private int getHigh32BitsFromLongAsInteger(long value) {
+		return (int) (value >>> 32);
+	}
+
+	private int getLow32BitsFromLongAsInteger(long value) {
+		return (int) (value & 0xffff);

Review comment:
       This only gets the low 16 bits, not the low 32 bits.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org