You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by xt...@apache.org on 2022/07/26 01:50:17 UTC
[flink] 04/04: [FLINK-27904][runtime] Introduce HsMemoryDataManager to manage in-memory data of hybrid shuffle mode
This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 0d466603a981acfee8d7136220461f63849fbe0f
Author: Weijie Guo <re...@163.com>
AuthorDate: Mon Jul 18 16:52:00 2022 +0800
[FLINK-27904][runtime] Introduce HsMemoryDataManager to manage in-memory data of hybrid shuffle mode
This closes #20293
---
.../network/partition/hybrid/HsBufferContext.java | 128 ++++++
.../partition/hybrid/HsMemoryDataManager.java | 286 +++++++++++++
.../hybrid/HsMemoryDataManagerOperation.java | 52 +++
.../partition/hybrid/HsMemoryDataSpiller.java | 1 -
.../hybrid/HsSubpartitionMemoryDataManager.java | 471 +++++++++++++++++++++
.../partition/hybrid/HsBufferContextTest.java | 131 ++++++
.../hybrid/HsFullSpillingStrategyTest.java | 2 +-
.../partition/hybrid/HsMemoryDataManagerTest.java | 214 ++++++++++
.../hybrid/HsSelectiveSpillingStrategyTest.java | 2 +-
.../hybrid/HsSpillingStrategyUtilsTest.java | 4 +-
.../HsSubpartitionMemoryDataManagerTest.java | 427 +++++++++++++++++++
...yTestUtils.java => HybridShuffleTestUtils.java} | 21 +-
.../partition/hybrid/TestingFileDataIndex.java | 96 +++++
.../hybrid/TestingMemoryDataManagerOperation.java | 119 ++++++
.../partition/hybrid/TestingSpillingStrategy.java | 119 ++++++
15 files changed, 2066 insertions(+), 7 deletions(-)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java
new file mode 100644
index 00000000000..8feb6fd0c41
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+
+import javax.annotation.Nullable;
+
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * This class maintains the buffer's reference count and its status for hybrid shuffle mode.
+ *
+ * <p>Each buffer has three status: {@link #released}, {@link #spillStarted}, {@link #consumed}.
+ *
+ * <ul>
+ * <li>{@link #released} indicates that buffer has been released from the memory data manager, and
+ * can no longer be spilled or consumed.
+ * <li>{@link #spillStarted} indicates that spilling of the buffer has started, either completed
+ * or not.
+ * <li>{@link #consumed} indicates that buffer has been consumed by the downstream.
+ * </ul>
+ *
+ * <p>Reference count of the buffer is maintained as follows: *
+ *
+ * <ul>
+ * <li>+1 when the buffer is obtained by memory data manager (from the buffer pool), and -1 when
+ * it is released from memory data manager.
+ * <li>+1 when spilling of the buffer is tarted, and -1 when it is completed.
+ * <li>+1 when the buffer is being consumed, and -1 when consuming is completed (by the
+ * downstream).
+ * </ul>
+ *
+ * <p>Note: This class is not thread-safe.
+ */
+public class HsBufferContext {
+ private final Buffer buffer;
+
+ private final BufferIndexAndChannel bufferIndexAndChannel;
+
+ // --------------------------
+ // Buffer Status
+ // --------------------------
+ private boolean released;
+
+ private boolean spillStarted;
+
+ private boolean consumed;
+
+ @Nullable private CompletableFuture<Void> spilledFuture;
+
+ public HsBufferContext(Buffer buffer, int bufferIndex, int subpartitionId) {
+ this.bufferIndexAndChannel = new BufferIndexAndChannel(bufferIndex, subpartitionId);
+ this.buffer = buffer;
+ }
+
+ public Buffer getBuffer() {
+ return buffer;
+ }
+
+ public BufferIndexAndChannel getBufferIndexAndChannel() {
+ return bufferIndexAndChannel;
+ }
+
+ public boolean isReleased() {
+ return released;
+ }
+
+ public boolean isSpillStarted() {
+ return spillStarted;
+ }
+
+ public boolean isConsumed() {
+ return consumed;
+ }
+
+ public Optional<CompletableFuture<Void>> getSpilledFuture() {
+ return Optional.ofNullable(spilledFuture);
+ }
+
+ public void release() {
+ checkState(!released, "Release buffer repeatedly is unexpected.");
+ released = true;
+ // decrease ref count when buffer is released from memory.
+ buffer.recycleBuffer();
+ }
+
+ public void startSpilling(CompletableFuture<Void> spilledFuture) {
+ checkState(!released, "Buffer is already released.");
+ checkState(
+ !spillStarted && this.spilledFuture == null,
+ "Spill buffer repeatedly is unexpected.");
+ spillStarted = true;
+ this.spilledFuture = spilledFuture;
+ // increase ref count when buffer is decided to spill.
+ buffer.retainBuffer();
+ // decrease ref count when buffer spilling is finished.
+ spilledFuture.thenRun(buffer::recycleBuffer);
+ }
+
+ public void consumed() {
+ checkState(!released, "Buffer is already released.");
+ checkState(!consumed, "Consume buffer repeatedly is unexpected.");
+ consumed = true;
+ // increase ref count when buffer is consumed, will be decreased when downstream finish
+ // consuming.
+ buffer.retainBuffer();
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java
new file mode 100644
index 00000000000..0a7c60cf9b2
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategy.Decision;
+import org.apache.flink.util.function.SupplierWithException;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+/** This class is responsible for managing data in memory. */
+public class HsMemoryDataManager implements HsSpillingInfoProvider, HsMemoryDataManagerOperation {
+
+ private final int numSubpartitions;
+
+ private final HsSubpartitionMemoryDataManager[] subpartitionMemoryDataManagers;
+
+ private final HsMemoryDataSpiller spiller;
+
+ private final HsSpillingStrategy spillStrategy;
+
+ private final HsFileDataIndex fileDataIndex;
+
+ private final BufferPool bufferPool;
+
+ private final Lock lock;
+
+ private final AtomicInteger numRequestedBuffers = new AtomicInteger(0);
+
+ private final AtomicInteger numUnSpillBuffers = new AtomicInteger(0);
+
+ public HsMemoryDataManager(
+ int numSubpartitions,
+ int bufferSize,
+ BufferPool bufferPool,
+ HsSpillingStrategy spillStrategy,
+ HsFileDataIndex fileDataIndex,
+ FileChannel dataFileChannel) {
+ this.numSubpartitions = numSubpartitions;
+ this.bufferPool = bufferPool;
+ this.spiller = new HsMemoryDataSpiller(dataFileChannel);
+ this.spillStrategy = spillStrategy;
+ this.fileDataIndex = fileDataIndex;
+ this.subpartitionMemoryDataManagers = new HsSubpartitionMemoryDataManager[numSubpartitions];
+
+ ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
+ this.lock = readWriteLock.writeLock();
+
+ for (int subpartitionId = 0; subpartitionId < numSubpartitions; ++subpartitionId) {
+ subpartitionMemoryDataManagers[subpartitionId] =
+ new HsSubpartitionMemoryDataManager(
+ subpartitionId, bufferSize, readWriteLock.readLock(), this);
+ }
+ }
+
+ // ------------------------------------
+ // For ResultPartition
+ // ------------------------------------
+
+ /**
+ * Append record to {@link HsMemoryDataManager}, It will be managed by {@link
+ * HsSubpartitionMemoryDataManager} witch it belongs to.
+ *
+ * @param record to be managed by this class.
+ * @param targetChannel target subpartition of this record.
+ * @param dataType the type of this record. In other words, is it data or event.
+ */
+ public void append(ByteBuffer record, int targetChannel, Buffer.DataType dataType)
+ throws IOException {
+ try {
+ getSubpartitionMemoryDataManager(targetChannel).append(record, dataType);
+ } catch (InterruptedException e) {
+ throw new IOException(e);
+ }
+ }
+
+ // ------------------------------------
+ // For Spilling Strategy
+ // ------------------------------------
+
+ @Override
+ public int getPoolSize() {
+ return bufferPool.getNumBuffers();
+ }
+
+ @Override
+ public int getNumSubpartitions() {
+ return numSubpartitions;
+ }
+
+ @Override
+ public int getNumTotalRequestedBuffers() {
+ return numRequestedBuffers.get();
+ }
+
+ @Override
+ public int getNumTotalUnSpillBuffers() {
+ return numUnSpillBuffers.get();
+ }
+
+ // Write lock should be acquired before invoke this method.
+ @Override
+ public Deque<BufferIndexAndChannel> getBuffersInOrder(
+ int subpartitionId, SpillStatus spillStatus, ConsumeStatus consumeStatus) {
+ HsSubpartitionMemoryDataManager targetSubpartitionDataManager =
+ getSubpartitionMemoryDataManager(subpartitionId);
+ return targetSubpartitionDataManager.getBuffersSatisfyStatus(spillStatus, consumeStatus);
+ }
+
+ // Write lock should be acquired before invoke this method.
+ @Override
+ public List<Integer> getNextBufferIndexToConsume() {
+ // TODO implements this logical when subpartition view is implemented.
+ return Collections.emptyList();
+ }
+
+ // ------------------------------------
+ // Callback for subpartition
+ // ------------------------------------
+
+ @Override
+ public void markBufferReadableFromFile(int subpartitionId, int bufferIndex) {
+ fileDataIndex.markBufferReadable(subpartitionId, bufferIndex);
+ }
+
+ @Override
+ public BufferBuilder requestBufferFromPool() throws InterruptedException {
+ MemorySegment segment = bufferPool.requestMemorySegmentBlocking();
+ Optional<Decision> decisionOpt =
+ spillStrategy.onMemoryUsageChanged(
+ numRequestedBuffers.incrementAndGet(), getPoolSize());
+
+ handleDecision(decisionOpt);
+ return new BufferBuilder(segment, this::recycleBuffer);
+ }
+
+ @Override
+ public void onBufferConsumed(BufferIndexAndChannel consumedBuffer) {
+ Optional<Decision> decision = spillStrategy.onBufferConsumed(consumedBuffer);
+ handleDecision(decision);
+ }
+
+ @Override
+ public void onBufferFinished() {
+ Optional<Decision> decision =
+ spillStrategy.onBufferFinished(numUnSpillBuffers.incrementAndGet());
+ handleDecision(decision);
+ }
+
+ // ------------------------------------
+ // Internal Method
+ // ------------------------------------
+
+ // Attention: Do not call this method within the read lock and subpartition lock, otherwise
+ // deadlock may occur as this method maybe acquire write lock and other subpartition's lock
+ // inside.
+ private void handleDecision(
+ @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
+ Optional<Decision> decisionOpt) {
+ Decision decision =
+ decisionOpt.orElseGet(
+ () -> callWithLock(() -> spillStrategy.decideActionWithGlobalInfo(this)));
+
+ if (!decision.getBufferToSpill().isEmpty()) {
+ spillBuffers(decision.getBufferToSpill());
+ }
+ if (!decision.getBufferToRelease().isEmpty()) {
+ releaseBuffers(decision.getBufferToRelease());
+ }
+ }
+
+ /**
+ * Spill buffers for each subpartition in a decision.
+ *
+ * <p>Note that: The method should not be locked, it is the responsibility of each subpartition
+ * to maintain thread safety itself.
+ *
+ * @param toSpill All buffers that need to be spilled in a decision.
+ */
+ private void spillBuffers(Map<Integer, List<BufferIndexAndChannel>> toSpill) {
+ CompletableFuture<Void> spillingCompleteFuture = new CompletableFuture<>();
+ List<BufferWithIdentity> bufferWithIdentities = new ArrayList<>();
+ toSpill.forEach(
+ (subpartitionId, bufferIndexAndChannels) -> {
+ HsSubpartitionMemoryDataManager subpartitionDataManager =
+ getSubpartitionMemoryDataManager(subpartitionId);
+ bufferWithIdentities.addAll(
+ subpartitionDataManager.spillSubpartitionBuffers(
+ bufferIndexAndChannels, spillingCompleteFuture));
+ // decrease numUnSpillBuffers as this subpartition's buffer is spill.
+ numUnSpillBuffers.getAndAdd(-bufferIndexAndChannels.size());
+ });
+
+ spiller.spillAsync(bufferWithIdentities)
+ .thenAccept(
+ spilledBuffers -> {
+ fileDataIndex.addBuffers(spilledBuffers);
+ spillingCompleteFuture.complete(null);
+ });
+ }
+
+ /**
+ * Release buffers for each subpartition in a decision.
+ *
+ * <p>Note that: The method should not be locked, it is the responsibility of each subpartition
+ * to maintain thread safety itself.
+ *
+ * @param toRelease All buffers that need to be released in a decision.
+ */
+ private void releaseBuffers(Map<Integer, List<BufferIndexAndChannel>> toRelease) {
+ toRelease.forEach(
+ (subpartitionId, subpartitionBuffers) ->
+ getSubpartitionMemoryDataManager(subpartitionId)
+ .releaseSubpartitionBuffers(subpartitionBuffers));
+ }
+
+ private HsSubpartitionMemoryDataManager getSubpartitionMemoryDataManager(int targetChannel) {
+ return subpartitionMemoryDataManagers[targetChannel];
+ }
+
+ private void recycleBuffer(MemorySegment buffer) {
+ numRequestedBuffers.decrementAndGet();
+ bufferPool.recycle(buffer);
+ }
+
+ public <T, R extends Exception> T callWithLock(SupplierWithException<T, R> callable) throws R {
+ try {
+ lock.lock();
+ return callable.get();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ /** Integrate the buffer and dataType of next buffer. */
+ public static class BufferAndNextDataType {
+ private final Buffer buffer;
+
+ private final Buffer.DataType nextDataType;
+
+ public BufferAndNextDataType(Buffer buffer, Buffer.DataType nextDataType) {
+ this.buffer = buffer;
+ this.nextDataType = nextDataType;
+ }
+
+ public Buffer getBuffer() {
+ return buffer;
+ }
+
+ public Buffer.DataType getNextDataType() {
+ return nextDataType;
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerOperation.java
new file mode 100644
index 00000000000..d34b251a700
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerOperation.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+
+/**
+ * This interface is used by {@link HsSubpartitionMemoryDataManager} to operate {@link
+ * HsMemoryDataManager}. Spilling decision may be made and handled inside these operations.
+ */
+public interface HsMemoryDataManagerOperation {
+ /**
+ * Request buffer from buffer pool.
+ *
+ * @return requested buffer.
+ */
+ BufferBuilder requestBufferFromPool() throws InterruptedException;
+
+ /**
+ * This method is called when buffer should mark as readable in {@link HsFileDataIndex}.
+ *
+ * @param subpartitionId the subpartition that target buffer belong to.
+ * @param bufferIndex index of buffer to mark as readable.
+ */
+ void markBufferReadableFromFile(int subpartitionId, int bufferIndex);
+
+ /**
+ * This method is called when buffer is consumed.
+ *
+ * @param consumedBuffer target buffer to mark as consumed.
+ */
+ void onBufferConsumed(BufferIndexAndChannel consumedBuffer);
+
+ /** This method is called when buffer is finished. */
+ void onBufferFinished();
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java
index dd225ba6b27..be376b5e241 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java
@@ -89,7 +89,6 @@ public class HsMemoryDataSpiller implements AutoCloseable {
// complete spill future when buffers are written to disk successfully.
// note that the ownership of these buffers is transferred to the MemoryDataManager,
// which controls data's life cycle.
- // TODO update file data index and handle buffers release in future ticket.
spilledFuture.complete(spilledBuffers);
} catch (IOException exception) {
// if spilling is failed, throw exception directly to uncaughtExceptionHandler.
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java
new file mode 100644
index 00000000000..56814024911
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java
@@ -0,0 +1,471 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus;
+import org.apache.flink.util.function.SupplierWithException;
+import org.apache.flink.util.function.ThrowingRunnable;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.locks.Lock;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * This class is responsible for managing the data in a single subpartition. One {@link
+ * HsMemoryDataManager} will hold multiple {@link HsSubpartitionMemoryDataManager}.
+ */
+public class HsSubpartitionMemoryDataManager {
+ private final int targetChannel;
+
+ private final int bufferSize;
+
+ private final HsMemoryDataManagerOperation memoryDataManagerOperation;
+
+ // Not guarded by lock because it is expected only accessed from task's main thread.
+ private final Queue<BufferBuilder> unfinishedBuffers = new LinkedList<>();
+
+ // Not guarded by lock because it is expected only accessed from task's main thread.
+ private int finishedBufferIndex;
+
+ @GuardedBy("subpartitionLock")
+ private final Deque<HsBufferContext> allBuffers = new LinkedList<>();
+
+ @GuardedBy("subpartitionLock")
+ private final Deque<HsBufferContext> unConsumedBuffers = new LinkedList<>();
+
+ @GuardedBy("subpartitionLock")
+ private final Map<Integer, HsBufferContext> bufferIndexToContexts = new HashMap<>();
+
+ /** DO NOT USE DIRECTLY. Use {@link #runWithLock} or {@link #callWithLock} instead. */
+ private final Lock resultPartitionLock;
+
+ /** DO NOT USE DIRECTLY. Use {@link #runWithLock} or {@link #callWithLock} instead. */
+ private final Object subpartitionLock = new Object();
+
+ HsSubpartitionMemoryDataManager(
+ int targetChannel,
+ int bufferSize,
+ Lock resultPartitionLock,
+ HsMemoryDataManagerOperation memoryDataManagerOperation) {
+ this.targetChannel = targetChannel;
+ this.bufferSize = bufferSize;
+ this.resultPartitionLock = resultPartitionLock;
+ this.memoryDataManagerOperation = memoryDataManagerOperation;
+ }
+
+ // ------------------------------------------------------------------------
+ // Called by Consumer
+ // ------------------------------------------------------------------------
+
+ /**
+ * Check whether the head of {@link #unConsumedBuffers} is the buffer to be consumed next time.
+ * If so, return the next buffer's data type.
+ *
+ * @param nextToConsumeIndex index of the buffer to be consumed next time.
+ * @return If the head of {@link #unConsumedBuffers} is target, return the buffer's data type.
+ * Otherwise, return {@link DataType#NONE}.
+ */
+ @SuppressWarnings("FieldAccessNotGuarded")
+ // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and
+ // subpartitionLock.
+ public DataType peekNextToConsumeDataType(int nextToConsumeIndex) {
+ return callWithLock(() -> peekNextToConsumeDataTypeInternal(nextToConsumeIndex));
+ }
+
+ /**
+ * Check whether the head of {@link #unConsumedBuffers} is the buffer to be consumed. If so,
+ * return the buffer and next data type.
+ *
+ * @param toConsumeIndex index of buffer to be consumed.
+ * @return If the head of {@link #unConsumedBuffers} is target, return optional of the buffer
+ * and next data type. Otherwise, return {@link Optional#empty()}.
+ */
+ @SuppressWarnings("FieldAccessNotGuarded")
+ // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and
+ // subpartitionLock.
+ public Optional<HsMemoryDataManager.BufferAndNextDataType> consumeBuffer(int toConsumeIndex) {
+ Optional<Tuple2<HsBufferContext, DataType>> bufferAndNextDataType =
+ callWithLock(
+ () -> {
+ if (!checkFirstUnConsumedBufferIndex(toConsumeIndex)) {
+ return Optional.empty();
+ }
+
+ HsBufferContext bufferContext =
+ checkNotNull(unConsumedBuffers.pollFirst());
+ bufferContext.consumed();
+ DataType nextDataType =
+ peekNextToConsumeDataTypeInternal(toConsumeIndex + 1);
+ return Optional.of(Tuple2.of(bufferContext, nextDataType));
+ });
+
+ bufferAndNextDataType.ifPresent(
+ tuple ->
+ memoryDataManagerOperation.onBufferConsumed(
+ tuple.f0.getBufferIndexAndChannel()));
+ return bufferAndNextDataType.map(
+ tuple ->
+ new HsMemoryDataManager.BufferAndNextDataType(
+ tuple.f0.getBuffer(), tuple.f1));
+ }
+
+ // ------------------------------------------------------------------------
+ // Called by MemoryDataManager
+ // ------------------------------------------------------------------------
+
+ /**
+ * Append record to {@link HsSubpartitionMemoryDataManager}.
+ *
+ * @param record to be managed by this class.
+ * @param dataType the type of this record. In other words, is it data or event.
+ */
+ public void append(ByteBuffer record, DataType dataType) throws InterruptedException {
+ if (dataType.isEvent()) {
+ writeEvent(record, dataType);
+ } else {
+ writeRecord(record, dataType);
+ }
+ }
+
+ /**
+ * Get buffers in {@link #allBuffers} that satisfy expected {@link SpillStatus} and {@link
+ * ConsumeStatus}.
+ *
+ * @param spillStatus the status of spilling expected.
+ * @param consumeStatus the status of consuming expected.
+ * @return buffers satisfy expected status in order.
+ */
+ @SuppressWarnings("FieldAccessNotGuarded")
+ // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and
+ // subpartitionLock.
+ public Deque<BufferIndexAndChannel> getBuffersSatisfyStatus(
+ SpillStatus spillStatus, ConsumeStatus consumeStatus) {
+ return callWithLock(
+ () -> {
+ // TODO return iterator to avoid completely traversing the queue for each call.
+ Deque<BufferIndexAndChannel> targetBuffers = new ArrayDeque<>();
+ // traverse buffers in order.
+ allBuffers.forEach(
+ (bufferContext -> {
+ if (isBufferSatisfyStatus(
+ bufferContext, spillStatus, consumeStatus)) {
+ targetBuffers.add(bufferContext.getBufferIndexAndChannel());
+ }
+ }));
+ return targetBuffers;
+ });
+ }
+
+ /**
+ * Spill this subpartition's buffers in a decision.
+ *
+ * @param toSpill All buffers that need to be spilled belong to this subpartition in a decision.
+ * @param spillDoneFuture completed when spill is finished.
+ * @return {@link BufferWithIdentity}s about these spill buffers.
+ */
+ @SuppressWarnings("FieldAccessNotGuarded")
+ // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and
+ // subpartitionLock.
+ public List<BufferWithIdentity> spillSubpartitionBuffers(
+ List<BufferIndexAndChannel> toSpill, CompletableFuture<Void> spillDoneFuture) {
+ return callWithLock(
+ () ->
+ toSpill.stream()
+ .map(
+ indexAndChannel -> {
+ int bufferIndex = indexAndChannel.getBufferIndex();
+ HsBufferContext bufferContext =
+ startSpillingBuffer(
+ bufferIndex, spillDoneFuture);
+ return new BufferWithIdentity(
+ bufferContext.getBuffer(),
+ bufferIndex,
+ targetChannel);
+ })
+ .collect(Collectors.toList()));
+ }
+
+ /**
+ * Release this subpartition's buffers in a decision.
+ *
+ * @param toRelease All buffers that need to be released belong to this subpartition in a
+ * decision.
+ */
+ @SuppressWarnings("FieldAccessNotGuarded")
+ // Note that: runWithLock ensure that code block guarded by resultPartitionReadLock and
+ // subpartitionLock.
+ public void releaseSubpartitionBuffers(List<BufferIndexAndChannel> toRelease) {
+ runWithLock(
+ () ->
+ toRelease.forEach(
+ (indexAndChannel) -> {
+ int bufferIndex = indexAndChannel.getBufferIndex();
+ HsBufferContext bufferContext =
+ checkNotNull(bufferIndexToContexts.get(bufferIndex));
+ checkAndMarkBufferReadable(bufferContext);
+ releaseBuffer(bufferIndex);
+ }));
+ }
+
+ // ------------------------------------------------------------------------
+ // Internal Methods
+ // ------------------------------------------------------------------------
+
+ private void writeEvent(ByteBuffer event, DataType dataType) {
+ checkArgument(dataType.isEvent());
+
+ // each Event must take an exclusive buffer
+ finishCurrentWritingBufferIfNotEmpty();
+
+ // store Events in adhoc heap segments, for network memory efficiency
+ MemorySegment data = MemorySegmentFactory.wrap(event.array());
+ Buffer buffer =
+ new NetworkBuffer(data, FreeingBufferRecycler.INSTANCE, dataType, data.size());
+
+ HsBufferContext bufferContext =
+ new HsBufferContext(buffer, finishedBufferIndex, targetChannel);
+ addFinishedBuffer(bufferContext);
+ memoryDataManagerOperation.onBufferFinished();
+ }
+
+ private void writeRecord(ByteBuffer record, DataType dataType) throws InterruptedException {
+ checkArgument(!dataType.isEvent());
+
+ ensureCapacityForRecord(record);
+
+ writeRecord(record);
+ }
+
+ private void ensureCapacityForRecord(ByteBuffer record) throws InterruptedException {
+ final int numRecordBytes = record.remaining();
+ int availableBytes =
+ Optional.ofNullable(unfinishedBuffers.peek())
+ .map(
+ currentWritingBuffer ->
+ currentWritingBuffer.getWritableBytes()
+ + bufferSize * (unfinishedBuffers.size() - 1))
+ .orElse(0);
+
+ while (availableBytes < numRecordBytes) {
+ // request unfinished buffer.
+ BufferBuilder bufferBuilder = memoryDataManagerOperation.requestBufferFromPool();
+ unfinishedBuffers.add(bufferBuilder);
+ availableBytes += bufferSize;
+ }
+ }
+
+ private void writeRecord(ByteBuffer record) {
+ while (record.hasRemaining()) {
+ BufferBuilder currentWritingBuffer =
+ checkNotNull(
+ unfinishedBuffers.peek(), "Expect enough capacity for the record.");
+ currentWritingBuffer.append(record);
+
+ if (currentWritingBuffer.isFull()) {
+ finishCurrentWritingBuffer();
+ }
+ }
+ }
+
+ private void finishCurrentWritingBufferIfNotEmpty() {
+ BufferBuilder currentWritingBuffer = unfinishedBuffers.peek();
+ if (currentWritingBuffer == null || currentWritingBuffer.getWritableBytes() == bufferSize) {
+ return;
+ }
+
+ finishCurrentWritingBuffer();
+ }
+
+ private void finishCurrentWritingBuffer() {
+ BufferBuilder currentWritingBuffer = unfinishedBuffers.poll();
+
+ if (currentWritingBuffer == null) {
+ return;
+ }
+
+ currentWritingBuffer.finish();
+ BufferConsumer bufferConsumer = currentWritingBuffer.createBufferConsumerFromBeginning();
+ Buffer buffer = bufferConsumer.build();
+ currentWritingBuffer.close();
+ bufferConsumer.close();
+
+ HsBufferContext bufferContext =
+ new HsBufferContext(buffer, finishedBufferIndex, targetChannel);
+ addFinishedBuffer(bufferContext);
+ memoryDataManagerOperation.onBufferFinished();
+ }
+
+ @SuppressWarnings("FieldAccessNotGuarded")
+ // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and
+ // subpartitionLock.
+ private void addFinishedBuffer(HsBufferContext bufferContext) {
+ finishedBufferIndex++;
+ boolean needNotify =
+ callWithLock(
+ () -> {
+ allBuffers.add(bufferContext);
+ unConsumedBuffers.add(bufferContext);
+ bufferIndexToContexts.put(
+ bufferContext.getBufferIndexAndChannel().getBufferIndex(),
+ bufferContext);
+ trimHeadingReleasedBuffers(unConsumedBuffers);
+ return unConsumedBuffers.isEmpty();
+ });
+ if (needNotify) {
+ // TODO notify data available, the notification mechanism may need further
+ // consideration.
+ }
+ }
+
+ @GuardedBy("subpartitionLock")
+ private DataType peekNextToConsumeDataTypeInternal(int nextToConsumeIndex) {
+ return checkFirstUnConsumedBufferIndex(nextToConsumeIndex)
+ ? checkNotNull(unConsumedBuffers.peekFirst()).getBuffer().getDataType()
+ : DataType.NONE;
+ }
+
+ @GuardedBy("subpartitionLock")
+ private boolean checkFirstUnConsumedBufferIndex(int expectedBufferIndex) {
+ trimHeadingReleasedBuffers(unConsumedBuffers);
+ return !unConsumedBuffers.isEmpty()
+ && unConsumedBuffers.peekFirst().getBufferIndexAndChannel().getBufferIndex()
+ == expectedBufferIndex;
+ }
+
+ /**
+ * Remove all released buffer from head of queue until buffer queue is empty or meet un-released
+ * buffer.
+ */
+ @GuardedBy("subpartitionLock")
+ private void trimHeadingReleasedBuffers(Deque<HsBufferContext> bufferQueue) {
+ while (!bufferQueue.isEmpty() && bufferQueue.peekFirst().isReleased()) {
+ bufferQueue.removeFirst();
+ }
+ }
+
+ @GuardedBy("subpartitionLock")
+ private void releaseBuffer(int bufferIndex) {
+ HsBufferContext bufferContext = checkNotNull(bufferIndexToContexts.remove(bufferIndex));
+ bufferContext.release();
+ // remove released buffers from head lazy.
+ trimHeadingReleasedBuffers(allBuffers);
+ }
+
+ @GuardedBy("subpartitionLock")
+ private HsBufferContext startSpillingBuffer(
+ int bufferIndex, CompletableFuture<Void> spillFuture) {
+ HsBufferContext bufferContext = checkNotNull(bufferIndexToContexts.get(bufferIndex));
+ bufferContext.startSpilling(spillFuture);
+ return bufferContext;
+ }
+
+ @GuardedBy("subpartitionLock")
+ private void checkAndMarkBufferReadable(HsBufferContext bufferContext) {
+ // only spill and not consumed buffer needs to be marked as readable.
+ if (isBufferSatisfyStatus(bufferContext, SpillStatus.SPILL, ConsumeStatus.NOT_CONSUMED)) {
+ bufferContext
+ .getSpilledFuture()
+ .orElseThrow(
+ () ->
+ new IllegalStateException(
+ "Buffer in spill status should already set spilled future."))
+ .thenRun(
+ () -> {
+ BufferIndexAndChannel bufferIndexAndChannel =
+ bufferContext.getBufferIndexAndChannel();
+ memoryDataManagerOperation.markBufferReadableFromFile(
+ bufferIndexAndChannel.getChannel(),
+ bufferIndexAndChannel.getBufferIndex());
+ });
+ }
+ }
+
+ @GuardedBy("subpartitionLock")
+ private boolean isBufferSatisfyStatus(
+ HsBufferContext bufferContext, SpillStatus spillStatus, ConsumeStatus consumeStatus) {
+ // released buffer is not needed.
+ if (bufferContext.isReleased()) {
+ return false;
+ }
+ boolean match = true;
+ switch (spillStatus) {
+ case NOT_SPILL:
+ match = !bufferContext.isSpillStarted();
+ break;
+ case SPILL:
+ match = bufferContext.isSpillStarted();
+ break;
+ }
+ switch (consumeStatus) {
+ case NOT_CONSUMED:
+ match &= !bufferContext.isConsumed();
+ break;
+ case CONSUMED:
+ match &= bufferContext.isConsumed();
+ break;
+ }
+ return match;
+ }
+
+ private <E extends Exception> void runWithLock(ThrowingRunnable<E> runnable) throws E {
+ try {
+ resultPartitionLock.lock();
+ synchronized (subpartitionLock) {
+ runnable.run();
+ }
+ } finally {
+ resultPartitionLock.unlock();
+ }
+ }
+
+ private <R, E extends Exception> R callWithLock(SupplierWithException<R, E> callable) throws E {
+ try {
+ resultPartitionLock.lock();
+ synchronized (subpartitionLock) {
+ return callable.get();
+ }
+ } finally {
+ resultPartitionLock.unlock();
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java
new file mode 100644
index 00000000000..a16b811b68f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.CompletableFuture;
+
+import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBuffer;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link HsBufferContext}. */
+class HsBufferContextTest {
+ private static final int BUFFER_SIZE = 16;
+
+ private static final int SUBPARTITION_ID = 0;
+
+ private static final int BUFFER_INDEX = 0;
+
+ private HsBufferContext bufferContext;
+
+ @BeforeEach
+ void before() {
+ bufferContext = createBufferContext();
+ }
+
+ @Test
+ void testBufferStartSpillingRefCount() {
+ Buffer buffer = bufferContext.getBuffer();
+ CompletableFuture<Void> spilledFuture = new CompletableFuture<>();
+ bufferContext.startSpilling(spilledFuture);
+ assertThat(bufferContext.isSpillStarted()).isTrue();
+ assertThat(buffer.refCnt()).isEqualTo(2);
+ spilledFuture.complete(null);
+ assertThat(buffer.refCnt()).isEqualTo(1);
+ }
+
+ @Test
+ void testBufferStartSpillingRepeatedly() {
+ bufferContext.startSpilling(new CompletableFuture<>());
+ assertThatThrownBy(() -> bufferContext.startSpilling(new CompletableFuture<>()))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("Spill buffer repeatedly is unexpected.");
+ }
+
+ @Test
+ void testBufferReleaseRefCount() {
+ Buffer buffer = bufferContext.getBuffer();
+ assertThat(buffer.refCnt()).isEqualTo(1);
+ bufferContext.release();
+ assertThat(bufferContext.isReleased()).isTrue();
+ assertThat(buffer.isRecycled()).isTrue();
+ }
+
+ @Test
+ void testBufferReleaseRepeatedly() {
+ bufferContext.release();
+ assertThatThrownBy(() -> bufferContext.release())
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("Release buffer repeatedly is unexpected.");
+ }
+
+ @Test
+ void testBufferConsumed() {
+ Buffer buffer = bufferContext.getBuffer();
+ bufferContext.consumed();
+ assertThat(bufferContext.isConsumed()).isTrue();
+ assertThat(buffer.refCnt()).isEqualTo(2);
+ }
+
+ @Test
+ void testBufferConsumedRepeatedly() {
+ bufferContext.consumed();
+ assertThatThrownBy(() -> bufferContext.consumed())
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("Consume buffer repeatedly is unexpected.");
+ }
+
+ @Test
+ void testBufferStartSpillOrConsumedAfterReleased() {
+ bufferContext.release();
+ assertThatThrownBy(() -> bufferContext.startSpilling(new CompletableFuture<>()))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("Buffer is already released.");
+ assertThatThrownBy(() -> bufferContext.consumed())
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessageContaining("Buffer is already released.");
+ }
+
+ @Test
+ void testBufferStartSpillingThenRelease() {
+ Buffer buffer = bufferContext.getBuffer();
+ CompletableFuture<Void> spilledFuture = new CompletableFuture<>();
+ bufferContext.startSpilling(spilledFuture);
+ bufferContext.release();
+ spilledFuture.complete(null);
+ assertThat(buffer.isRecycled()).isTrue();
+ }
+
+ @Test
+ void testBufferConsumedThenRelease() {
+ Buffer buffer = bufferContext.getBuffer();
+ bufferContext.consumed();
+ bufferContext.release();
+ assertThat(buffer.refCnt()).isEqualTo(1);
+ }
+
+ private static HsBufferContext createBufferContext() {
+ return new HsBufferContext(createBuffer(BUFFER_SIZE, false), BUFFER_INDEX, SUBPARTITION_ID);
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java
index 2f7674aa4af..b8d1289efb8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java
@@ -29,7 +29,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
-import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList;
+import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerTest.java
new file mode 100644
index 00000000000..6608be9c500
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsFileDataIndex.SpilledBuffer;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategy.Decision;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link HsMemoryDataManager}. */
+class HsMemoryDataManagerTest {
+ private static final int NUM_BUFFERS = 10;
+
+ private static final int NUM_SUBPARTITIONS = 3;
+
+ private int poolSize = 10;
+
+ private int bufferSize = Integer.BYTES;
+
+ private FileChannel dataFileChannel;
+
+ @BeforeEach
+ void before(@TempDir Path tempDir) throws Exception {
+ Path dataPath = Files.createFile(tempDir.resolve(".data"));
+ dataFileChannel = FileChannel.open(dataPath, StandardOpenOption.WRITE);
+ }
+
+ @Test
+ void testAppendMarkBufferFinished() throws Exception {
+ AtomicInteger finishedBuffers = new AtomicInteger(0);
+ HsSpillingStrategy spillingStrategy =
+ TestingSpillingStrategy.builder()
+ .setOnBufferFinishedFunction(
+ (numTotalUnSpillBuffers) -> {
+ finishedBuffers.incrementAndGet();
+ return Optional.of(Decision.NO_ACTION);
+ })
+ .build();
+ bufferSize = Integer.BYTES * 3;
+ HsMemoryDataManager memoryDataManager = createMemoryDataManager(spillingStrategy);
+
+ memoryDataManager.append(createRecord(0), 0, Buffer.DataType.DATA_BUFFER);
+ memoryDataManager.append(createRecord(1), 0, Buffer.DataType.DATA_BUFFER);
+ assertThat(finishedBuffers).hasValue(0);
+
+ memoryDataManager.append(createRecord(2), 0, Buffer.DataType.DATA_BUFFER);
+ assertThat(finishedBuffers).hasValue(1);
+
+ memoryDataManager.append(createRecord(3), 0, Buffer.DataType.DATA_BUFFER);
+ memoryDataManager.append(createRecord(4), 0, Buffer.DataType.EVENT_BUFFER);
+ assertThat(finishedBuffers).hasValue(3);
+ }
+
+ @Test
+ void testAppendRequestBuffer() throws Exception {
+ poolSize = 3;
+ List<Tuple2<Integer, Integer>> numFinishedBufferAndPoolSize = new ArrayList<>();
+ HsSpillingStrategy spillingStrategy =
+ TestingSpillingStrategy.builder()
+ .setOnMemoryUsageChangedFunction(
+ (finishedBuffer, poolSize) -> {
+ numFinishedBufferAndPoolSize.add(
+ Tuple2.of(finishedBuffer, poolSize));
+ return Optional.of(Decision.NO_ACTION);
+ })
+ .build();
+ HsMemoryDataManager memoryDataManager = createMemoryDataManager(spillingStrategy);
+ memoryDataManager.append(createRecord(0), 0, Buffer.DataType.DATA_BUFFER);
+ memoryDataManager.append(createRecord(1), 1, Buffer.DataType.DATA_BUFFER);
+ memoryDataManager.append(createRecord(2), 2, Buffer.DataType.DATA_BUFFER);
+ assertThat(memoryDataManager.getNumTotalRequestedBuffers()).isEqualTo(3);
+ List<Tuple2<Integer, Integer>> expectedFinishedBufferAndPoolSize =
+ Arrays.asList(Tuple2.of(1, 3), Tuple2.of(2, 3), Tuple2.of(3, 3));
+ assertThat(numFinishedBufferAndPoolSize).isEqualTo(expectedFinishedBufferAndPoolSize);
+ }
+
+ @Test
+ void testHandleDecision() throws Exception {
+ final int targetSubpartition = 0;
+ final int numFinishedBufferToTriggerDecision = 4;
+ List<BufferIndexAndChannel> toSpill =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(
+ targetSubpartition, 0, 1, 2);
+ List<BufferIndexAndChannel> toRelease =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(targetSubpartition, 2, 3);
+ HsSpillingStrategy spillingStrategy =
+ TestingSpillingStrategy.builder()
+ .setOnBufferFinishedFunction(
+ (numFinishedBuffers) -> {
+ if (numFinishedBuffers < numFinishedBufferToTriggerDecision) {
+ return Optional.of(Decision.NO_ACTION);
+ }
+ return Optional.of(
+ Decision.builder()
+ .addBufferToSpill(targetSubpartition, toSpill)
+ .addBufferToRelease(
+ targetSubpartition, toRelease)
+ .build());
+ })
+ .build();
+ CompletableFuture<List<SpilledBuffer>> spilledFuture = new CompletableFuture<>();
+ CompletableFuture<Integer> readableFuture = new CompletableFuture<>();
+ TestingFileDataIndex dataIndex =
+ TestingFileDataIndex.builder()
+ .setAddBuffersConsumer(spilledFuture::complete)
+ .setMarkBufferReadableConsumer(
+ (subpartitionId, bufferIndex) ->
+ readableFuture.complete(bufferIndex))
+ .build();
+ HsMemoryDataManager memoryDataManager =
+ createMemoryDataManager(spillingStrategy, dataIndex);
+ for (int i = 0; i < 4; i++) {
+ memoryDataManager.append(
+ createRecord(i), targetSubpartition, Buffer.DataType.DATA_BUFFER);
+ }
+
+ assertThat(spilledFuture).succeedsWithin(10, TimeUnit.SECONDS);
+ assertThat(readableFuture).succeedsWithin(10, TimeUnit.SECONDS);
+ assertThat(readableFuture).isCompletedWithValue(2);
+ assertThat(memoryDataManager.getNumTotalUnSpillBuffers()).isEqualTo(1);
+ }
+
+ @Test
+ void testHandleEmptyDecision() throws Exception {
+ CompletableFuture<Void> globalDecisionFuture = new CompletableFuture<>();
+ HsSpillingStrategy spillingStrategy =
+ TestingSpillingStrategy.builder()
+ .setOnBufferFinishedFunction(
+ (finishedBuffer) -> {
+ // return empty optional to trigger global decision.
+ return Optional.empty();
+ })
+ .setDecideActionWithGlobalInfoFunction(
+ (provider) -> {
+ globalDecisionFuture.complete(null);
+ return Decision.NO_ACTION;
+ })
+ .build();
+ HsMemoryDataManager memoryDataManager = createMemoryDataManager(spillingStrategy);
+ // trigger an empty decision.
+ memoryDataManager.onBufferFinished();
+ assertThat(globalDecisionFuture).isCompleted();
+ }
+
+ private HsMemoryDataManager createMemoryDataManager(HsSpillingStrategy spillStrategy)
+ throws Exception {
+ NetworkBufferPool networkBufferPool = new NetworkBufferPool(NUM_BUFFERS, bufferSize);
+ BufferPool bufferPool = networkBufferPool.createBufferPool(poolSize, poolSize);
+ return new HsMemoryDataManager(
+ NUM_SUBPARTITIONS,
+ bufferSize,
+ bufferPool,
+ spillStrategy,
+ new HsFileDataIndexImpl(NUM_SUBPARTITIONS),
+ dataFileChannel);
+ }
+
+ private HsMemoryDataManager createMemoryDataManager(
+ HsSpillingStrategy spillStrategy, HsFileDataIndex fileDataIndex) throws Exception {
+ NetworkBufferPool networkBufferPool = new NetworkBufferPool(NUM_BUFFERS, bufferSize);
+ BufferPool bufferPool = networkBufferPool.createBufferPool(poolSize, poolSize);
+ return new HsMemoryDataManager(
+ NUM_SUBPARTITIONS,
+ bufferSize,
+ bufferPool,
+ spillStrategy,
+ fileDataIndex,
+ dataFileChannel);
+ }
+
+ private static ByteBuffer createRecord(int value) {
+ ByteBuffer byteBuffer = ByteBuffer.allocate(Integer.BYTES);
+ byteBuffer.putInt(value);
+ byteBuffer.flip();
+ return byteBuffer;
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java
index 9e7d7254dba..6862c4c6e4a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java
@@ -29,7 +29,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
-import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList;
+import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsList;
import static org.assertj.core.api.Assertions.assertThat;
/** Tests for {@link HsSelectiveSpillingStrategy}. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java
index 98d3ab7907f..b5849a48644 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java
@@ -26,8 +26,8 @@ import java.util.Deque;
import java.util.List;
import java.util.TreeMap;
-import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsDeque;
-import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList;
+import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsDeque;
+import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsList;
import static org.assertj.core.api.Assertions.assertThat;
/** Tests for {@link HsSpillingStrategyUtils}. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java
new file mode 100644
index 00000000000..a15dfddc9f2
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsMemoryDataManager.BufferAndNextDataType;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus;
+import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus;
+
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferBuilder;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link HsSubpartitionMemoryDataManager}. */
+class HsSubpartitionMemoryDataManagerTest {
+ private static final int SUBPARTITION_ID = 0;
+
+ private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+
+ private static final int RECORD_SIZE = Integer.BYTES;
+
+ private int bufferSize = RECORD_SIZE;
+
+ @Test
+ void testAppendDataRequestBuffer() throws Exception {
+ CompletableFuture<Void> requestBufferFuture = new CompletableFuture<>();
+ HsMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(
+ () -> {
+ requestBufferFuture.complete(null);
+ return createBufferBuilder(bufferSize);
+ })
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+ assertThat(requestBufferFuture).isCompleted();
+ }
+
+ @Test
+ void testAppendEventNotRequestBuffer() throws Exception {
+ CompletableFuture<Void> requestBufferFuture = new CompletableFuture<>();
+ HsMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(
+ () -> {
+ requestBufferFuture.complete(null);
+ return null;
+ })
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.EVENT_BUFFER);
+ assertThat(requestBufferFuture).isNotDone();
+ }
+
+ @Test
+ void testAppendEventFinishCurrentBuffer() throws Exception {
+ bufferSize = RECORD_SIZE * 3;
+ AtomicInteger finishedBuffers = new AtomicInteger(0);
+ HsMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(bufferSize))
+ .setOnBufferFinishedRunnable(finishedBuffers::incrementAndGet)
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER);
+ assertThat(finishedBuffers).hasValue(0);
+ subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER);
+ assertThat(finishedBuffers).hasValue(2);
+ }
+
+ @Test
+ void testPeekNextToConsumeDataTypeNotMeetBufferIndexToConsume() throws Exception {
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+
+ assertThat(subpartitionMemoryDataManager.peekNextToConsumeDataType(1))
+ .isEqualTo(DataType.NONE);
+ }
+
+ @Test
+ void testPeekNextToConsumeDataTypeTrimHeadingReleasedBuffers() throws Exception {
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER);
+
+ List<BufferIndexAndChannel> toRelease =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1);
+ subpartitionMemoryDataManager.releaseSubpartitionBuffers(toRelease);
+
+ assertThat(subpartitionMemoryDataManager.peekNextToConsumeDataType(2))
+ .isEqualTo(DataType.EVENT_BUFFER);
+ }
+
+ @Test
+ void testConsumeBufferFirstUnConsumedBufferIndexNotMeetNextToConsume() throws Exception {
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+
+ assertThat(subpartitionMemoryDataManager.consumeBuffer(1)).isNotPresent();
+ }
+
+ @Test
+ void testConsumeBufferTrimHeadingReleasedBuffers() throws Exception {
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER);
+
+ List<BufferIndexAndChannel> toRelease =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1);
+ subpartitionMemoryDataManager.releaseSubpartitionBuffers(toRelease);
+
+ assertThat(subpartitionMemoryDataManager.consumeBuffer(2)).isPresent();
+ }
+
+ @Test
+ void testConsumeBuffer() throws Exception {
+ List<BufferIndexAndChannel> consumedBufferIndexAndChannel = new ArrayList<>();
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .setOnBufferConsumedConsumer(consumedBufferIndexAndChannel::add)
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+
+ subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER);
+ subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER);
+
+ List<Tuple2<Integer, Buffer.DataType>> expectedRecords = new ArrayList<>();
+ expectedRecords.add(Tuple2.of(0, Buffer.DataType.DATA_BUFFER));
+ expectedRecords.add(Tuple2.of(1, Buffer.DataType.DATA_BUFFER));
+ expectedRecords.add(Tuple2.of(2, DataType.EVENT_BUFFER));
+ checkConsumedBufferAndNextDataType(
+ expectedRecords,
+ Arrays.asList(
+ subpartitionMemoryDataManager.consumeBuffer(0),
+ subpartitionMemoryDataManager.consumeBuffer(1),
+ subpartitionMemoryDataManager.consumeBuffer(2)));
+
+ List<BufferIndexAndChannel> expectedBufferIndexAndChannel =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1, 2);
+ assertThat(consumedBufferIndexAndChannel)
+ .zipSatisfy(
+ expectedBufferIndexAndChannel,
+ (consumed, expected) -> {
+ assertThat(consumed.getChannel()).isEqualTo(expected.getChannel());
+ assertThat(consumed.getBufferIndex())
+ .isEqualTo(expected.getBufferIndex());
+ });
+ }
+
+ @Test
+ void testGetBuffersSatisfyStatus() throws Exception {
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+ final int numBuffers = 4;
+ for (int i = 0; i < numBuffers; i++) {
+ subpartitionMemoryDataManager.append(createRecord(i), DataType.DATA_BUFFER);
+ }
+
+ // spill buffer 1 and 2
+ List<BufferIndexAndChannel> toStartSpilling =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 1, 2);
+ CompletableFuture<Void> spilledDoneFuture = new CompletableFuture<>();
+ subpartitionMemoryDataManager.spillSubpartitionBuffers(toStartSpilling, spilledDoneFuture);
+
+ // consume buffer 0, 1
+ subpartitionMemoryDataManager.consumeBuffer(0);
+ subpartitionMemoryDataManager.consumeBuffer(1);
+
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.ALL, ConsumeStatus.ALL),
+ Arrays.asList(0, 1, 2, 3));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.ALL, ConsumeStatus.CONSUMED),
+ Arrays.asList(0, 1));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.ALL, ConsumeStatus.NOT_CONSUMED),
+ Arrays.asList(2, 3));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.SPILL, ConsumeStatus.ALL),
+ Arrays.asList(1, 2));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.NOT_SPILL, ConsumeStatus.ALL),
+ Arrays.asList(0, 3));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.SPILL, ConsumeStatus.NOT_CONSUMED),
+ Collections.singletonList(2));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.SPILL, ConsumeStatus.CONSUMED),
+ Collections.singletonList(1));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.NOT_SPILL, ConsumeStatus.CONSUMED),
+ Collections.singletonList(0));
+ checkBufferIndex(
+ subpartitionMemoryDataManager.getBuffersSatisfyStatus(
+ SpillStatus.NOT_SPILL, ConsumeStatus.NOT_CONSUMED),
+ Collections.singletonList(3));
+ }
+
+ @Test
+ void testSpillSubpartitionBuffers() throws Exception {
+ CompletableFuture<Void> spilledDoneFuture = new CompletableFuture<>();
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE))
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+ final int numBuffers = 3;
+ for (int i = 0; i < numBuffers; i++) {
+ subpartitionMemoryDataManager.append(createRecord(i), DataType.DATA_BUFFER);
+ }
+
+ List<BufferIndexAndChannel> toStartSpilling =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1, 2);
+ List<BufferWithIdentity> buffers =
+ subpartitionMemoryDataManager.spillSubpartitionBuffers(
+ toStartSpilling, spilledDoneFuture);
+ assertThat(toStartSpilling)
+ .zipSatisfy(
+ buffers,
+ (expected, spilled) -> {
+ assertThat(expected.getBufferIndex())
+ .isEqualTo(spilled.getBufferIndex());
+ assertThat(expected.getChannel()).isEqualTo(spilled.getChannelIndex());
+ });
+ List<Integer> expectedValues = Arrays.asList(0, 1, 2);
+ checkBuffersRefCountAndValue(buffers, Arrays.asList(2, 2, 2), expectedValues);
+ spilledDoneFuture.complete(null);
+ checkBuffersRefCountAndValue(buffers, Arrays.asList(1, 1, 1), expectedValues);
+ }
+
+ @Test
+ void testReleaseAndMarkReadableSubpartitionBuffers() throws Exception {
+ int targetChannel = 0;
+ List<Integer> readableBufferIndex = new ArrayList<>();
+ List<MemorySegment> recycledBuffers = new ArrayList<>();
+ TestingMemoryDataManagerOperation memoryDataManagerOperation =
+ TestingMemoryDataManagerOperation.builder()
+ .setRequestBufferFromPoolSupplier(
+ () ->
+ new BufferBuilder(
+ MemorySegmentFactory.allocateUnpooledSegment(
+ bufferSize),
+ recycledBuffers::add))
+ .setMarkBufferReadableConsumer(
+ (channel, bufferIndex) -> {
+ assertThat(channel).isEqualTo(targetChannel);
+ readableBufferIndex.add(bufferIndex);
+ })
+ .build();
+ HsSubpartitionMemoryDataManager subpartitionMemoryDataManager =
+ createSubpartitionMemoryDataManager(memoryDataManagerOperation);
+ // append data
+ final int numBuffers = 3;
+ for (int i = 0; i < numBuffers; i++) {
+ subpartitionMemoryDataManager.append(createRecord(i), DataType.DATA_BUFFER);
+ }
+ // spill the last buffer and release all buffers.
+ List<BufferIndexAndChannel> toRelease =
+ HybridShuffleTestUtils.createBufferIndexAndChannelsList(targetChannel, 0, 1, 2);
+ CompletableFuture<Void> spilledFuture = new CompletableFuture<>();
+ subpartitionMemoryDataManager.spillSubpartitionBuffers(
+ toRelease.subList(numBuffers - 1, numBuffers), spilledFuture);
+ subpartitionMemoryDataManager.releaseSubpartitionBuffers(toRelease);
+ assertThat(readableBufferIndex).isEmpty();
+ // not start spilling buffers should be recycled after release.
+ checkMemorySegmentValue(recycledBuffers, Arrays.asList(0, 1));
+
+ // after spill finished, need mark readable buffers should trigger notify.
+ spilledFuture.complete(null);
+ assertThat(readableBufferIndex).containsExactly(2);
+ checkMemorySegmentValue(recycledBuffers, Arrays.asList(0, 1, 2));
+ }
+
+ private static void checkBufferIndex(
+ Deque<BufferIndexAndChannel> bufferWithIdentities, List<Integer> expectedIndexes) {
+ List<Integer> bufferIndexes =
+ bufferWithIdentities.stream()
+ .map(BufferIndexAndChannel::getBufferIndex)
+ .collect(Collectors.toList());
+ assertThat(bufferIndexes).isEqualTo(expectedIndexes);
+ }
+
+ private static void checkMemorySegmentValue(
+ List<MemorySegment> memorySegments, List<Integer> expectedValues) {
+ for (int i = 0; i < memorySegments.size(); i++) {
+ assertThat(memorySegments.get(i).getInt(0)).isEqualTo(expectedValues.get(i));
+ }
+ }
+
+ private static void checkConsumedBufferAndNextDataType(
+ List<Tuple2<Integer, Buffer.DataType>> expectedRecords,
+ List<Optional<BufferAndNextDataType>> bufferAndNextDataTypesOpt) {
+ checkArgument(expectedRecords.size() == bufferAndNextDataTypesOpt.size());
+ for (int i = 0; i < bufferAndNextDataTypesOpt.size(); i++) {
+ final int index = i;
+ assertThat(bufferAndNextDataTypesOpt.get(index))
+ .hasValueSatisfying(
+ (bufferAndNextDataType -> {
+ Buffer buffer = bufferAndNextDataType.getBuffer();
+ int value =
+ buffer.getNioBufferReadable()
+ .order(ByteOrder.LITTLE_ENDIAN)
+ .getInt();
+ Buffer.DataType dataType = buffer.getDataType();
+ assertThat(value).isEqualTo(expectedRecords.get(index).f0);
+ assertThat(dataType).isEqualTo(expectedRecords.get(index).f1);
+ if (index != bufferAndNextDataTypesOpt.size() - 1) {
+ assertThat(bufferAndNextDataType.getNextDataType())
+ .isEqualTo(expectedRecords.get(index + 1).f1);
+ } else {
+ assertThat(bufferAndNextDataType.getNextDataType())
+ .isEqualTo(Buffer.DataType.NONE);
+ }
+ }));
+ }
+ }
+
+ private static void checkBuffersRefCountAndValue(
+ List<BufferWithIdentity> bufferWithIdentities,
+ List<Integer> expectedRefCounts,
+ List<Integer> expectedValues) {
+ for (int i = 0; i < bufferWithIdentities.size(); i++) {
+ BufferWithIdentity bufferWithIdentity = bufferWithIdentities.get(i);
+ Buffer buffer = bufferWithIdentity.getBuffer();
+ assertThat(buffer.getNioBufferReadable().order(ByteOrder.LITTLE_ENDIAN).getInt())
+ .isEqualTo(expectedValues.get(i));
+ assertThat(buffer.refCnt()).isEqualTo(expectedRefCounts.get(i));
+ }
+ }
+
+ private HsSubpartitionMemoryDataManager createSubpartitionMemoryDataManager(
+ HsMemoryDataManagerOperation memoryDataManagerOperation) {
+ return new HsSubpartitionMemoryDataManager(
+ SUBPARTITION_ID, bufferSize, lock.readLock(), memoryDataManagerOperation);
+ }
+
+ private static ByteBuffer createRecord(int value) {
+ ByteBuffer byteBuffer = ByteBuffer.allocate(RECORD_SIZE);
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
+ byteBuffer.putInt(value);
+ byteBuffer.flip();
+ return byteBuffer;
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HybridShuffleTestUtils.java
similarity index 72%
rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyTestUtils.java
rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HybridShuffleTestUtils.java
index e959c150f63..0d7ec874ba9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HybridShuffleTestUtils.java
@@ -20,6 +20,9 @@ package org.apache.flink.runtime.io.network.partition.hybrid;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import java.util.ArrayDeque;
@@ -27,8 +30,8 @@ import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
-/** Test utils for {@link HsSpillingStrategy}. */
-public class HsSpillingStrategyTestUtils {
+/** Test utils for hybrid shuffle mode. */
+public class HybridShuffleTestUtils {
public static final int MEMORY_SEGMENT_SIZE = 128;
public static List<BufferIndexAndChannel> createBufferIndexAndChannelsList(
@@ -51,4 +54,18 @@ public class HsSpillingStrategyTestUtils {
}
return bufferIndexAndChannels;
}
+
+ public static Buffer createBuffer(int bufferSize, boolean isEvent) {
+ return new NetworkBuffer(
+ MemorySegmentFactory.allocateUnpooledSegment(bufferSize),
+ FreeingBufferRecycler.INSTANCE,
+ isEvent ? Buffer.DataType.EVENT_BUFFER : Buffer.DataType.DATA_BUFFER,
+ bufferSize);
+ }
+
+ public static BufferBuilder createBufferBuilder(int bufferSize) {
+ return new BufferBuilder(
+ MemorySegmentFactory.allocateUnpooledSegment(bufferSize),
+ FreeingBufferRecycler.INSTANCE);
+ }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingFileDataIndex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingFileDataIndex.java
new file mode 100644
index 00000000000..db5663bb42f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingFileDataIndex.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+
+/** Mock {@link HsFileDataIndex} for testing. */
+public class TestingFileDataIndex implements HsFileDataIndex {
+ private final BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction;
+
+ private final Consumer<List<SpilledBuffer>> addBuffersConsumer;
+
+ private final BiConsumer<Integer, Integer> markBufferReadableConsumer;
+
+ private TestingFileDataIndex(
+ BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction,
+ Consumer<List<SpilledBuffer>> addBuffersConsumer,
+ BiConsumer<Integer, Integer> markBufferReadableConsumer) {
+ this.getReadableRegionFunction = getReadableRegionFunction;
+ this.addBuffersConsumer = addBuffersConsumer;
+ this.markBufferReadableConsumer = markBufferReadableConsumer;
+ }
+
+ @Override
+ public Optional<ReadableRegion> getReadableRegion(int subpartitionId, int bufferIndex) {
+ return getReadableRegionFunction.apply(subpartitionId, bufferIndex);
+ }
+
+ @Override
+ public void addBuffers(List<SpilledBuffer> spilledBuffers) {
+ addBuffersConsumer.accept(spilledBuffers);
+ }
+
+ @Override
+ public void markBufferReadable(int subpartitionId, int bufferIndex) {
+ markBufferReadableConsumer.accept(subpartitionId, bufferIndex);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Builder for {@link TestingFileDataIndex}. */
+ public static class Builder {
+ private BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction =
+ (ignore1, ignore2) -> Optional.empty();
+
+ private Consumer<List<SpilledBuffer>> addBuffersConsumer = (ignore) -> {};
+
+ private BiConsumer<Integer, Integer> markBufferReadableConsumer = (ignore1, ignore2) -> {};
+
+ private Builder() {}
+
+ public Builder setGetReadableRegionFunction(
+ BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction) {
+ this.getReadableRegionFunction = getReadableRegionFunction;
+ return this;
+ }
+
+ public Builder setAddBuffersConsumer(Consumer<List<SpilledBuffer>> addBuffersConsumer) {
+ this.addBuffersConsumer = addBuffersConsumer;
+ return this;
+ }
+
+ public Builder setMarkBufferReadableConsumer(
+ BiConsumer<Integer, Integer> markBufferReadableConsumer) {
+ this.markBufferReadableConsumer = markBufferReadableConsumer;
+ return this;
+ }
+
+ public TestingFileDataIndex build() {
+ return new TestingFileDataIndex(
+ getReadableRegionFunction, addBuffersConsumer, markBufferReadableConsumer);
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingMemoryDataManagerOperation.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingMemoryDataManagerOperation.java
new file mode 100644
index 00000000000..f78774ca674
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingMemoryDataManagerOperation.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.util.function.SupplierWithException;
+
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+/** Mock {@link HsMemoryDataManagerOperation} for testing. */
+public class TestingMemoryDataManagerOperation implements HsMemoryDataManagerOperation {
+ private final SupplierWithException<BufferBuilder, InterruptedException>
+ requestBufferFromPoolSupplier;
+
+ private final BiConsumer<Integer, Integer> markBufferReadableConsumer;
+
+ private final Consumer<BufferIndexAndChannel> onBufferConsumedConsumer;
+
+ private final Runnable onBufferFinishedRunnable;
+
+ private TestingMemoryDataManagerOperation(
+ SupplierWithException<BufferBuilder, InterruptedException>
+ requestBufferFromPoolSupplier,
+ BiConsumer<Integer, Integer> markBufferReadableConsumer,
+ Consumer<BufferIndexAndChannel> onBufferConsumedConsumer,
+ Runnable onBufferFinishedRunnable) {
+ this.requestBufferFromPoolSupplier = requestBufferFromPoolSupplier;
+ this.markBufferReadableConsumer = markBufferReadableConsumer;
+ this.onBufferConsumedConsumer = onBufferConsumedConsumer;
+ this.onBufferFinishedRunnable = onBufferFinishedRunnable;
+ }
+
+ @Override
+ public BufferBuilder requestBufferFromPool() throws InterruptedException {
+ return requestBufferFromPoolSupplier.get();
+ }
+
+ @Override
+ public void markBufferReadableFromFile(int subpartitionId, int bufferIndex) {
+ markBufferReadableConsumer.accept(subpartitionId, bufferIndex);
+ }
+
+ @Override
+ public void onBufferConsumed(BufferIndexAndChannel consumedBuffer) {
+ onBufferConsumedConsumer.accept(consumedBuffer);
+ }
+
+ @Override
+ public void onBufferFinished() {
+ onBufferFinishedRunnable.run();
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Builder for {@link TestingMemoryDataManagerOperation}. */
+ public static class Builder {
+ private SupplierWithException<BufferBuilder, InterruptedException>
+ requestBufferFromPoolSupplier = () -> null;
+
+ private BiConsumer<Integer, Integer> markBufferReadableConsumer = (ignore1, ignore2) -> {};
+
+ private Consumer<BufferIndexAndChannel> onBufferConsumedConsumer = (ignore1) -> {};
+
+ private Runnable onBufferFinishedRunnable = () -> {};
+
+ public Builder setRequestBufferFromPoolSupplier(
+ SupplierWithException<BufferBuilder, InterruptedException>
+ requestBufferFromPoolSupplier) {
+ this.requestBufferFromPoolSupplier = requestBufferFromPoolSupplier;
+ return this;
+ }
+
+ public Builder setMarkBufferReadableConsumer(
+ BiConsumer<Integer, Integer> markBufferReadableConsumer) {
+ this.markBufferReadableConsumer = markBufferReadableConsumer;
+ return this;
+ }
+
+ public Builder setOnBufferConsumedConsumer(
+ Consumer<BufferIndexAndChannel> onBufferConsumedConsumer) {
+ this.onBufferConsumedConsumer = onBufferConsumedConsumer;
+ return this;
+ }
+
+ public Builder setOnBufferFinishedRunnable(Runnable onBufferFinishedRunnable) {
+ this.onBufferFinishedRunnable = onBufferFinishedRunnable;
+ return this;
+ }
+
+ private Builder() {}
+
+ public TestingMemoryDataManagerOperation build() {
+ return new TestingMemoryDataManagerOperation(
+ requestBufferFromPoolSupplier,
+ markBufferReadableConsumer,
+ onBufferConsumedConsumer,
+ onBufferFinishedRunnable);
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingStrategy.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingStrategy.java
new file mode 100644
index 00000000000..4cce53db0a9
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingStrategy.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import java.util.Optional;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+/** Mock {@link HsSpillingStrategy} for testing. */
+public class TestingSpillingStrategy implements HsSpillingStrategy {
+ private final BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction;
+
+ private final Function<Integer, Optional<Decision>> onBufferFinishedFunction;
+
+ private final Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction;
+
+ private final Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction;
+
+ private TestingSpillingStrategy(
+ BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction,
+ Function<Integer, Optional<Decision>> onBufferFinishedFunction,
+ Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction,
+ Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction) {
+ this.onMemoryUsageChangedFunction = onMemoryUsageChangedFunction;
+ this.onBufferFinishedFunction = onBufferFinishedFunction;
+ this.onBufferConsumedFunction = onBufferConsumedFunction;
+ this.decideActionWithGlobalInfoFunction = decideActionWithGlobalInfoFunction;
+ }
+
+ @Override
+ public Optional<Decision> onMemoryUsageChanged(
+ int numTotalRequestedBuffers, int currentPoolSize) {
+ return onMemoryUsageChangedFunction.apply(numTotalRequestedBuffers, currentPoolSize);
+ }
+
+ @Override
+ public Optional<Decision> onBufferFinished(int numTotalUnSpillBuffers) {
+ return onBufferFinishedFunction.apply(numTotalUnSpillBuffers);
+ }
+
+ @Override
+ public Optional<Decision> onBufferConsumed(BufferIndexAndChannel consumedBuffer) {
+ return onBufferConsumedFunction.apply(consumedBuffer);
+ }
+
+ @Override
+ public Decision decideActionWithGlobalInfo(HsSpillingInfoProvider spillingInfoProvider) {
+ return decideActionWithGlobalInfoFunction.apply(spillingInfoProvider);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Builder for {@link TestingSpillingStrategy}. */
+ public static class Builder {
+ private BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction =
+ (ignore1, ignore2) -> Optional.of(Decision.NO_ACTION);
+
+ private Function<Integer, Optional<Decision>> onBufferFinishedFunction =
+ (ignore) -> Optional.of(Decision.NO_ACTION);
+
+ private Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction =
+ (ignore) -> Optional.of(Decision.NO_ACTION);
+
+ private Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction =
+ (ignore) -> Decision.NO_ACTION;
+
+ private Builder() {}
+
+ public Builder setOnMemoryUsageChangedFunction(
+ BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction) {
+ this.onMemoryUsageChangedFunction = onMemoryUsageChangedFunction;
+ return this;
+ }
+
+ public Builder setOnBufferFinishedFunction(
+ Function<Integer, Optional<Decision>> onBufferFinishedFunction) {
+ this.onBufferFinishedFunction = onBufferFinishedFunction;
+ return this;
+ }
+
+ public Builder setOnBufferConsumedFunction(
+ Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction) {
+ this.onBufferConsumedFunction = onBufferConsumedFunction;
+ return this;
+ }
+
+ public Builder setDecideActionWithGlobalInfoFunction(
+ Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction) {
+ this.decideActionWithGlobalInfoFunction = decideActionWithGlobalInfoFunction;
+ return this;
+ }
+
+ public TestingSpillingStrategy build() {
+ return new TestingSpillingStrategy(
+ onMemoryUsageChangedFunction,
+ onBufferFinishedFunction,
+ onBufferConsumedFunction,
+ decideActionWithGlobalInfoFunction);
+ }
+ }
+}