You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by re...@apache.org on 2023/03/29 02:23:16 UTC
[incubator-celeborn] 01/42: [Flink] support 1.15
This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
commit 712c88c3d66d4e0cfdab5c8a5ce1974a4e0e8711
Author: Shuang <lv...@gmail.com>
AuthorDate: Thu Mar 9 19:32:59 2023 +0800
[Flink] support 1.15
---
build/make-distribution.sh | 3 +-
.../plugin/flink/FlinkResultPartitionInfo.java | 0
.../plugin/flink/RemoteBufferStreamReader.java | 0
.../plugin/flink/RemoteShuffleDescriptor.java | 0
.../plugin/flink/RemoteShuffleOutputGate.java | 0
.../plugin/flink/RemoteShuffleResource.java | 0
.../celeborn/plugin/flink/ShuffleResource.java | 0
.../plugin/flink/ShuffleResourceDescriptor.java | 0
.../celeborn/plugin/flink/utils/FlinkUtils.java | 55 ++
.../celeborn/plugin/flink/utils/ThreadUtils.java | 48 ++
client-flink/flink-1.15-shaded/pom.xml | 128 ++++
client-flink/flink-1.15/pom.xml | 74 +++
.../plugin/flink/RemoteShuffleEnvironment.java | 211 ++++++
.../plugin/flink/RemoteShuffleInputGate.java | 728 +++++++++++++++++++++
.../flink/RemoteShuffleInputGateFactory.java | 128 ++++
.../celeborn/plugin/flink/RemoteShuffleMaster.java | 158 +++++
.../plugin/flink/RemoteShuffleResultPartition.java | 432 ++++++++++++
.../flink/RemoteShuffleResultPartitionFactory.java | 193 ++++++
.../plugin/flink/RemoteShuffleServiceFactory.java | 99 +++
.../celeborn/plugin/flink/BufferPackSuitJ.java | 266 ++++++++
.../plugin/flink/PartitionSortedBufferSuitJ.java | 369 +++++++++++
.../plugin/flink/RemoteShuffleMasterTest.java | 226 +++++++
.../flink/RemoteShuffleOutputGateSuiteJ.java | 104 +++
.../flink/RemoteShuffleResultPartitionSuiteJ.java | 623 ++++++++++++++++++
.../flink/RemoteShuffleServiceFactorySuitJ.java | 58 ++
pom.xml | 32 +-
tests/flink-it/pom.xml | 8 +-
27 files changed, 3934 insertions(+), 9 deletions(-)
diff --git a/build/make-distribution.sh b/build/make-distribution.sh
index c4315ac53..7d42640be 100755
--- a/build/make-distribution.sh
+++ b/build/make-distribution.sh
@@ -186,7 +186,7 @@ function build_spark_client {
}
function build_flink_client {
- FLINK_VERSION=$("$MVN" help:evaluate -Dexpression=flink.version $@ 2>/dev/null \
+ FLINK_BINARY_VERSION=$("$MVN" help:evaluate -Dexpression=flink.binary.version $@ 2>/dev/null \
| grep -v "INFO" \
| grep -v "WARNING" \
| tail -n 1)
@@ -194,7 +194,6 @@ function build_flink_client {
| grep -v "INFO" \
| grep -v "WARNING" \
| tail -n 1)
- FLINK_BINARY_VERSION=${FLINK_VERSION%.*}
# Store the command as an array because $MVN variable might have spaces in it.
# Normal quoting tricks don't work.
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/FlinkResultPartitionInfo.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/FlinkResultPartitionInfo.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/FlinkResultPartitionInfo.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/FlinkResultPartitionInfo.java
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResource.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResource.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResource.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResource.java
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResource.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResource.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResource.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResource.java
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResourceDescriptor.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResourceDescriptor.java
similarity index 100%
rename from client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResourceDescriptor.java
rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/ShuffleResourceDescriptor.java
diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/FlinkUtils.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/FlinkUtils.java
new file mode 100644
index 000000000..06cc38566
--- /dev/null
+++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/FlinkUtils.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.utils;
+
+import java.util.Map;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+
+import org.apache.celeborn.common.CelebornConf;
+
+public class FlinkUtils {
+
+ public static CelebornConf toCelebornConf(Configuration configuration) {
+ CelebornConf tmpCelebornConf = new CelebornConf();
+ Map<String, String> confMap = configuration.toMap();
+ for (Map.Entry<String, String> entry : confMap.entrySet()) {
+ String key = entry.getKey();
+ if (key.startsWith("celeborn.")) {
+ tmpCelebornConf.set(entry.getKey(), entry.getValue());
+ }
+ }
+
+ return tmpCelebornConf;
+ }
+
+ public static String toCelebornAppId(JobID jobID) {
+ return jobID.toString();
+ }
+
+ public static String toShuffleId(JobID jobID, IntermediateDataSetID dataSetID) {
+ return jobID.toString() + "-" + dataSetID.toString();
+ }
+
+ public static String toAttemptId(ExecutionAttemptID attemptID) {
+ return attemptID.toString();
+ }
+}
diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/ThreadUtils.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/ThreadUtils.java
new file mode 100644
index 000000000..58e23ddc0
--- /dev/null
+++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/ThreadUtils.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.utils;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.flink.util.ExecutorUtils;
+import org.slf4j.Logger;
+
+public class ThreadUtils {
+
+ public static ThreadFactory createFactoryWithDefaultExceptionHandler(
+ final String executorServiceName, final Logger LOG) {
+ return new ThreadFactoryBuilder()
+ .setNameFormat(executorServiceName + "-%d")
+ .setDaemon(true)
+ .setUncaughtExceptionHandler(
+ (Thread t, Throwable e) ->
+ LOG.error(
+ "exception in serviceName: {}, thread: {}",
+ executorServiceName,
+ t.getName(),
+ e))
+ .build();
+ }
+
+ public static void shutdownExecutors(int timeoutSecs, ExecutorService executorService) {
+ ExecutorUtils.gracefulShutdown(timeoutSecs, TimeUnit.SECONDS, executorService);
+ }
+}
diff --git a/client-flink/flink-1.15-shaded/pom.xml b/client-flink/flink-1.15-shaded/pom.xml
new file mode 100644
index 000000000..56dd05561
--- /dev/null
+++ b/client-flink/flink-1.15-shaded/pom.xml
@@ -0,0 +1,128 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+~ Licensed to the Apache Software Foundation (ASF) under one or more
+~ contributor license agreements. See the NOTICE file distributed with
+~ this work for additional information regarding copyright ownership.
+~ The ASF licenses this file to You under the Apache License, Version 2.0
+~ (the "License"); you may not use this file except in compliance with
+~ the License. You may obtain a copy of the License at
+~
+~ http://www.apache.org/licenses/LICENSE-2.0
+~
+~ Unless required by applicable law or agreed to in writing, software
+~ distributed under the License is distributed on an "AS IS" BASIS,
+~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+~ See the License for the specific language governing permissions and
+~ limitations under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-parent_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <artifactId>celeborn-client-flink-1.15-shaded_${scala.binary.version}</artifactId>
+ <packaging>jar</packaging>
+ <name>Celeborn Shaded Client for Flink 1.15</name>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-client-flink-1.15</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <configuration>
+ <relocations>
+ <relocation>
+ <pattern>com.google.protobuf</pattern>
+ <shadedPattern>${shading.prefix}.com.google.protobuf</shadedPattern>
+ </relocation>
+ <relocation>
+ <pattern>com.google.common</pattern>
+ <shadedPattern>${shading.prefix}.com.google.common</shadedPattern>
+ </relocation>
+ <relocation>
+ <pattern>io.netty</pattern>
+ <shadedPattern>${shading.prefix}.io.netty</shadedPattern>
+ </relocation>
+ <relocation>
+ <pattern>org.apache.commons</pattern>
+ <shadedPattern>${shading.prefix}.org.apache.commons</shadedPattern>
+ </relocation>
+ <relocation>
+ <pattern>org.roaringbitmap</pattern>
+ <shadedPattern>${shading.prefix}.org.roaringbitmap</shadedPattern>
+ </relocation>
+ </relocations>
+ <artifactSet>
+ <includes>
+ <include>org.apache.celeborn:*</include>
+ <include>com.google.protobuf:protobuf-java</include>
+ <include>com.google.guava:guava</include>
+ <include>io.netty:*</include>
+ <include>org.apache.commons:commons-lang3</include>
+ <include>org.roaringbitmap:RoaringBitmap</include>
+ </includes>
+ </artifactSet>
+ <filters>
+ <filter>
+ <artifact>*:*</artifact>
+ <excludes>
+ <exclude>META-INF/*.SF</exclude>
+ <exclude>META-INF/*.DSA</exclude>
+ <exclude>META-INF/*.RSA</exclude>
+ <exclude>**/log4j.properties</exclude>
+ </excludes>
+ </filter>
+ </filters>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <version>${maven.plugin.antrun.version}</version>
+ <executions>
+ <execution>
+ <id>rename-native-library</id>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <phase>package</phase>
+ <configuration>
+ <target>
+ <echo message="unpacking netty jar"></echo>
+ <unzip dest="${project.build.directory}/unpacked/" src="${project.build.directory}/${artifactId}-${version}.jar"></unzip>
+ <echo message="renaming native epoll library"></echo>
+ <move includeemptydirs="false" todir="${project.build.directory}/unpacked/META-INF/native">
+ <fileset dir="${project.build.directory}/unpacked/META-INF/native"></fileset>
+ <mapper from="libnetty_transport_native_epoll_x86_64.so" to="liborg_apache_celeborn_shaded_netty_transport_native_epoll_x86_64.so" type="glob"></mapper>
+ </move>
+ <move includeemptydirs="false" todir="${project.build.directory}/unpacked/META-INF/native">
+ <fileset dir="${project.build.directory}/unpacked/META-INF/native"></fileset>
+ <mapper from="libnetty_transport_native_epoll_aarch_64.so" to="liborg_apache_celeborn_shaded_netty_transport_native_epoll_aarch_64.so.so" type="glob"></mapper>
+ </move>
+ <echo message="deleting native kqueue library"></echo>
+ <delete file="${project.build.directory}/unpacked/META-INF/native/libnetty_transport_native_kqueue_x86_64.jnilib"></delete>
+ <delete file="${project.build.directory}/unpacked/META-INF/native/libnetty_transport_native_kqueue_aarch_64.jnilib"></delete>
+ <delete file="${project.build.directory}/unpacked/META-INF/native/libnetty_resolver_dns_native_macos_aarch_64.jnilib"></delete>
+ <delete file="${project.build.directory}/unpacked/META-INF/native/libnetty_resolver_dns_native_macos_x86_64.jnilib"></delete>
+ <echo message="repackaging netty jar"></echo>
+ <jar basedir="${project.build.directory}/unpacked" destfile="${project.build.directory}/${artifactId}-${version}.jar"></jar>
+ </target>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/client-flink/flink-1.15/pom.xml b/client-flink/flink-1.15/pom.xml
new file mode 100644
index 000000000..3f79101bc
--- /dev/null
+++ b/client-flink/flink-1.15/pom.xml
@@ -0,0 +1,74 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+~ Licensed to the Apache Software Foundation (ASF) under one or more
+~ contributor license agreements. See the NOTICE file distributed with
+~ this work for additional information regarding copyright ownership.
+~ The ASF licenses this file to You under the Apache License, Version 2.0
+~ (the "License"); you may not use this file except in compliance with
+~ the License. You may obtain a copy of the License at
+~
+~ http://www.apache.org/licenses/LICENSE-2.0
+~
+~ Unless required by applicable law or agreed to in writing, software
+~ distributed under the License is distributed on an "AS IS" BASIS,
+~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+~ See the License for the specific language governing permissions and
+~ limitations under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-parent_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <artifactId>celeborn-client-flink-1.15</artifactId>
+ <packaging>jar</packaging>
+ <name>Celeborn Client for Flink 1.15</name>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-common_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-client_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+ <artifactId>celeborn-client-flink-common_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-runtime</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.logging.log4j</groupId>
+ <artifactId>log4j-slf4j-impl</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.logging.log4j</groupId>
+ <artifactId>log4j-1.2-api</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+</project>
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java
new file mode 100644
index 000000000..7ad50963e
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_INPUT;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_OUTPUT;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.createShuffleIOOwnerMetricGroup;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.PartitionInfo;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleIOOwnerContext;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+
+/**
+ * The implementation of {@link ShuffleEnvironment} based on the remote shuffle service, providing
+ * shuffle environment on flink TM side.
+ */
+public class RemoteShuffleEnvironment
+ implements ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleEnvironment.class);
+
+ /** Network buffer pool for shuffle read and shuffle write. */
+ private final NetworkBufferPool networkBufferPool;
+
+ /** A trivial {@link ResultPartitionManager}. */
+ private final ResultPartitionManager resultPartitionManager;
+
+ /** Factory class to create {@link RemoteShuffleResultPartition}. */
+ private final RemoteShuffleResultPartitionFactory resultPartitionFactory;
+
+ // // /** Factory class to create {@link RemoteShuffleInputGate}. */
+ // private final RemoteShuffleInputGateFactory inputGateFactory;
+
+ /** Whether the shuffle environment is closed. */
+ private boolean isClosed;
+
+ private final Object lock = new Object();
+
+ private final CelebornConf conf;
+
+ private final RemoteShuffleInputGateFactory inputGateFactory;
+
+ /**
+ * @param networkBufferPool Network buffer pool for shuffle read and shuffle write.
+ * @param resultPartitionManager A trivial {@link ResultPartitionManager}.
+ * @param resultPartitionFactory Factory class to create {@link RemoteShuffleResultPartition}. //
+ * * @param inputGateFactory Factory class to create {@link RemoteShuffleInputGate}.
+ */
+ public RemoteShuffleEnvironment(
+ NetworkBufferPool networkBufferPool,
+ ResultPartitionManager resultPartitionManager,
+ RemoteShuffleResultPartitionFactory resultPartitionFactory,
+ RemoteShuffleInputGateFactory inputGateFactory,
+ CelebornConf conf) {
+
+ this.networkBufferPool = networkBufferPool;
+ this.resultPartitionManager = resultPartitionManager;
+ this.resultPartitionFactory = resultPartitionFactory;
+ this.inputGateFactory = inputGateFactory;
+ this.conf = conf;
+ this.isClosed = false;
+ }
+
+ @Override
+ public List<ResultPartitionWriter> createResultPartitionWriters(
+ ShuffleIOOwnerContext ownerContext,
+ List<ResultPartitionDeploymentDescriptor> resultPartitionDeploymentDescriptors) {
+
+ synchronized (lock) {
+ checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+
+ ResultPartitionWriter[] resultPartitions =
+ new ResultPartitionWriter[resultPartitionDeploymentDescriptors.size()];
+ for (int index = 0; index < resultPartitions.length; index++) {
+ resultPartitions[index] =
+ resultPartitionFactory.create(
+ ownerContext.getOwnerName(), index,
+ resultPartitionDeploymentDescriptors.get(index), conf);
+ }
+ return Arrays.asList(resultPartitions);
+ }
+ }
+
+ @Override
+ public List<IndexedInputGate> createInputGates(
+ ShuffleIOOwnerContext ownerContext,
+ PartitionProducerStateProvider producerStateProvider,
+ List<InputGateDeploymentDescriptor> inputGateDescriptors) {
+ synchronized (lock) {
+ checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+
+ IndexedInputGate[] inputGates = new IndexedInputGate[inputGateDescriptors.size()];
+ for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) {
+ InputGateDeploymentDescriptor igdd = inputGateDescriptors.get(gateIndex);
+ RemoteShuffleInputGate inputGate =
+ inputGateFactory.create(ownerContext.getOwnerName(), gateIndex, igdd);
+ inputGates[gateIndex] = inputGate;
+ }
+ return Arrays.asList(inputGates);
+ }
+ }
+
+ @Override
+ public void close() {
+ LOG.info("Close RemoteShuffleEnvironment.");
+ synchronized (lock) {
+ try {
+ networkBufferPool.destroyAllBufferPools();
+ } catch (Throwable t) {
+ LOG.error("Close RemoteShuffleEnvironment failure.", t);
+ }
+ try {
+ resultPartitionManager.shutdown();
+ } catch (Throwable t) {
+ LOG.error("Close RemoteShuffleEnvironment failure.", t);
+ }
+ try {
+ networkBufferPool.destroy();
+ } catch (Throwable t) {
+ LOG.error("Close RemoteShuffleEnvironment failure.", t);
+ }
+ isClosed = true;
+ }
+ }
+
+ @Override
+ public int start() throws IOException {
+ synchronized (lock) {
+ checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+ LOG.info("Starting the network environment and its components.");
+ // trivial value.
+ return 1;
+ }
+ }
+
+ @Override
+ public boolean updatePartitionInfo(ExecutionAttemptID consumerID, PartitionInfo partitionInfo) {
+ throw new FlinkRuntimeException("Not implemented yet.");
+ }
+
+ @Override
+ public ShuffleIOOwnerContext createShuffleIOOwnerContext(
+ String ownerName, ExecutionAttemptID executionAttemptID, MetricGroup parentGroup) {
+ MetricGroup nettyGroup = createShuffleIOOwnerMetricGroup(checkNotNull(parentGroup));
+ return new ShuffleIOOwnerContext(
+ checkNotNull(ownerName),
+ checkNotNull(executionAttemptID),
+ parentGroup,
+ nettyGroup.addGroup(METRIC_GROUP_OUTPUT),
+ nettyGroup.addGroup(METRIC_GROUP_INPUT));
+ }
+
+ @Override
+ public void releasePartitionsLocally(Collection<ResultPartitionID> partitionIds) {
+ throw new FlinkRuntimeException("Not implemented yet.");
+ }
+
+ @Override
+ public Collection<ResultPartitionID> getPartitionsOccupyingLocalResources() {
+ return new ArrayList<>();
+ }
+
+ @VisibleForTesting
+ NetworkBufferPool getNetworkBufferPool() {
+ return networkBufferPool;
+ }
+
+ @VisibleForTesting
+ RemoteShuffleResultPartitionFactory getResultPartitionFactory() {
+ return resultPartitionFactory;
+ }
+}
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleInputGate.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleInputGate.java
new file mode 100644
index 000000000..33581d791
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleInputGate.java
@@ -0,0 +1,728 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentProvider;
+import org.apache.flink.metrics.SimpleCounter;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
+import org.apache.flink.runtime.clusterframework.types.ResourceID;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.SubpartitionIndexRange;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.event.TaskEvent;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.LocalConnectionManager;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
+import org.apache.flink.runtime.throughput.ThroughputCalculator;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.util.CloseableIterator;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.clock.SystemClock;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.InetAddress;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
+/** A {@link IndexedInputGate} which ingest data from remote shuffle workers. */
+public class RemoteShuffleInputGate extends IndexedInputGate {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleInputGate.class);
+
+ /** Lock to protect {@link #receivedBuffers} and {@link #cause} and {@link #closed}. */
+ private final Object lock = new Object();
+
+ /** Name of the corresponding computing task. */
+ private final String taskName;
+
+ /** Index of the gate of the corresponding computing task. */
+ private final int gateIndex;
+
+ /** Deployment descriptor for a single input gate instance. */
+ private final InputGateDeploymentDescriptor gateDescriptor;
+
+ /** Buffer pool provider. */
+ private final SupplierWithException<BufferPool, IOException> bufferPoolFactory;
+
+ /** Flink buffer pools to allocate network memory. */
+ private BufferPool bufferPool;
+
+ /** Buffer pool used by the transfer layer. */
+ private final TransferBufferPool transferBufferPool =
+ new TransferBufferPool(Collections.emptySet());
+
+ private final List<RemoteBufferStreamReader> bufferReaders = new ArrayList<>();
+ private final List<InputChannelInfo> channelsInfo;
+ /** Map from channel index to shuffle client index. */
+ private final int[] clientIndexMap;
+
+ /** Map from shuffle client index to channel index. */
+ private final int[] channelIndexMap;
+
+ /** The number of subpartitions that has not consumed per channel. */
+ private final int[] numSubPartitionsHasNotConsumed;
+
+ /** The overall number of subpartitions that has not been consumed. */
+ private long numUnconsumedSubpartitions;
+
+ /** Received buffers from remote shuffle worker. It's consumed by upper computing task. */
+ private final Queue<Pair<Buffer, InputChannelInfo>> receivedBuffers = new LinkedList<>();
+
+ /** {@link Throwable} when reading failure. */
+ private Throwable cause;
+
+ /** Whether this remote input gate has been closed or not. */
+ private boolean closed;
+
+ /** Whether we have opened all initial channels or not. */
+ private boolean initialChannelsOpened;
+
+ /** Number of pending {@link EndOfData} events to be received. */
+ private long pendingEndOfDataEvents;
+ /** Max concurrent reader count */
+ private int numConcurrentReading = Integer.MAX_VALUE;
+ /** Keep compatibility with streaming mode. */
+ private boolean shouldDrainOnEndOfData = true;
+
+ /** Data decompressor. */
+ private final BufferDecompressor bufferDecompressor;
+
+ private FlinkShuffleClientImpl shuffleClient;
+
+ public RemoteShuffleInputGate(
+ CelebornConf celebornConf,
+ String taskName,
+ int gateIndex,
+ InputGateDeploymentDescriptor gateDescriptor,
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+ BufferDecompressor bufferDecompressor) {
+
+ this.taskName = taskName;
+ this.gateIndex = gateIndex;
+ this.gateDescriptor = gateDescriptor;
+ this.bufferPoolFactory = bufferPoolFactory;
+
+ int numChannels = gateDescriptor.getShuffleDescriptors().length;
+ this.clientIndexMap = new int[numChannels];
+ this.channelIndexMap = new int[numChannels];
+ this.numSubPartitionsHasNotConsumed = new int[numChannels];
+ this.bufferDecompressor = bufferDecompressor;
+
+ RemoteShuffleDescriptor remoteShuffleDescriptor =
+ (RemoteShuffleDescriptor) gateDescriptor.getShuffleDescriptors()[0];
+ this.shuffleClient =
+ FlinkShuffleClientImpl.get(
+ remoteShuffleDescriptor.getShuffleResource().getRssMetaServiceHost(),
+ remoteShuffleDescriptor.getShuffleResource().getRssMetaServicePort(),
+ celebornConf,
+ new UserIdentifier("default", "default"));
+
+ this.numUnconsumedSubpartitions = initShuffleReadClients();
+ this.pendingEndOfDataEvents = numUnconsumedSubpartitions;
+ this.channelsInfo = createChannelInfos();
+ }
+
+ private long initShuffleReadClients() {
+ int startSubIdx = gateDescriptor.getConsumedSubpartitionIndex();
+ int endSubIdx = gateDescriptor.getConsumedSubpartitionIndex();
+ int numSubpartitionsPerChannel = endSubIdx - startSubIdx + 1;
+ long numUnconsumedSubpartitions = 0;
+
+ // left element is index
+ List<Pair<Integer, ShuffleDescriptor>> descriptors =
+ IntStream.range(0, gateDescriptor.getShuffleDescriptors().length)
+ .mapToObj(i -> Pair.of(i, gateDescriptor.getShuffleDescriptors()[i]))
+ .collect(Collectors.toList());
+
+ int clientIndex = 0;
+ for (Pair<Integer, ShuffleDescriptor> descriptor : descriptors) {
+ RemoteShuffleDescriptor remoteDescriptor = (RemoteShuffleDescriptor) descriptor.getRight();
+ ShuffleResourceDescriptor shuffleDescriptor =
+ remoteDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+
+ LOG.debug("create shuffle reader for descriptor {}", shuffleDescriptor);
+ String applicationId = remoteDescriptor.getCelebornAppId();
+
+ RemoteBufferStreamReader reader =
+ new RemoteBufferStreamReader(
+ shuffleClient,
+ shuffleDescriptor,
+ applicationId,
+ startSubIdx,
+ endSubIdx,
+ transferBufferPool,
+ getDataListener(descriptor.getLeft()),
+ getFailureListener(remoteDescriptor.getResultPartitionID()));
+
+ bufferReaders.add(reader);
+ numSubPartitionsHasNotConsumed[descriptor.getLeft()] = numSubpartitionsPerChannel;
+ numUnconsumedSubpartitions += numSubpartitionsPerChannel;
+ clientIndexMap[descriptor.getLeft()] = clientIndex;
+ channelIndexMap[clientIndex] = descriptor.getLeft();
+ ++clientIndex;
+ }
+ return numUnconsumedSubpartitions;
+ }
+
+ /** Setup gate and build network connections. */
+ @Override
+ public void setup() throws IOException {
+ long startTime = System.nanoTime();
+
+ bufferPool = bufferPoolFactory.get();
+ BufferUtils.reserveNumRequiredBuffers(bufferPool, 16);
+
+ tryRequestBuffers();
+ // Complete availability future though handshake not fired yet, thus to allow fetcher to
+ // 'pollNext' and fire handshake to remote. This mechanism is to avoid bookkeeping remote
+ // reading resource before task start processing data from input gate.
+ availabilityHelper.getUnavailableToResetAvailable().complete(null);
+ LOG.info("Set up read gate by {} ms.", (System.nanoTime() - startTime) / 1000_000);
+ }
+
+ /** Index of the gate of the corresponding computing task. */
+ @Override
+ public int getGateIndex() {
+ return gateIndex;
+ }
+
+ /** Get number of input channels. A channel is a data flow from one shuffle worker. */
+ @Override
+ public int getNumberOfInputChannels() {
+ return bufferReaders.size();
+ }
+
+ /** Whether reading is finished -- all channels are finished and cached buffers are drained. */
+ @Override
+ public boolean isFinished() {
+ synchronized (lock) {
+ return allReadersEOF() && receivedBuffers.isEmpty();
+ }
+ }
+
+ @Override
+ public Optional<BufferOrEvent> getNext() {
+ throw new UnsupportedOperationException("Not implemented (DataSet API is not supported).");
+ }
+
+ /** Poll a received {@link BufferOrEvent}. */
+ @Override
+ public Optional<BufferOrEvent> pollNext() throws IOException {
+ if (!initialChannelsOpened) {
+ tryOpenSomeChannels();
+ initialChannelsOpened = true;
+ // DO NOT return, method of 'getReceived' will manipulate 'availabilityHelper'.
+ }
+
+ Pair<Buffer, InputChannelInfo> pair = getReceived();
+ Optional<BufferOrEvent> bufferOrEvent = Optional.empty();
+ LOG.debug("pollNext called with pair null {}", pair == null);
+ while (pair != null) {
+ Buffer buffer = pair.getLeft();
+ InputChannelInfo channelInfo = pair.getRight();
+ LOG.debug("get buffer {} on channel {}", buffer, channelInfo);
+ if (buffer.isBuffer()) {
+ bufferOrEvent = transformBuffer(buffer, channelInfo);
+ } else {
+ bufferOrEvent = transformEvent(buffer, channelInfo);
+ LOG.info("recevied event: " + bufferOrEvent.get().getEvent().getClass().getName());
+ }
+
+ if (bufferOrEvent.isPresent()) {
+ break;
+ }
+ pair = getReceived();
+ }
+
+ tryRequestBuffers();
+ return bufferOrEvent;
+ }
+
+ private Buffer decompressBufferIfNeeded(Buffer buffer) throws IOException {
+ if (buffer.isCompressed()) {
+ try {
+ checkState(bufferDecompressor != null, "Buffer decompressor not set.");
+ return bufferDecompressor.decompressToIntermediateBuffer(buffer);
+ } catch (Throwable t) {
+ throw new IOException("Decompress failure", t);
+ } finally {
+ buffer.recycleBuffer();
+ }
+ }
+ return buffer;
+ }
+
+ /** Close all reading channels inside this {@link RemoteShuffleInputGate}. */
+ @Override
+ public void close() throws Exception {
+ List<Buffer> buffersToRecycle;
+ Throwable closeException = null;
+ // Do not check closed flag, thus to allow calling this method from both task thread and
+ // cancel thread.
+ for (RemoteBufferStreamReader shuffleReadClient : bufferReaders) {
+ try {
+ shuffleReadClient.close();
+ } catch (Throwable throwable) {
+ closeException = closeException == null ? throwable : closeException;
+ LOG.error("Failed to close shuffle read client.", throwable);
+ }
+ }
+ synchronized (lock) {
+ buffersToRecycle = receivedBuffers.stream().map(Pair::getLeft).collect(Collectors.toList());
+ receivedBuffers.clear();
+ closed = true;
+ }
+
+ try {
+ buffersToRecycle.forEach(Buffer::recycleBuffer);
+ } catch (Throwable throwable) {
+ closeException = closeException == null ? throwable : closeException;
+ LOG.error("Failed to recycle buffers.", throwable);
+ }
+
+ try {
+ transferBufferPool.destroy();
+ } catch (Throwable throwable) {
+ closeException = closeException == null ? throwable : closeException;
+ LOG.error("Failed to close transfer buffer pool.", throwable);
+ }
+
+ try {
+ if (bufferPool != null) {
+ bufferPool.lazyDestroy();
+ }
+ } catch (Throwable throwable) {
+ closeException = closeException == null ? throwable : closeException;
+ LOG.error("Failed to close local buffer pool.", throwable);
+ }
+
+ if (closeException != null) {
+ ExceptionUtils.rethrowException(closeException);
+ }
+ }
+
+ /** Get {@link InputChannelInfo}s of this {@link RemoteShuffleInputGate}. */
+ @Override
+ public List<InputChannelInfo> getChannelInfos() {
+ return channelsInfo;
+ }
+
+ /** Each one corresponds to a reading channel. */
+ public List<RemoteBufferStreamReader> getBufferReaders() {
+ return bufferReaders;
+ }
+
+ private List<InputChannelInfo> createChannelInfos() {
+ return IntStream.range(0, gateDescriptor.getShuffleDescriptors().length)
+ .mapToObj(i -> new InputChannelInfo(gateIndex, i))
+ .collect(Collectors.toList());
+ }
+
+ /** Try to open more readers to {@link #numConcurrentReading}. */
+ private void tryOpenSomeChannels() throws IOException {
+ List<RemoteBufferStreamReader> clientsToOpen = new ArrayList<>();
+
+ synchronized (lock) {
+ if (closed) {
+ throw new IOException("Input gate already closed.");
+ }
+
+ LOG.debug("Try open some partition readers.");
+ int numOnGoing = 0;
+ for (int i = 0; i < bufferReaders.size(); i++) {
+ RemoteBufferStreamReader bufferStreamReader = bufferReaders.get(i);
+ LOG.debug(
+ "Trying reader: {}, isOpened={}, numSubPartitionsHasNotConsumed={}.",
+ bufferStreamReader,
+ bufferStreamReader.isOpened(),
+ numSubPartitionsHasNotConsumed[channelIndexMap[i]]);
+ if (numOnGoing >= numConcurrentReading) {
+ break;
+ }
+
+ if (bufferStreamReader.isOpened()
+ && numSubPartitionsHasNotConsumed[channelIndexMap[i]] > 0) {
+ numOnGoing++;
+ continue;
+ }
+
+ if (!bufferStreamReader.isOpened()) {
+ clientsToOpen.add(bufferStreamReader);
+ numOnGoing++;
+ }
+ }
+ }
+
+ for (RemoteBufferStreamReader reader : clientsToOpen) {
+ reader.open(0);
+ }
+ }
+
+ private void tryRequestBuffers() {
+ checkState(bufferPool != null, "Not initialized yet.");
+
+ Buffer buffer;
+ List<ByteBuf> buffers = new ArrayList<>();
+ while ((buffer = bufferPool.requestBuffer()) != null) {
+ buffers.add(buffer.asByteBuf());
+ }
+
+ if (!buffers.isEmpty()) {
+ transferBufferPool.addBuffers(buffers);
+ }
+ }
+
+ private void onBuffer(Buffer buffer, int channelIdx) {
+ synchronized (lock) {
+ if (closed || cause != null) {
+ buffer.recycleBuffer();
+ throw new IllegalStateException("Input gate already closed or failed.");
+ }
+
+ boolean needRecycle = true;
+ try {
+ boolean wasEmpty = receivedBuffers.isEmpty();
+ InputChannelInfo channelInfo = channelsInfo.get(channelIdx);
+ checkState(channelInfo.getInputChannelIdx() == channelIdx, "Illegal channel index.");
+ LOG.debug("ReceivedBuffers is adding buffer {} on {}", buffer, channelInfo);
+ receivedBuffers.add(Pair.of(buffer, channelInfo));
+ needRecycle = false;
+ if (wasEmpty) {
+ availabilityHelper.getUnavailableToResetAvailable().complete(null);
+ }
+ } catch (Throwable throwable) {
+ if (needRecycle) {
+ buffer.recycleBuffer();
+ }
+ throw throwable;
+ }
+ }
+ }
+
+ private Consumer<ByteBuf> getDataListener(int channelIdx) {
+ return byteBuf -> {
+ Queue<Buffer> unpackedBuffers = null;
+ try {
+ unpackedBuffers = BufferPacker.unpack(byteBuf);
+ while (!unpackedBuffers.isEmpty()) {
+ onBuffer(unpackedBuffers.poll(), channelIdx);
+ }
+ } catch (Throwable throwable) {
+ synchronized (lock) {
+ cause = cause == null ? throwable : cause;
+ availabilityHelper.getUnavailableToResetAvailable().complete(null);
+ }
+
+ if (unpackedBuffers != null) {
+ unpackedBuffers.forEach(Buffer::recycleBuffer);
+ }
+ LOG.error("Failed to process the received buffer.", throwable);
+ }
+ };
+ }
+
+ private Consumer<Throwable> getFailureListener(ResultPartitionID rpID) {
+ return throwable -> {
+ synchronized (lock) {
+ if (cause != null) {
+ return;
+ }
+ Class<?> clazz = PartitionNotFoundException.class;
+ if (throwable.getMessage() != null && throwable.getMessage().contains(clazz.getName())) {
+ cause = new PartitionNotFoundException(rpID);
+ } else {
+ cause = throwable;
+ }
+ availabilityHelper.getUnavailableToResetAvailable().complete(null);
+ }
+ };
+ }
+
+ private Pair<Buffer, InputChannelInfo> getReceived() throws IOException {
+ synchronized (lock) {
+ healthCheck();
+ if (!receivedBuffers.isEmpty()) {
+ return receivedBuffers.poll();
+ } else {
+ if (!allReadersEOF()) {
+ availabilityHelper.resetUnavailable();
+ }
+ return null;
+ }
+ }
+ }
+
+ private void healthCheck() throws IOException {
+ if (closed) {
+ throw new IOException("Input gate already closed.");
+ }
+ if (cause != null) {
+ if (cause instanceof IOException) {
+ throw (IOException) cause;
+ } else {
+ throw new IOException(cause);
+ }
+ }
+ }
+
+ private boolean allReadersEOF() {
+ return numUnconsumedSubpartitions <= 0;
+ }
+
+ private Optional<BufferOrEvent> transformBuffer(Buffer buf, InputChannelInfo info)
+ throws IOException {
+ return Optional.of(
+ new BufferOrEvent(decompressBufferIfNeeded(buf), info, !isFinished(), false));
+ }
+
+ private Optional<BufferOrEvent> transformEvent(Buffer buffer, InputChannelInfo channelInfo)
+ throws IOException {
+ final AbstractEvent event;
+ try {
+ event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader());
+ } catch (Throwable t) {
+ throw new IOException("Deserialize failure.", t);
+ } finally {
+ buffer.recycleBuffer();
+ }
+ if (event.getClass() == EndOfPartitionEvent.class) {
+ checkState(
+ numSubPartitionsHasNotConsumed[channelInfo.getInputChannelIdx()] > 0,
+ "BUG -- EndOfPartitionEvent received repeatedly.");
+ numSubPartitionsHasNotConsumed[channelInfo.getInputChannelIdx()]--;
+ numUnconsumedSubpartitions--;
+ // not the real end.
+ if (numSubPartitionsHasNotConsumed[channelInfo.getInputChannelIdx()] != 0) {
+ return Optional.empty();
+ } else {
+ // the real end.
+ bufferReaders.get(clientIndexMap[channelInfo.getInputChannelIdx()]).close();
+ // tryOpenSomeChannels();
+ if (allReadersEOF()) {
+ availabilityHelper.getUnavailableToResetAvailable().complete(null);
+ }
+ }
+ } else if (event.getClass() == EndOfData.class) {
+ checkState(pendingEndOfDataEvents > 0, "Too many EndOfData event.");
+ --pendingEndOfDataEvents;
+ }
+
+ return Optional.of(
+ new BufferOrEvent(
+ event,
+ buffer.getDataType().hasPriority(),
+ channelInfo,
+ !isFinished(),
+ buffer.getSize(),
+ false));
+ }
+
+ @Override
+ public void requestPartitions() {
+ // do-nothing
+ }
+
+ @Override
+ public void checkpointStarted(CheckpointBarrier barrier) {
+ // do-nothing.
+ }
+
+ @Override
+ public void checkpointStopped(long cancelledCheckpointId) {
+ // do-nothing.
+ }
+
+ @Override
+ public void triggerDebloating() {
+ // do-nothing.
+ }
+
+ @Override
+ public List<InputChannelInfo> getUnfinishedChannels() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public EndOfDataStatus hasReceivedEndOfData() {
+ if (pendingEndOfDataEvents > 0) {
+ return EndOfDataStatus.NOT_END_OF_DATA;
+ } else if (shouldDrainOnEndOfData) {
+ return EndOfDataStatus.DRAINED;
+ } else {
+ return EndOfDataStatus.STOPPED;
+ }
+ }
+
+ @Override
+ public void finishReadRecoveredState() {
+ // do-nothing.
+ }
+
+ @Override
+ public InputChannel getChannel(int channelIndex) {
+ return new FakedRemoteInputChannel(channelIndex);
+ }
+
+ @Override
+ public void sendTaskEvent(TaskEvent event) {
+ throw new FlinkRuntimeException("Method should not be called.");
+ }
+
+ @Override
+ public void resumeConsumption(InputChannelInfo channelInfo) {
+ throw new FlinkRuntimeException("Method should not be called.");
+ }
+
+ @Override
+ public void acknowledgeAllRecordsProcessed(InputChannelInfo inputChannelInfo) {}
+
+ @Override
+ public CompletableFuture<Void> getStateConsumedFuture() {
+ return CompletableFuture.completedFuture(null);
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "ReadGate [owning task: %s, gate index: %d, descriptor: %s]",
+ taskName, gateIndex, gateDescriptor.toString());
+ }
+
+ /** Accommodation for the incompleteness of Flink pluggable shuffle service. */
+ private class FakedRemoteInputChannel extends RemoteInputChannel {
+ FakedRemoteInputChannel(int channelIndex) {
+ super(
+ new SingleInputGate(
+ "",
+ gateIndex,
+ new IntermediateDataSetID(),
+ ResultPartitionType.BLOCKING,
+ new SubpartitionIndexRange(0, 0),
+ 1,
+ (a, b, c) -> {},
+ () -> null,
+ null,
+ new FakedMemorySegmentProvider(),
+ 0,
+ new ThroughputCalculator(SystemClock.getInstance()),
+ null ),
+ channelIndex,
+ new ResultPartitionID(),
+ 0,
+ new ConnectionID(
+ new TaskManagerLocation(
+ ResourceID.generate(), InetAddress.getLoopbackAddress(), 1),
+ 0),
+ new LocalConnectionManager(),
+ 0,
+ 0,
+ 0,
+ new SimpleCounter(),
+ new SimpleCounter(),
+ new FakedChannelStateWriter());
+ }
+ }
+
+ /** Accommodation for the incompleteness of Flink pluggable shuffle service. */
+ private static class FakedMemorySegmentProvider implements MemorySegmentProvider {
+
+ @Override
+ public Collection<MemorySegment> requestUnpooledMemorySegments(int i) throws IOException {
+ return null;
+ }
+
+ @Override
+ public void recycleUnpooledMemorySegments(Collection<MemorySegment> collection)
+ throws IOException {}
+ }
+
+ /** Accommodation for the incompleteness of Flink pluggable shuffle service. */
+ private static class FakedChannelStateWriter implements ChannelStateWriter {
+
+ @Override
+ public void start(long cpId, CheckpointOptions checkpointOptions) {}
+
+ @Override
+ public void addInputData(
+ long cpId, InputChannelInfo info, int startSeqNum, CloseableIterator<Buffer> data) {}
+
+ @Override
+ public void addOutputData(
+ long cpId, ResultSubpartitionInfo info, int startSeqNum, Buffer... data) {}
+
+ @Override
+ public void finishInput(long checkpointId) {}
+
+ @Override
+ public void finishOutput(long checkpointId) {}
+
+ @Override
+ public void abort(long checkpointId, Throwable cause, boolean cleanup) {}
+
+ @Override
+ public ChannelStateWriteResult getAndRemoveWriteResult(long checkpointId) {
+ return null;
+ }
+
+ @Override
+ public void close() {}
+ }
+}
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleInputGateFactory.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleInputGateFactory.java
new file mode 100644
index 000000000..a3485721f
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleInputGateFactory.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.config.PluginConf;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/** Factory class to create {@link RemoteShuffleInputGate}. */
+public class RemoteShuffleInputGateFactory {
+
+ public static final int MIN_BUFFERS_PER_GATE = 16;
+
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleInputGateFactory.class);
+
+ /** Number of max concurrent reading channels. */
+ private final int numConcurrentReading;
+
+ /** Codec used for compression / decompression. */
+ private static final String compressionCodec = "LZ4";
+
+ /** Network buffer size. */
+ private final int networkBufferSize;
+
+ /**
+ * Network buffer pool used for shuffle read buffers. {@link BufferPool}s will be created from it
+ * and each of them will be used by a channel exclusively.
+ */
+ private final NetworkBufferPool networkBufferPool;
+
+ /** Sum of buffers. */
+ private final int numBuffersPerGate;
+
+ private CelebornConf celebornConf;
+
+ public RemoteShuffleInputGateFactory(
+ Configuration flinkConf,
+ CelebornConf conf,
+ NetworkBufferPool networkBufferPool,
+ int networkBufferSize) {
+ this.celebornConf = conf;
+ long configuredMemorySize =
+ org.apache.celeborn.common.util.Utils.byteStringAsBytes(
+ PluginConf.getValue(flinkConf, PluginConf.MEMORY_PER_INPUT_GATE));
+ if (configuredMemorySize < MIN_BUFFERS_PER_GATE) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Insufficient network memory per input gate, please increase %s to at " + "least %s.",
+ PluginConf.MEMORY_PER_INPUT_GATE.name,
+ PluginConf.getValue(flinkConf, PluginConf.MIN_MEMORY_PER_GATE)));
+ }
+
+ this.numBuffersPerGate = Utils.checkedDownCast(configuredMemorySize / networkBufferSize);
+ if (numBuffersPerGate < MIN_BUFFERS_PER_GATE) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Insufficient network memory per input gate, please increase %s to at "
+ + "least %d bytes.",
+ PluginConf.MEMORY_PER_INPUT_GATE.name, networkBufferSize * MIN_BUFFERS_PER_GATE));
+ }
+
+ this.networkBufferSize = networkBufferSize;
+ this.numConcurrentReading =
+ Integer.valueOf(PluginConf.getValue(flinkConf, PluginConf.NUM_CONCURRENT_READINGS));
+ this.networkBufferPool = networkBufferPool;
+ }
+
+ /** Create {@link RemoteShuffleInputGate} from {@link InputGateDeploymentDescriptor}. */
+ public RemoteShuffleInputGate create(
+ String owningTaskName, int gateIndex, InputGateDeploymentDescriptor igdd) {
+ LOG.info(
+ "Create input gate -- number of buffers per input gate={}, "
+ + "number of concurrent readings={}.",
+ numBuffersPerGate,
+ numConcurrentReading);
+
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory =
+ createBufferPoolFactory(networkBufferPool, numBuffersPerGate);
+ BufferDecompressor bufferDecompressor =
+ new BufferDecompressor(networkBufferSize, compressionCodec);
+
+ return createInputGate(owningTaskName, gateIndex, igdd, bufferPoolFactory, bufferDecompressor);
+ }
+
+ // For testing.
+ RemoteShuffleInputGate createInputGate(
+ String owningTaskName,
+ int gateIndex,
+ InputGateDeploymentDescriptor igdd,
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+ BufferDecompressor bufferDecompressor) {
+ return new RemoteShuffleInputGate(
+ this.celebornConf, owningTaskName, gateIndex, igdd, bufferPoolFactory, bufferDecompressor);
+ }
+
+ private SupplierWithException<BufferPool, IOException> createBufferPoolFactory(
+ BufferPoolFactory bufferPoolFactory, int numBuffers) {
+ return () -> bufferPoolFactory.createBufferPool(numBuffers, numBuffers);
+ }
+}
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java
new file mode 100644
index 000000000..1a8d3d2be
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.shuffle.JobShuffleContext;
+import org.apache.flink.runtime.shuffle.PartitionDescriptor;
+import org.apache.flink.runtime.shuffle.ProducerDescriptor;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.runtime.shuffle.ShuffleMaster;
+import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.utils.FlinkUtils;
+import org.apache.celeborn.plugin.flink.utils.ThreadUtils;
+
+public class RemoteShuffleMaster implements ShuffleMaster<RemoteShuffleDescriptor> {
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleMaster.class);
+ private final ShuffleMasterContext shuffleMasterContext;
+ // Flink JobId -> Celeborn register shuffleIds
+ private Map<JobID, Set<Integer>> jobShuffleIds = new ConcurrentHashMap<>();
+ private String celebornAppId;
+ private volatile LifecycleManager lifecycleManager;
+
+ private final ScheduledThreadPoolExecutor executor =
+ new ScheduledThreadPoolExecutor(
+ 1,
+ ThreadUtils.createFactoryWithDefaultExceptionHandler(
+ "remote-shuffle-master-executor", LOG));
+
+ public RemoteShuffleMaster(ShuffleMasterContext shuffleMasterContext) {
+ this.shuffleMasterContext = shuffleMasterContext;
+ }
+
+ @Override
+ public void registerJob(JobShuffleContext context) {
+ JobID jobID = context.getJobId();
+ if (lifecycleManager == null) {
+ synchronized (RemoteShuffleMaster.class) {
+ if (lifecycleManager == null) {
+ // use first jobID as celeborn shared appId for all other flink jobs
+ celebornAppId = FlinkUtils.toCelebornAppId(jobID);
+ CelebornConf celebornConf =
+ FlinkUtils.toCelebornConf(shuffleMasterContext.getConfiguration());
+ lifecycleManager = new LifecycleManager(celebornAppId, celebornConf);
+ }
+ }
+ }
+
+ Set<Integer> previousShuffleIds = jobShuffleIds.putIfAbsent(jobID, new HashSet<>());
+ if (previousShuffleIds != null) {
+ throw new RuntimeException("Duplicated registration job: " + jobID);
+ }
+ }
+
+ @Override
+ public void unregisterJob(JobID jobID) {
+ LOG.info("Unregister job: {}.", jobID);
+ Set<Integer> shuffleIds = jobShuffleIds.remove(jobID);
+ if (shuffleIds != null) {
+ executor.execute(
+ () -> {
+ try {
+ synchronized (shuffleIds) {
+ for (Integer shuffleId : shuffleIds) {
+ lifecycleManager.handleUnregisterShuffle(celebornAppId, shuffleId);
+ }
+ }
+ } catch (Throwable throwable) {
+ LOG.error("Encounter an error when unregistering job: {}.", jobID, throwable);
+ }
+ });
+ }
+ }
+
+ @Override
+ public CompletableFuture<RemoteShuffleDescriptor> registerPartitionWithProducer(
+ JobID jobID, PartitionDescriptor partitionDescriptor, ProducerDescriptor producerDescriptor) {
+ CompletableFuture<RemoteShuffleDescriptor> completableFuture =
+ CompletableFuture.supplyAsync(
+ () -> {
+ Set<Integer> shuffleIds = jobShuffleIds.get(jobID);
+ if (shuffleIds == null) {
+ throw new RuntimeException("Can not find job in lifecycleManager, job: " + jobID);
+ }
+
+ FlinkResultPartitionInfo resultPartitionInfo =
+ new FlinkResultPartitionInfo(jobID, partitionDescriptor, producerDescriptor);
+ LifecycleManager.ShuffleTask shuffleTask =
+ lifecycleManager.encodeExternalShuffleTask(
+ resultPartitionInfo.getShuffleId(),
+ resultPartitionInfo.getTaskId(),
+ resultPartitionInfo.getAttemptId());
+
+ synchronized (shuffleIds) {
+ shuffleIds.add(shuffleTask.shuffleId());
+ }
+
+ ShuffleResourceDescriptor shuffleResourceDescriptor =
+ new ShuffleResourceDescriptor(shuffleTask);
+ RemoteShuffleResource remoteShuffleResource =
+ new RemoteShuffleResource(
+ lifecycleManager.getRssMetaServiceHost(),
+ lifecycleManager.getRssMetaServicePort(),
+ shuffleResourceDescriptor);
+ return new RemoteShuffleDescriptor(
+ celebornAppId,
+ resultPartitionInfo.getShuffleId(),
+ resultPartitionInfo.getResultPartitionId(),
+ remoteShuffleResource);
+ },
+ executor);
+
+ return completableFuture;
+ }
+
+ @Override
+ public void releasePartitionExternally(ShuffleDescriptor shuffleDescriptor) {
+ // TODO
+ }
+
+ @Override
+ public void close() throws Exception {
+ try {
+ jobShuffleIds.clear();
+ lifecycleManager.stop();
+ } catch (Exception e) {
+ LOG.warn("Encounter exception when shutdown: " + e.getMessage(), e);
+ }
+
+ ThreadUtils.shutdownExecutors(10, executor);
+ }
+}
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
new file mode 100644
index 000000000..d481c0f7a
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
@@ -0,0 +1,432 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.CompletableFuture;
+
+import javax.annotation.Nullable;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.plugin.flink.buffer.PartitionSortedBuffer;
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/**
+ * A {@link ResultPartition} which appends records and events to {@link SortBuffer} and after the
+ * {@link SortBuffer} is full, all data in the {@link SortBuffer} will be copied and spilled to the
+ * remote shuffle service in subpartition index order sequentially. Large records that can not be
+ * appended to an empty {@link SortBuffer} will be
+ * spilled directly.
+ */
+public class RemoteShuffleResultPartition extends ResultPartition {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleResultPartition.class);
+
+ /** Size of network buffer and write buffer. */
+ private final int networkBufferSize;
+
+ /** {@link SortBuffer} for records sent by {@link #broadcastRecord(ByteBuffer)}. */
+ private SortBuffer broadcastSortBuffer;
+
+ /** {@link SortBuffer} for records sent by {@link #emitRecord(ByteBuffer, int)}. */
+ private SortBuffer unicastSortBuffer;
+
+ /** Utility to spill data to shuffle workers. */
+ private final RemoteShuffleOutputGate outputGate;
+
+ /** Whether {@link #notifyEndOfData} has been called or not. */
+ private boolean endOfDataNotified;
+
+ public RemoteShuffleResultPartition(
+ String owningTaskName,
+ int partitionIndex,
+ ResultPartitionID partitionId,
+ ResultPartitionType partitionType,
+ int numSubpartitions,
+ int numTargetKeyGroups,
+ int networkBufferSize,
+ ResultPartitionManager partitionManager,
+ @Nullable BufferCompressor bufferCompressor,
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+ RemoteShuffleOutputGate outputGate) {
+
+ super(
+ owningTaskName,
+ partitionIndex,
+ partitionId,
+ partitionType,
+ numSubpartitions,
+ numTargetKeyGroups,
+ partitionManager,
+ bufferCompressor,
+ bufferPoolFactory);
+
+ this.networkBufferSize = networkBufferSize;
+ this.outputGate = outputGate;
+ }
+
+ @Override
+ public void setup() throws IOException {
+ LOG.info("Setup {}", this);
+ super.setup();
+ BufferUtils.reserveNumRequiredBuffers(bufferPool, 1);
+ try {
+ outputGate.setup();
+ } catch (Throwable throwable) {
+ LOG.error("Failed to setup remote output gate.", throwable);
+ Utils.rethrowAsRuntimeException(throwable);
+ }
+ }
+
+ @Override
+ public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException {
+ emit(record, targetSubpartition, DataType.DATA_BUFFER, false);
+ }
+
+ @Override
+ public void broadcastRecord(ByteBuffer record) throws IOException {
+ broadcast(record, DataType.DATA_BUFFER);
+ }
+
+ @Override
+ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws IOException {
+ Buffer buffer = EventSerializer.toBuffer(event, isPriorityEvent);
+ try {
+ ByteBuffer serializedEvent = buffer.getNioBufferReadable();
+ broadcast(serializedEvent, buffer.getDataType());
+ } finally {
+ buffer.recycleBuffer();
+ }
+ }
+
+ private void broadcast(ByteBuffer record, DataType dataType) throws IOException {
+ emit(record, 0, dataType, true);
+ }
+
+ private void emit(
+ ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast)
+ throws IOException {
+
+ checkInProduceState();
+ if (isBroadcast) {
+ Preconditions.checkState(
+ targetSubpartition == 0, "Target subpartition index can only be 0 when broadcast.");
+ }
+
+ SortBuffer sortBuffer = isBroadcast ? getBroadcastSortBuffer() : getUnicastSortBuffer();
+ if (sortBuffer.append(record, targetSubpartition, dataType)) {
+ return;
+ }
+
+ try {
+ if (!sortBuffer.hasRemaining()) {
+ // the record can not be appended to the free sort buffer because it is too large
+ sortBuffer.finish();
+ sortBuffer.release();
+ writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
+ return;
+ }
+ flushSortBuffer(sortBuffer, isBroadcast);
+ } catch (InterruptedException e) {
+ LOG.error("Failed to flush the sort buffer.", e);
+ Utils.rethrowAsRuntimeException(e);
+ }
+ emit(record, targetSubpartition, dataType, isBroadcast);
+ }
+
+ private void releaseSortBuffer(SortBuffer sortBuffer) {
+ if (sortBuffer != null) {
+ sortBuffer.release();
+ }
+ }
+
+ @VisibleForTesting
+ SortBuffer getUnicastSortBuffer() throws IOException {
+ flushBroadcastSortBuffer();
+
+ if (unicastSortBuffer != null && !unicastSortBuffer.isFinished()) {
+ return unicastSortBuffer;
+ }
+
+ unicastSortBuffer =
+ new PartitionSortedBuffer(bufferPool, numSubpartitions, networkBufferSize, null);
+ return unicastSortBuffer;
+ }
+
+ private SortBuffer getBroadcastSortBuffer() throws IOException {
+ flushUnicastSortBuffer();
+
+ if (broadcastSortBuffer != null && !broadcastSortBuffer.isFinished()) {
+ return broadcastSortBuffer;
+ }
+
+ broadcastSortBuffer =
+ new PartitionSortedBuffer(bufferPool, numSubpartitions, networkBufferSize, null);
+ return broadcastSortBuffer;
+ }
+
+ private void flushBroadcastSortBuffer() throws IOException {
+ flushSortBuffer(broadcastSortBuffer, true);
+ }
+
+ private void flushUnicastSortBuffer() throws IOException {
+ flushSortBuffer(unicastSortBuffer, false);
+ }
+
+ @VisibleForTesting
+ void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws IOException {
+ if (sortBuffer == null || sortBuffer.isReleased()) {
+ return;
+ }
+ sortBuffer.finish();
+ if (sortBuffer.hasRemaining()) {
+ try {
+ outputGate.regionStart(isBroadcast);
+ while (sortBuffer.hasRemaining()) {
+ MemorySegment segment = outputGate.getBufferPool().requestMemorySegmentBlocking();
+ SortBuffer.BufferWithChannel bufferWithChannel;
+ try {
+ bufferWithChannel =
+ sortBuffer.copyIntoSegment(
+ segment, outputGate.getBufferPool(), BufferUtils.HEADER_LENGTH);
+ } catch (Throwable t) {
+ outputGate.getBufferPool().recycle(segment);
+ throw new FlinkRuntimeException("Shuffle write failure.", t);
+ }
+
+ Buffer buffer = bufferWithChannel.getBuffer();
+ int subpartitionIndex = bufferWithChannel.getChannelIndex();
+ updateStatistics(bufferWithChannel.getBuffer());
+ writeCompressedBufferIfPossible(buffer, subpartitionIndex);
+ }
+ outputGate.regionFinish();
+ } catch (InterruptedException e) {
+ throw new IOException("Failed to flush the sort buffer, broadcast=" + isBroadcast, e);
+ }
+ }
+ releaseSortBuffer(sortBuffer);
+ }
+
+ private void writeCompressedBufferIfPossible(Buffer buffer, int targetSubpartition)
+ throws InterruptedException {
+ Buffer compressedBuffer = null;
+ try {
+ if (canBeCompressed(buffer)) {
+ Buffer dataBuffer =
+ buffer.readOnlySlice(
+ BufferUtils.HEADER_LENGTH, buffer.getSize() - BufferUtils.HEADER_LENGTH);
+ compressedBuffer =
+ Utils.checkNotNull(bufferCompressor).compressToIntermediateBuffer(dataBuffer);
+ }
+ BufferUtils.setCompressedDataWithHeader(buffer, compressedBuffer);
+ } catch (Throwable throwable) {
+ buffer.recycleBuffer();
+ throw new RuntimeException("Shuffle write failure.", throwable);
+ } finally {
+ if (compressedBuffer != null && compressedBuffer.isCompressed()) {
+ compressedBuffer.setReaderIndex(0);
+ compressedBuffer.recycleBuffer();
+ }
+ }
+ outputGate.write(buffer, targetSubpartition);
+ }
+
+ private void updateStatistics(Buffer buffer) {
+ numBuffersOut.inc();
+ numBytesOut.inc(buffer.readableBytes() - BufferUtils.HEADER_LENGTH);
+ }
+
+ /** Spills the large record into {@link RemoteShuffleOutputGate}. */
+ private void writeLargeRecord(
+ ByteBuffer record, int targetSubpartition, DataType dataType, boolean isBroadcast)
+ throws InterruptedException {
+
+ outputGate.regionStart(isBroadcast);
+ while (record.hasRemaining()) {
+ MemorySegment writeBuffer = outputGate.getBufferPool().requestMemorySegmentBlocking();
+ int toCopy = Math.min(record.remaining(), writeBuffer.size() - BufferUtils.HEADER_LENGTH);
+ writeBuffer.put(BufferUtils.HEADER_LENGTH, record, toCopy);
+ NetworkBuffer buffer =
+ new NetworkBuffer(
+ writeBuffer,
+ outputGate.getBufferPool(),
+ dataType,
+ toCopy + BufferUtils.HEADER_LENGTH);
+
+ updateStatistics(buffer);
+ writeCompressedBufferIfPossible(buffer, targetSubpartition);
+ }
+ outputGate.regionFinish();
+ }
+
+ @Override
+ public void finish() throws IOException {
+ Utils.checkState(!isReleased(), "Result partition is already released.");
+ broadcastEvent(EndOfPartitionEvent.INSTANCE, false);
+ Utils.checkState(
+ unicastSortBuffer == null || unicastSortBuffer.isReleased(),
+ "The unicast sort buffer should be either null or released.");
+ flushBroadcastSortBuffer();
+ try {
+ outputGate.finish();
+ } catch (InterruptedException e) {
+ throw new IOException("Output gate fails to finish.", e);
+ }
+ super.finish();
+ }
+
+ @Override
+ public synchronized void close() {
+ Throwable closeException = null;
+ closeException =
+ checkException(
+ () -> releaseSortBuffer(unicastSortBuffer),
+ closeException,
+ "Failed to release unicast sort buffer.");
+
+ closeException =
+ checkException(
+ () -> releaseSortBuffer(broadcastSortBuffer),
+ closeException,
+ "Failed to release broadcast sort buffer.");
+
+ closeException =
+ checkException(() -> super.close(), closeException, "Failed to call super#close() method.");
+
+ try {
+ outputGate.close();
+ } catch (Throwable throwable) {
+ closeException = closeException == null ? throwable : closeException;
+ LOG.error("Failed to close remote shuffle output gate.", throwable);
+ }
+
+ if (closeException != null) {
+ Utils.rethrowAsRuntimeException(closeException);
+ }
+ }
+
+ private Throwable checkException(Runnable runnable, Throwable exception, String errorMessage) {
+ Throwable newException = null;
+ try {
+ runnable.run();
+ } catch (Throwable throwable) {
+ newException = exception == null ? throwable : exception;
+ LOG.error(errorMessage, throwable);
+ }
+ return newException;
+ }
+
+ @Override
+ protected void releaseInternal() {
+ // no-op
+ }
+
+ @Override
+ public void flushAll() {
+ try {
+ flushUnicastSortBuffer();
+ flushBroadcastSortBuffer();
+ } catch (Throwable t) {
+ LOG.error("Failed to flush the current sort buffer.", t);
+ Utils.rethrowAsRuntimeException(t);
+ }
+ }
+
+ @Override
+ public void flush(int subpartitionIndex) {
+ flushAll();
+ }
+
+ @Override
+ public CompletableFuture<?> getAvailableFuture() {
+ return AVAILABLE;
+ }
+
+ @Override
+ public int getNumberOfQueuedBuffers() {
+ return 0;
+ }
+
+ @Override
+ public long getSizeOfQueuedBuffersUnsafe() {
+ return 0;
+ }
+
+ @Override
+ public int getNumberOfQueuedBuffers(int targetSubpartition) {
+ return 0;
+ }
+
+ @Override
+ public ResultSubpartitionView createSubpartitionView(
+ int index, BufferAvailabilityListener availabilityListener) {
+ throw new UnsupportedOperationException("Not supported.");
+ }
+
+ @Override
+ public void notifyEndOfData(StopMode mode) throws IOException {
+ if (!endOfDataNotified) {
+ broadcastEvent(new EndOfData(mode), false);
+ endOfDataNotified = true;
+ }
+ }
+
+ @Override
+ public CompletableFuture<Void> getAllDataProcessedFuture() {
+ return CompletableFuture.completedFuture(null);
+ }
+
+ @Override
+ public String toString() {
+ return "ResultPartition "
+ + partitionId.toString()
+ + " ["
+ + partitionType
+ + ", "
+ + numSubpartitions
+ + " subpartitions, shuffle-descriptor: "
+ + outputGate.getShuffleDesc()
+ + "]";
+ }
+}
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java
new file mode 100644
index 000000000..c2eb26140
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.plugin.flink.config.PluginConf;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/** Factory class to create {@link RemoteShuffleResultPartition}. */
+public class RemoteShuffleResultPartitionFactory {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RemoteShuffleResultPartitionFactory.class);
+
+ public static final int MIN_BUFFERS_PER_PARTITION = 16;
+
+ /** Not used and just for compatibility with Flink pluggable shuffle service. */
+ private final ResultPartitionManager partitionManager;
+
+ /** Network buffer pool used for shuffle write buffers. */
+ private final BufferPoolFactory bufferPoolFactory;
+
+ /** Network buffer size. */
+ private final int networkBufferSize;
+
+ /**
+ * Configured number of buffers for shuffle write, it contains two parts: sorting buffers and
+ * transportation buffers.
+ */
+ private final int numBuffersPerPartition;
+
+ private String compressionCodec;
+
+ public RemoteShuffleResultPartitionFactory(
+ Configuration flinkConf,
+ CelebornConf celebornConf,
+ ResultPartitionManager partitionManager,
+ BufferPoolFactory bufferPoolFactory,
+ int networkBufferSize) {
+ long configuredMemorySize =
+ org.apache.celeborn.common.util.Utils.byteStringAsBytes(
+ PluginConf.getValue(flinkConf, PluginConf.MEMORY_PER_RESULT_PARTITION));
+ long minConfiguredMemorySize =
+ org.apache.celeborn.common.util.Utils.byteStringAsBytes(
+ PluginConf.getValue(flinkConf, PluginConf.MIN_MEMORY_PER_PARTITION));
+ if (configuredMemorySize < minConfiguredMemorySize) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Insufficient network memory per result partition, please increase %s "
+ + "to at least %s.",
+ PluginConf.MEMORY_PER_RESULT_PARTITION.name, minConfiguredMemorySize));
+ }
+
+ this.numBuffersPerPartition = Utils.checkedDownCast(configuredMemorySize / networkBufferSize);
+ if (numBuffersPerPartition < MIN_BUFFERS_PER_PARTITION) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Insufficient network memory per partition, please increase %s to at "
+ + "least %d bytes.",
+ PluginConf.MEMORY_PER_RESULT_PARTITION.name,
+ networkBufferSize * MIN_BUFFERS_PER_PARTITION));
+ }
+
+ this.partitionManager = partitionManager;
+ this.bufferPoolFactory = bufferPoolFactory;
+ this.networkBufferSize = networkBufferSize;
+ if (PluginConf.getValue(flinkConf, PluginConf.ENABLE_DATA_COMPRESSION).equals("false")) {
+ throw new RuntimeException("remote-shuffle.job.enable-data-compression must be true");
+ }
+ this.compressionCodec =
+ PluginConf.getValue(flinkConf, PluginConf.REMOTE_SHUFFLE_COMPRESSION_CODEC);
+ }
+
+ public ResultPartition create(
+ String taskNameWithSubtaskAndId,
+ int partitionIndex,
+ ResultPartitionDeploymentDescriptor desc,
+ CelebornConf celebornConf) {
+ LOG.info(
+ "Create result partition -- number of buffers per result partition={}, "
+ + "number of subpartitions={}.",
+ numBuffersPerPartition,
+ desc.getNumberOfSubpartitions());
+
+ return create(
+ taskNameWithSubtaskAndId,
+ partitionIndex,
+ desc.getShuffleDescriptor().getResultPartitionID(),
+ desc.getPartitionType(),
+ desc.getNumberOfSubpartitions(),
+ desc.getMaxParallelism(),
+ createBufferPoolFactory(),
+ desc.getShuffleDescriptor(),
+ celebornConf,
+ desc.getTotalNumberOfPartitions());
+ }
+
+ private ResultPartition create(
+ String taskNameWithSubtaskAndId,
+ int partitionIndex,
+ ResultPartitionID id,
+ ResultPartitionType type,
+ int numSubpartitions,
+ int maxParallelism,
+ List<SupplierWithException<BufferPool, IOException>> bufferPoolFactories,
+ ShuffleDescriptor shuffleDescriptor,
+ CelebornConf celebornConf,
+ int numMappers) {
+
+ // in flink1.14/1.15, just support LZ4
+ if (!compressionCodec.equals(CompressionCodec.LZ4.name())) {
+ throw new IllegalStateException("Unknown CompressionMethod " + compressionCodec);
+ }
+ final BufferCompressor bufferCompressor =
+ new BufferCompressor(networkBufferSize, compressionCodec);
+ RemoteShuffleDescriptor rsd = (RemoteShuffleDescriptor) shuffleDescriptor;
+ ResultPartition partition =
+ new RemoteShuffleResultPartition(
+ taskNameWithSubtaskAndId,
+ partitionIndex,
+ id,
+ type,
+ numSubpartitions,
+ maxParallelism,
+ networkBufferSize,
+ partitionManager,
+ bufferCompressor,
+ bufferPoolFactories.get(0),
+ new RemoteShuffleOutputGate(
+ rsd,
+ numSubpartitions,
+ networkBufferSize,
+ bufferPoolFactories.get(1),
+ celebornConf,
+ numMappers));
+ LOG.debug("{}: Initialized {}", taskNameWithSubtaskAndId, this);
+ return partition;
+ }
+
+ /**
+ * Used to create 2 buffer pools -- sorting buffer pool (7/8), transportation buffer pool (1/8).
+ */
+ private List<SupplierWithException<BufferPool, IOException>> createBufferPoolFactory() {
+ int numForResultPartition = numBuffersPerPartition * 7 / 8;
+ int numForOutputGate = numBuffersPerPartition - numForResultPartition;
+
+ List<SupplierWithException<BufferPool, IOException>> factories = new ArrayList<>();
+ factories.add(
+ () -> bufferPoolFactory.createBufferPool(numForResultPartition, numForResultPartition));
+ factories.add(() -> bufferPoolFactory.createBufferPool(numForOutputGate, numForOutputGate));
+ return factories;
+ }
+
+ @VisibleForTesting
+ int getNetworkBufferSize() {
+ return networkBufferSize;
+ }
+}
diff --git a/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
new file mode 100644
index 000000000..4cc02c0fc
--- /dev/null
+++ b/client-flink/flink-1.15/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.registerShuffleMetrics;
+
+import java.time.Duration;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
+import org.apache.flink.runtime.shuffle.ShuffleMaster;
+import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
+import org.apache.flink.runtime.shuffle.ShuffleServiceFactory;
+import org.apache.flink.runtime.util.ConfigurationParserUtils;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.utils.FlinkUtils;
+
+public class RemoteShuffleServiceFactory
+ implements ShuffleServiceFactory<
+ RemoteShuffleDescriptor, ResultPartitionWriter, IndexedInputGate> {
+
+ @Override
+ public ShuffleMaster<RemoteShuffleDescriptor> createShuffleMaster(
+ ShuffleMasterContext shuffleMasterContext) {
+ return new RemoteShuffleMaster(shuffleMasterContext);
+ }
+
+ @Override
+ public ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> createShuffleEnvironment(
+ ShuffleEnvironmentContext shuffleEnvironmentContext) {
+ Configuration configuration = shuffleEnvironmentContext.getConfiguration();
+ int bufferSize = ConfigurationParserUtils.getPageSize(configuration);
+ final int numBuffers =
+ calculateNumberOfNetworkBuffers(
+ shuffleEnvironmentContext.getNetworkMemorySize(), bufferSize);
+
+ ResultPartitionManager resultPartitionManager = new ResultPartitionManager();
+ MetricGroup metricGroup = shuffleEnvironmentContext.getParentMetricGroup();
+
+ Duration requestSegmentsTimeout =
+ Duration.ofMillis(
+ configuration.getLong(
+ NettyShuffleEnvironmentOptions
+ .NETWORK_EXCLUSIVE_BUFFERS_REQUEST_TIMEOUT_MILLISECONDS));
+ NetworkBufferPool networkBufferPool =
+ new NetworkBufferPool(numBuffers, bufferSize, requestSegmentsTimeout);
+
+ registerShuffleMetrics(metricGroup, networkBufferPool);
+ CelebornConf celebornConf = FlinkUtils.toCelebornConf(configuration);
+ RemoteShuffleResultPartitionFactory resultPartitionFactory =
+ new RemoteShuffleResultPartitionFactory(
+ configuration, celebornConf, resultPartitionManager, networkBufferPool, bufferSize);
+ RemoteShuffleInputGateFactory inputGateFactory =
+ new RemoteShuffleInputGateFactory(
+ configuration, celebornConf, networkBufferPool, bufferSize);
+
+ return new RemoteShuffleEnvironment(
+ networkBufferPool,
+ resultPartitionManager,
+ resultPartitionFactory,
+ inputGateFactory,
+ celebornConf);
+ }
+
+ private static int calculateNumberOfNetworkBuffers(MemorySize memorySize, int bufferSize) {
+ long numBuffersLong = memorySize.getBytes() / bufferSize;
+ if (numBuffersLong > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException(
+ "The given number of memory bytes ("
+ + memorySize.getBytes()
+ + ") corresponds to more than MAX_INT pages.");
+ }
+ return (int) numBuffersLong;
+ }
+}
diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuitJ.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuitJ.java
new file mode 100644
index 000000000..b9b52175c
--- /dev/null
+++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuitJ.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.DATA_BUFFER;
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.EVENT_BUFFER;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+
+public class BufferPackSuitJ {
+ private static final int BUFFER_SIZE = 20 + 16;
+
+ private NetworkBufferPool networkBufferPool;
+
+ private BufferPool bufferPool;
+
+ @Before
+ public void setup() throws Exception {
+ networkBufferPool = new NetworkBufferPool(10, BUFFER_SIZE);
+ bufferPool = networkBufferPool.createBufferPool(10, 10);
+ }
+
+ @After
+ public void tearDown() {
+ bufferPool.lazyDestroy();
+ assertEquals(10, networkBufferPool.getNumberOfAvailableMemorySegments());
+ networkBufferPool.destroy();
+ }
+
+ @Test
+ public void testPackEmptyBuffers() throws Exception {
+ List<Buffer> buffers = requestBuffers(3);
+ setCompressed(buffers, true, true, false);
+ setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+
+ Integer subIdx = 2;
+
+ List<ByteBuf> output = new ArrayList<>();
+ BufferPacker.BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler =
+ (ripe, sub) -> {
+ assertEquals(subIdx, sub);
+ output.add(ripe);
+ };
+
+ BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ packer.process(buffers.get(0), subIdx);
+ packer.process(buffers.get(1), subIdx);
+ packer.process(buffers.get(2), subIdx);
+ assertTrue(output.isEmpty());
+
+ packer.drain();
+ assertEquals(0, output.size());
+ }
+
+ @Test
+ public void testPartialBuffersForSameSubIdx() throws Exception {
+ List<Buffer> buffers = requestBuffers(3);
+ setCompressed(buffers, true, true, false);
+ setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+
+ List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
+ BufferPacker.BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler =
+ (ripe, sub) -> output.add(Pair.of(ripe, sub));
+ BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ fillBuffers(buffers, 0, 1, 2);
+
+ packer.process(buffers.get(0), 2);
+ packer.process(buffers.get(1), 2);
+ assertEquals(0, output.size());
+
+ packer.process(buffers.get(2), 2);
+ assertEquals(1, output.size());
+
+ packer.drain();
+ assertEquals(2, output.size());
+
+ List<Buffer> unpacked = new ArrayList<>();
+ output.forEach(
+ pair -> {
+ assertEquals(Integer.valueOf(2), pair.getRight());
+ unpacked.addAll(BufferPacker.unpack(pair.getLeft()));
+ });
+ checkIfCompressed(unpacked, true, true, false);
+ checkDataType(unpacked, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+ verifyBuffers(unpacked, 0, 1, 2);
+ unpacked.forEach(Buffer::recycleBuffer);
+ }
+
+ @Test
+ public void testPartialBuffersForMultipleSubIdx() throws Exception {
+ List<Buffer> buffers = requestBuffers(3);
+ setCompressed(buffers, true, true, false);
+ setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+
+ List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
+ BufferPacker.BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler =
+ (ripe, sub) -> output.add(Pair.of(ripe, sub));
+ BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ fillBuffers(buffers, 0, 1, 2);
+
+ packer.process(buffers.get(0), 0);
+ packer.process(buffers.get(1), 1);
+ assertEquals(1, output.size());
+
+ packer.process(buffers.get(2), 1);
+ assertEquals(1, output.size());
+
+ packer.drain();
+ assertEquals(2, output.size());
+
+ List<Buffer> unpacked = new ArrayList<>();
+ for (int i = 0; i < output.size(); i++) {
+ Pair<ByteBuf, Integer> pair = output.get(i);
+ assertEquals(Integer.valueOf(i), pair.getRight());
+ unpacked.addAll(BufferPacker.unpack(pair.getLeft()));
+ }
+
+ checkIfCompressed(unpacked, true, true, false);
+ checkDataType(unpacked, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+ verifyBuffers(unpacked, 0, 1, 2);
+ unpacked.forEach(Buffer::recycleBuffer);
+ }
+
+ @Test
+ public void testUnpackedBuffers() throws Exception {
+ List<Buffer> buffers = requestBuffers(3);
+ setCompressed(buffers, true, true, false);
+ setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+
+ List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
+ BufferPacker.BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler =
+ (ripe, sub) -> output.add(Pair.of(ripe, sub));
+ BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ fillBuffers(buffers, 0, 1, 2);
+
+ packer.process(buffers.get(0), 0);
+ packer.process(buffers.get(1), 1);
+ assertEquals(1, output.size());
+
+ packer.process(buffers.get(2), 2);
+ assertEquals(2, output.size());
+
+ packer.drain();
+ assertEquals(3, output.size());
+
+ List<Buffer> unpacked = new ArrayList<>();
+ for (int i = 0; i < output.size(); i++) {
+ Pair<ByteBuf, Integer> pair = output.get(i);
+ assertEquals(Integer.valueOf(i), pair.getRight());
+ unpacked.addAll(BufferPacker.unpack(pair.getLeft()));
+ }
+
+ checkIfCompressed(unpacked, true, true, false);
+ checkDataType(unpacked, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
+ verifyBuffers(unpacked, 0, 1, 2);
+ unpacked.forEach(Buffer::recycleBuffer);
+ }
+
+ @Test
+ public void testFailedToHandleRipeBufferAndClose() throws Exception {
+ List<Buffer> buffers = requestBuffers(1);
+ setCompressed(buffers, false);
+ setDataType(buffers, DATA_BUFFER);
+ fillBuffers(buffers, 0);
+
+ BufferPacker.BiConsumerWithException<ByteBuf, Integer, InterruptedException> ripeBufferHandler =
+ (ripe, sub) -> {
+ // ripe.release();
+ throw new RuntimeException("Test");
+ };
+ BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ System.out.println(buffers.get(0).refCnt());
+ packer.process(buffers.get(0), 0);
+ try {
+ packer.drain();
+ } catch (RuntimeException e) {
+ e.printStackTrace();
+ } catch (Exception e) {
+ throw e;
+ }
+
+ // this should never throw any exception
+ packer.close();
+ assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
+ }
+
+ private List<Buffer> requestBuffers(int n) {
+ List<Buffer> buffers = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ Buffer buffer = bufferPool.requestBuffer();
+ buffers.add(buffer);
+ }
+ return buffers;
+ }
+
+ private void setCompressed(List<Buffer> buffers, boolean... values) {
+ for (int i = 0; i < buffers.size(); i++) {
+ buffers.get(i).setCompressed(values[i]);
+ }
+ }
+
+ private void setDataType(List<Buffer> buffers, Buffer.DataType... values) {
+ for (int i = 0; i < buffers.size(); i++) {
+ buffers.get(i).setDataType(values[i]);
+ }
+ }
+
+ private void checkIfCompressed(List<Buffer> buffers, boolean... values) {
+ for (int i = 0; i < buffers.size(); i++) {
+ assertEquals(values[i], buffers.get(i).isCompressed());
+ }
+ }
+
+ private void checkDataType(List<Buffer> buffers, Buffer.DataType... values) {
+ for (int i = 0; i < buffers.size(); i++) {
+ assertEquals(values[i], buffers.get(i).getDataType());
+ }
+ }
+
+ private void fillBuffers(List<Buffer> buffers, int... ints) {
+ for (int i = 0; i < buffers.size(); i++) {
+ Buffer buffer = buffers.get(i);
+ ByteBuf target = buffer.asByteBuf();
+ BufferUtils.setBufferHeader(target, buffer.getDataType(), buffer.isCompressed(), 4);
+ target.writerIndex(BufferUtils.HEADER_LENGTH);
+ target.writeInt(ints[i]);
+ }
+ }
+
+ private void verifyBuffers(List<Buffer> buffers, int... expects) {
+ for (int i = 0; i < buffers.size(); i++) {
+ ByteBuf actual = buffers.get(i).asByteBuf();
+ assertEquals(expects[i], actual.getInt(0));
+ }
+ }
+}
diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/PartitionSortedBufferSuitJ.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/PartitionSortedBufferSuitJ.java
new file mode 100644
index 000000000..c1bfdf2a0
--- /dev/null
+++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/PartitionSortedBufferSuitJ.java
@@ -0,0 +1,369 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Queue;
+import java.util.Random;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.junit.Test;
+
+import org.apache.celeborn.plugin.flink.buffer.PartitionSortedBuffer;
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+
+public class PartitionSortedBufferSuitJ {
+ @Test
+ public void testWriteAndReadSortBuffer() throws Exception {
+ int numSubpartitions = 10;
+ int bufferSize = 1024;
+ int bufferPoolSize = 1000;
+ Random random = new Random(1111);
+
+ // used to store data written to and read from sort buffer for correctness check
+ Queue<DataAndType>[] dataWritten = new Queue[numSubpartitions];
+ Queue<Buffer>[] buffersRead = new Queue[numSubpartitions];
+ for (int i = 0; i < numSubpartitions; ++i) {
+ dataWritten[i] = new ArrayDeque<>();
+ buffersRead[i] = new ArrayDeque<>();
+ }
+
+ int[] numBytesWritten = new int[numSubpartitions];
+ int[] numBytesRead = new int[numSubpartitions];
+ Arrays.fill(numBytesWritten, 0);
+ Arrays.fill(numBytesRead, 0);
+
+ // fill the sort buffer with randomly generated data
+ int totalBytesWritten = 0;
+ SortBuffer sortBuffer =
+ createSortBuffer(
+ bufferPoolSize,
+ bufferSize,
+ numSubpartitions,
+ getRandomSubpartitionOrder(numSubpartitions));
+ while (true) {
+ // record size may be larger than buffer size so a record may span multiple segments
+ int recordSize = random.nextInt(bufferSize * 4 - 1) + 1;
+ byte[] bytes = new byte[recordSize];
+
+ // fill record with random value
+ random.nextBytes(bytes);
+ ByteBuffer record = ByteBuffer.wrap(bytes);
+
+ // select a random subpartition to write
+ int subpartition = random.nextInt(numSubpartitions);
+
+ // select a random data type
+ boolean isBuffer = random.nextBoolean() || recordSize > bufferSize;
+ Buffer.DataType dataType =
+ isBuffer ? Buffer.DataType.DATA_BUFFER : Buffer.DataType.EVENT_BUFFER;
+ if (!sortBuffer.append(record, subpartition, dataType)) {
+ sortBuffer.finish();
+ break;
+ }
+ record.rewind();
+ dataWritten[subpartition].add(new DataAndType(record, dataType));
+ numBytesWritten[subpartition] += recordSize;
+ totalBytesWritten += recordSize;
+ }
+
+ // read all data from the sort buffer
+ while (sortBuffer.hasRemaining()) {
+ MemorySegment readBuffer = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
+ SortBuffer.BufferWithChannel bufferAndChannel =
+ sortBuffer.copyIntoSegment(readBuffer, ignore -> {}, 0);
+ int subpartition = bufferAndChannel.getChannelIndex();
+ buffersRead[subpartition].add(bufferAndChannel.getBuffer());
+ numBytesRead[subpartition] += bufferAndChannel.getBuffer().readableBytes();
+ }
+
+ assertEquals(totalBytesWritten, sortBuffer.numBytes());
+ checkWriteReadResult(numSubpartitions, numBytesWritten, numBytesRead, dataWritten, buffersRead);
+ }
+
+ public static void checkWriteReadResult(
+ int numSubpartitions,
+ int[] numBytesWritten,
+ int[] numBytesRead,
+ Queue<DataAndType>[] dataWritten,
+ Collection<Buffer>[] buffersRead) {
+ for (int subpartitionIndex = 0; subpartitionIndex < numSubpartitions; ++subpartitionIndex) {
+ assertEquals(numBytesWritten[subpartitionIndex], numBytesRead[subpartitionIndex]);
+
+ List<DataAndType> eventsWritten = new ArrayList<>();
+ List<Buffer> eventsRead = new ArrayList<>();
+
+ ByteBuffer subpartitionDataWritten = ByteBuffer.allocate(numBytesWritten[subpartitionIndex]);
+ for (DataAndType dataAndType : dataWritten[subpartitionIndex]) {
+ subpartitionDataWritten.put(dataAndType.data);
+ dataAndType.data.rewind();
+ if (dataAndType.dataType.isEvent()) {
+ eventsWritten.add(dataAndType);
+ }
+ }
+
+ ByteBuffer subpartitionDataRead = ByteBuffer.allocate(numBytesRead[subpartitionIndex]);
+ for (Buffer buffer : buffersRead[subpartitionIndex]) {
+ subpartitionDataRead.put(buffer.getNioBufferReadable());
+ if (!buffer.isBuffer()) {
+ eventsRead.add(buffer);
+ }
+ }
+
+ subpartitionDataWritten.flip();
+ subpartitionDataRead.flip();
+ assertEquals(subpartitionDataWritten, subpartitionDataRead);
+
+ assertEquals(eventsWritten.size(), eventsRead.size());
+ for (int i = 0; i < eventsWritten.size(); ++i) {
+ assertEquals(eventsWritten.get(i).dataType, eventsRead.get(i).getDataType());
+ assertEquals(eventsWritten.get(i).data, eventsRead.get(i).getNioBufferReadable());
+ }
+ }
+ }
+
+ @Test
+ public void testWriteReadWithEmptyChannel() throws Exception {
+ int bufferPoolSize = 10;
+ int bufferSize = 1024;
+ int numSubpartitions = 5;
+
+ ByteBuffer[] subpartitionRecords = {
+ ByteBuffer.allocate(128), null, ByteBuffer.allocate(1536), null, ByteBuffer.allocate(1024)
+ };
+
+ SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, numSubpartitions);
+ for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) {
+ ByteBuffer record = subpartitionRecords[subpartition];
+ if (record != null) {
+ sortBuffer.append(record, subpartition, Buffer.DataType.DATA_BUFFER);
+ record.rewind();
+ }
+ }
+ sortBuffer.finish();
+
+ checkReadResult(sortBuffer, subpartitionRecords[0], 0, bufferSize);
+
+ ByteBuffer expected1 = subpartitionRecords[2].duplicate();
+ expected1.limit(bufferSize);
+ checkReadResult(sortBuffer, expected1.slice(), 2, bufferSize);
+
+ ByteBuffer expected2 = subpartitionRecords[2].duplicate();
+ expected2.position(bufferSize);
+ checkReadResult(sortBuffer, expected2.slice(), 2, bufferSize);
+
+ checkReadResult(sortBuffer, subpartitionRecords[4], 4, bufferSize);
+ }
+
+ private void checkReadResult(
+ SortBuffer sortBuffer, ByteBuffer expectedBuffer, int expectedChannel, int bufferSize) {
+ MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
+ SortBuffer.BufferWithChannel bufferWithChannel =
+ sortBuffer.copyIntoSegment(segment, ignore -> {}, 0);
+ assertEquals(expectedChannel, bufferWithChannel.getChannelIndex());
+ assertEquals(expectedBuffer, bufferWithChannel.getBuffer().getNioBufferReadable());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWriteEmptyData() throws Exception {
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+
+ ByteBuffer record = ByteBuffer.allocate(1);
+ record.position(1);
+
+ sortBuffer.append(record, 0, Buffer.DataType.DATA_BUFFER);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testWriteFinishedSortBuffer() throws Exception {
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+ sortBuffer.finish();
+
+ sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testWriteReleasedSortBuffer() throws Exception {
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+ sortBuffer.release();
+
+ sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+ }
+
+ @Test
+ public void testWriteMoreDataThanCapacity() throws Exception {
+ int bufferPoolSize = 10;
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, 1);
+
+ for (int i = 1; i < bufferPoolSize; ++i) {
+ appendAndCheckResult(sortBuffer, bufferSize, true, bufferSize * i, i, true);
+ }
+
+ // append should fail for insufficient capacity
+ int numRecords = bufferPoolSize - 1;
+ appendAndCheckResult(sortBuffer, bufferSize, false, bufferSize * numRecords, numRecords, true);
+ }
+
+ @Test
+ public void testWriteLargeRecord() throws Exception {
+ int bufferPoolSize = 10;
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(bufferPoolSize, bufferSize, 1);
+ // append should fail for insufficient capacity
+ appendAndCheckResult(sortBuffer, bufferPoolSize * bufferSize, false, 0, 0, false);
+ }
+
+ private void appendAndCheckResult(
+ SortBuffer sortBuffer,
+ int recordSize,
+ boolean isSuccessful,
+ long numBytes,
+ long numRecords,
+ boolean hasRemaining)
+ throws IOException {
+ ByteBuffer largeRecord = ByteBuffer.allocate(recordSize);
+
+ assertEquals(isSuccessful, sortBuffer.append(largeRecord, 0, Buffer.DataType.DATA_BUFFER));
+ assertEquals(numBytes, sortBuffer.numBytes());
+ assertEquals(numRecords, sortBuffer.numRecords());
+ assertEquals(hasRemaining, sortBuffer.hasRemaining());
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testReadUnfinishedSortBuffer() throws Exception {
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+ sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+
+ assertTrue(sortBuffer.hasRemaining());
+ sortBuffer.copyIntoSegment(
+ MemorySegmentFactory.allocateUnpooledSegment(bufferSize), ignore -> {}, 0);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testReadReleasedSortBuffer() throws Exception {
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+ sortBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
+ sortBuffer.finish();
+ assertTrue(sortBuffer.hasRemaining());
+
+ sortBuffer.release();
+ assertFalse(sortBuffer.hasRemaining());
+
+ sortBuffer.copyIntoSegment(
+ MemorySegmentFactory.allocateUnpooledSegment(bufferSize), ignore -> {}, 0);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testReadEmptySortBuffer() throws Exception {
+ int bufferSize = 1024;
+
+ SortBuffer sortBuffer = createSortBuffer(1, bufferSize, 1);
+ sortBuffer.finish();
+
+ assertFalse(sortBuffer.hasRemaining());
+ sortBuffer.copyIntoSegment(
+ MemorySegmentFactory.allocateUnpooledSegment(bufferSize), ignore -> {}, 0);
+ }
+
+ @Test
+ public void testReleaseSortBuffer() throws Exception {
+ int bufferPoolSize = 10;
+ int bufferSize = 1024;
+ int recordSize = (bufferPoolSize - 1) * bufferSize;
+
+ NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
+ BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
+
+ SortBuffer sortBuffer = new PartitionSortedBuffer(bufferPool, 1, bufferSize, null);
+ sortBuffer.append(ByteBuffer.allocate(recordSize), 0, Buffer.DataType.DATA_BUFFER);
+
+ assertEquals(bufferPoolSize, bufferPool.bestEffortGetNumOfUsedBuffers());
+ assertTrue(sortBuffer.hasRemaining());
+ assertEquals(1, sortBuffer.numRecords());
+ assertEquals(recordSize, sortBuffer.numBytes());
+
+ // should release all data and resources
+ sortBuffer.release();
+ assertEquals(0, bufferPool.bestEffortGetNumOfUsedBuffers());
+ assertFalse(sortBuffer.hasRemaining());
+ assertEquals(0, sortBuffer.numRecords());
+ assertEquals(0, sortBuffer.numBytes());
+ }
+
+ private SortBuffer createSortBuffer(int bufferPoolSize, int bufferSize, int numSubpartitions)
+ throws IOException {
+ return createSortBuffer(bufferPoolSize, bufferSize, numSubpartitions, null);
+ }
+
+ private SortBuffer createSortBuffer(
+ int bufferPoolSize, int bufferSize, int numSubpartitions, int[] customReadOrder)
+ throws IOException {
+ NetworkBufferPool globalPool = new NetworkBufferPool(bufferPoolSize, bufferSize);
+ BufferPool bufferPool = globalPool.createBufferPool(bufferPoolSize, bufferPoolSize);
+
+ return new PartitionSortedBuffer(bufferPool, numSubpartitions, bufferSize, customReadOrder);
+ }
+
+ public static int[] getRandomSubpartitionOrder(int numSubpartitions) {
+ Random random = new Random(1111);
+ int[] subpartitionReadOrder = new int[numSubpartitions];
+ int shift = random.nextInt(numSubpartitions);
+ for (int i = 0; i < numSubpartitions; ++i) {
+ subpartitionReadOrder[i] = (i + shift) % numSubpartitions;
+ }
+ return subpartitionReadOrder;
+ }
+
+ /** Data written and its {@link Buffer.DataType}. */
+ public static class DataAndType {
+ private final ByteBuffer data;
+ private final Buffer.DataType dataType;
+
+ DataAndType(ByteBuffer data, Buffer.DataType dataType) {
+ this.data = data;
+ this.dataType = dataType;
+ }
+ }
+}
diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java
new file mode 100644
index 000000000..873693d15
--- /dev/null
+++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.util.Collection;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.clusterframework.types.ResourceID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.shuffle.JobShuffleContext;
+import org.apache.flink.runtime.shuffle.PartitionDescriptor;
+import org.apache.flink.runtime.shuffle.ProducerDescriptor;
+import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.util.PackedPartitionId;
+
+public class RemoteShuffleMasterTest {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleMasterTest.class);
+ private RemoteShuffleMaster remoteShuffleMaster;
+
+ @Before
+ public void setUp() {
+ Configuration configuration = new Configuration();
+ remoteShuffleMaster = createShuffleMaster(configuration);
+ }
+
+ @Test
+ public void testRegisterJob() {
+ JobShuffleContext jobShuffleContext = createJobShuffleContext(JobID.generate());
+ remoteShuffleMaster.registerJob(jobShuffleContext);
+
+ // reRunRegister job
+ try {
+ remoteShuffleMaster.registerJob(jobShuffleContext);
+ } catch (Exception e) {
+ Assert.assertTrue(true);
+ }
+
+ // unRegister job
+ remoteShuffleMaster.unregisterJob(jobShuffleContext.getJobId());
+ remoteShuffleMaster.registerJob(jobShuffleContext);
+ }
+
+ @Test
+ public void testRegisterPartitionWithProducer()
+ throws UnknownHostException, ExecutionException, InterruptedException {
+ JobID jobID = JobID.generate();
+ JobShuffleContext jobShuffleContext = createJobShuffleContext(jobID);
+ remoteShuffleMaster.registerJob(jobShuffleContext);
+
+ IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
+ PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0);
+ ProducerDescriptor producerDescriptor = createProducerDescriptor();
+ RemoteShuffleDescriptor remoteShuffleDescriptor =
+ remoteShuffleMaster
+ .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
+ .get();
+ ShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource();
+ ShuffleResourceDescriptor mapPartitionShuffleDescriptor =
+ shuffleResource.getMapPartitionShuffleDescriptor();
+
+ LOG.info("remoteShuffleDescriptor:", remoteShuffleDescriptor);
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getPartitionId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getAttemptId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getMapId());
+
+ // use same dataset id
+ partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 1);
+ remoteShuffleDescriptor =
+ remoteShuffleMaster
+ .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
+ .get();
+ mapPartitionShuffleDescriptor =
+ remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+ Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
+
+ // use another attemptId
+ producerDescriptor = createProducerDescriptor();
+ remoteShuffleDescriptor =
+ remoteShuffleMaster
+ .registerPartitionWithProducer(jobID, partitionDescriptor, producerDescriptor)
+ .get();
+ mapPartitionShuffleDescriptor =
+ remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+ Assert.assertEquals(
+ PackedPartitionId.packedPartitionId(1, 1), mapPartitionShuffleDescriptor.getPartitionId());
+ Assert.assertEquals(1, mapPartitionShuffleDescriptor.getAttemptId());
+ Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
+ }
+
+ @Test
+ public void testRegisterMultipleJobs()
+ throws UnknownHostException, ExecutionException, InterruptedException {
+ JobID jobID1 = JobID.generate();
+ JobShuffleContext jobShuffleContext1 = createJobShuffleContext(jobID1);
+ remoteShuffleMaster.registerJob(jobShuffleContext1);
+
+ JobID jobID2 = JobID.generate();
+ JobShuffleContext jobShuffleContext2 = createJobShuffleContext(jobID2);
+ remoteShuffleMaster.registerJob(jobShuffleContext2);
+
+ IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
+ PartitionDescriptor partitionDescriptor = createPartitionDescriptor(intermediateDataSetID, 0);
+ ProducerDescriptor producerDescriptor = createProducerDescriptor();
+ RemoteShuffleDescriptor remoteShuffleDescriptor1 =
+ remoteShuffleMaster
+ .registerPartitionWithProducer(jobID1, partitionDescriptor, producerDescriptor)
+ .get();
+
+ // use same datasetId but different jobId
+ RemoteShuffleDescriptor remoteShuffleDescriptor2 =
+ remoteShuffleMaster
+ .registerPartitionWithProducer(jobID2, partitionDescriptor, producerDescriptor)
+ .get();
+
+ Assert.assertEquals(
+ remoteShuffleDescriptor1
+ .getShuffleResource()
+ .getMapPartitionShuffleDescriptor()
+ .getShuffleId(),
+ 0);
+ Assert.assertEquals(
+ remoteShuffleDescriptor2
+ .getShuffleResource()
+ .getMapPartitionShuffleDescriptor()
+ .getShuffleId(),
+ 1);
+ }
+
+ @After
+ public void tearDown() {
+ if (remoteShuffleMaster != null) {
+ try {
+ remoteShuffleMaster.close();
+ } catch (Exception e) {
+ LOG.warn(e.getMessage(), e);
+ }
+ }
+ }
+
+ public RemoteShuffleMaster createShuffleMaster(Configuration configuration) {
+ remoteShuffleMaster =
+ new RemoteShuffleMaster(
+ new ShuffleMasterContext() {
+ @Override
+ public Configuration getConfiguration() {
+ return configuration;
+ }
+
+ @Override
+ public void onFatalError(Throwable throwable) {
+ System.exit(-1);
+ }
+ });
+
+ return remoteShuffleMaster;
+ }
+
+ public JobShuffleContext createJobShuffleContext(JobID jobId) {
+ return new JobShuffleContext() {
+ @Override
+ public org.apache.flink.api.common.JobID getJobId() {
+ return jobId;
+ }
+
+ @Override
+ public CompletableFuture<?> stopTrackingAndReleasePartitions(
+ Collection<ResultPartitionID> collection) {
+ return CompletableFuture.completedFuture(null);
+ }
+ };
+ }
+
+ public PartitionDescriptor createPartitionDescriptor(
+ IntermediateDataSetID intermediateDataSetId, int partitionNum) {
+ IntermediateResultPartitionID intermediateResultPartitionId =
+ new IntermediateResultPartitionID(intermediateDataSetId, partitionNum);
+ return new PartitionDescriptor(
+ intermediateDataSetId,
+ 10,
+ intermediateResultPartitionId,
+ ResultPartitionType.BLOCKING,
+ 5,
+ 1);
+ }
+
+ public ProducerDescriptor createProducerDescriptor() throws UnknownHostException {
+ ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
+ return new ProducerDescriptor(
+ ResourceID.generate(), executionAttemptId, InetAddress.getLocalHost(), 100);
+ }
+}
diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java
new file mode 100644
index 000000000..3dd9ec64f
--- /dev/null
+++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+
+public class RemoteShuffleOutputGateSuiteJ {
+ private RemoteShuffleOutputGate remoteShuffleOutputGate = mock(RemoteShuffleOutputGate.class);
+ private FlinkShuffleClientImpl shuffleClient = mock(FlinkShuffleClientImpl.class);
+ private static final int BUFFER_SIZE = 20;
+ private NetworkBufferPool networkBufferPool;
+ private BufferPool bufferPool;
+
+ @Before
+ public void setup() throws IOException {
+ remoteShuffleOutputGate.shuffleWriteClient = shuffleClient;
+ networkBufferPool = new NetworkBufferPool(10, BUFFER_SIZE);
+ bufferPool = networkBufferPool.createBufferPool(10, 10);
+ }
+
+ @Test
+ public void TestSimpleWriteData() throws IOException, InterruptedException {
+
+ PartitionLocation partitionLocation =
+ new PartitionLocation(1, 0, "localhost", 123, 245, 789, 238, PartitionLocation.Mode.MASTER);
+ when(shuffleClient.registerMapPartitionTask(any(), anyInt(), anyInt(), anyInt(), anyInt()))
+ .thenAnswer(t -> partitionLocation);
+ doNothing()
+ .when(remoteShuffleOutputGate.shuffleWriteClient)
+ .pushDataHandShake(anyString(), anyInt(), anyInt(), anyInt(), anyInt(), anyInt(), any());
+
+ remoteShuffleOutputGate.handshake(true);
+
+ when(remoteShuffleOutputGate.shuffleWriteClient.regionStart(
+ any(), anyInt(), anyInt(), anyInt(), any(), anyInt(), anyBoolean()))
+ .thenAnswer(t -> Optional.empty());
+ remoteShuffleOutputGate.regionStart(false);
+
+ remoteShuffleOutputGate.write(bufferPool.requestBuffer(), 0);
+
+ doNothing()
+ .when(remoteShuffleOutputGate.shuffleWriteClient)
+ .regionFinish(any(), anyInt(), anyInt(), anyInt(), any());
+ remoteShuffleOutputGate.regionFinish();
+
+ doNothing()
+ .when(remoteShuffleOutputGate.shuffleWriteClient)
+ .mapperEnd(any(), anyInt(), anyInt(), anyInt(), anyInt());
+ remoteShuffleOutputGate.finish();
+
+ doNothing().when(remoteShuffleOutputGate.shuffleWriteClient).shutdown();
+ remoteShuffleOutputGate.close();
+ }
+
+ @Test
+ public void testNettyPoolTransfrom() {
+ Buffer buffer = bufferPool.requestBuffer();
+ ByteBuf byteBuf = buffer.asByteBuf();
+ byteBuf.writeByte(1);
+ Assert.assertEquals(1, byteBuf.refCnt());
+ io.netty.buffer.ByteBuf celebornByteBuf =
+ io.netty.buffer.Unpooled.wrappedBuffer(byteBuf.nioBuffer());
+ Assert.assertEquals(1, celebornByteBuf.refCnt());
+ celebornByteBuf.release();
+ byteBuf.release();
+ Assert.assertEquals(0, byteBuf.refCnt());
+ Assert.assertEquals(0, celebornByteBuf.refCnt());
+ }
+}
diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
new file mode 100644
index 000000000..aa9a66070
--- /dev/null
+++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
@@ -0,0 +1,623 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.time.Duration;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Random;
+import java.util.Set;
+import java.util.stream.IntStream;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.util.function.SupplierWithException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+
+public class RemoteShuffleResultPartitionSuiteJ {
+ private final int networkBufferSize = 32 * 1024;
+ private BufferCompressor bufferCompressor = new BufferCompressor(networkBufferSize, "lz4");
+ private RemoteShuffleOutputGate remoteShuffleOutputGate = mock(RemoteShuffleOutputGate.class);
+ private final String compressCodec = "LZ4";
+ private final CelebornConf conf = new CelebornConf();
+ BufferDecompressor bufferDecompressor = new BufferDecompressor(networkBufferSize, "LZ4");
+
+ private static final int totalBuffers = 1000;
+
+ private static final int bufferSize = 1024;
+
+ private NetworkBufferPool globalBufferPool;
+
+ private BufferPool sortBufferPool;
+
+ private BufferPool nettyBufferPool;
+
+ private RemoteShuffleResultPartition partitionWriter;
+
+ private FakedRemoteShuffleOutputGate outputGate;
+
+ @Before
+ public void setup() {
+ globalBufferPool = new NetworkBufferPool(totalBuffers, bufferSize);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ if (outputGate != null) {
+ outputGate.release();
+ }
+
+ if (sortBufferPool != null) {
+ sortBufferPool.lazyDestroy();
+ }
+ if (nettyBufferPool != null) {
+ nettyBufferPool.lazyDestroy();
+ }
+ assertEquals(totalBuffers, globalBufferPool.getNumberOfAvailableMemorySegments());
+ globalBufferPool.destroy();
+ }
+
+ @Test
+ public void tesSimpleFlush() throws IOException, InterruptedException {
+ List<SupplierWithException<BufferPool, IOException>> bufferPool = createBufferPoolFactory();
+ RemoteShuffleResultPartition remoteShuffleResultPartition =
+ new RemoteShuffleResultPartition(
+ "test",
+ 0,
+ new ResultPartitionID(),
+ ResultPartitionType.BLOCKING,
+ 2,
+ 2,
+ 32 * 1024,
+ new ResultPartitionManager(),
+ bufferCompressor,
+ bufferPool.get(0),
+ remoteShuffleOutputGate);
+ remoteShuffleResultPartition.setup();
+ doNothing().when(remoteShuffleOutputGate).regionStart(anyBoolean());
+ doNothing().when(remoteShuffleOutputGate).regionFinish();
+ when(remoteShuffleOutputGate.getBufferPool()).thenReturn(bufferPool.get(1).get());
+ SortBuffer sortBuffer = remoteShuffleResultPartition.getUnicastSortBuffer();
+ ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[] {1, 2, 3});
+ sortBuffer.append(byteBuffer, 0, Buffer.DataType.DATA_BUFFER);
+ remoteShuffleResultPartition.flushSortBuffer(sortBuffer, true);
+ }
+
+ private List<SupplierWithException<BufferPool, IOException>> createBufferPoolFactory() {
+ NetworkBufferPool networkBufferPool =
+ new NetworkBufferPool(256 * 8, 32 * 1024, Duration.ofMillis(1000));
+
+ int numBuffersPerPartition = 64 * 1024 / 32;
+ int numForResultPartition = numBuffersPerPartition * 7 / 8;
+ int numForOutputGate = numBuffersPerPartition - numForResultPartition;
+
+ List<SupplierWithException<BufferPool, IOException>> factories = new ArrayList<>();
+ factories.add(
+ () -> networkBufferPool.createBufferPool(numForResultPartition, numForResultPartition));
+ factories.add(() -> networkBufferPool.createBufferPool(numForOutputGate, numForOutputGate));
+ return factories;
+ }
+
+ @Test
+ public void testWriteNormalRecordWithCompressionEnabled() throws Exception {
+ testWriteNormalRecord(true);
+ }
+
+ @Test
+ public void testWriteNormalRecordWithCompressionDisabled() throws Exception {
+ testWriteNormalRecord(false);
+ }
+
+ @Test
+ public void testWriteLargeRecord() throws Exception {
+ int numSubpartitions = 2;
+ int numBuffers = 100;
+ initResultPartitionWriter(numSubpartitions, 10, 200, false, conf, 10);
+
+ partitionWriter.setup();
+
+ byte[] dataWritten = new byte[bufferSize * numBuffers];
+ Random random = new Random();
+ random.nextBytes(dataWritten);
+ ByteBuffer recordWritten = ByteBuffer.wrap(dataWritten);
+ partitionWriter.emitRecord(recordWritten, 0);
+ assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.finish();
+ partitionWriter.close();
+
+ List<Buffer> receivedBuffers = outputGate.getReceivedBuffers()[0];
+
+ ByteBuffer recordRead = ByteBuffer.allocate(bufferSize * numBuffers);
+ for (Buffer buffer : receivedBuffers) {
+ if (buffer.isBuffer()) {
+ recordRead.put(
+ buffer.getNioBuffer(
+ BufferUtils.HEADER_LENGTH, buffer.readableBytes() - BufferUtils.HEADER_LENGTH));
+ }
+ }
+ recordWritten.rewind();
+ recordRead.flip();
+ assertEquals(recordWritten, recordRead);
+ }
+
+ @Test
+ public void testBroadcastLargeRecord() throws Exception {
+ int numSubpartitions = 2;
+ int numBuffers = 100;
+ initResultPartitionWriter(numSubpartitions, 10, 200, false, conf, 10);
+
+ partitionWriter.setup();
+
+ byte[] dataWritten = new byte[bufferSize * numBuffers];
+ Random random = new Random();
+ random.nextBytes(dataWritten);
+ ByteBuffer recordWritten = ByteBuffer.wrap(dataWritten);
+ partitionWriter.broadcastRecord(recordWritten);
+ assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.finish();
+ partitionWriter.close();
+
+ ByteBuffer recordRead0 = ByteBuffer.allocate(bufferSize * numBuffers);
+ for (Buffer buffer : outputGate.getReceivedBuffers()[0]) {
+ if (buffer.isBuffer()) {
+ recordRead0.put(
+ buffer.getNioBuffer(
+ BufferUtils.HEADER_LENGTH, buffer.readableBytes() - BufferUtils.HEADER_LENGTH));
+ }
+ }
+ recordWritten.rewind();
+ recordRead0.flip();
+ assertEquals(recordWritten, recordRead0);
+
+ ByteBuffer recordRead1 = ByteBuffer.allocate(bufferSize * numBuffers);
+ for (Buffer buffer : outputGate.getReceivedBuffers()[1]) {
+ if (buffer.isBuffer()) {
+ recordRead1.put(
+ buffer.getNioBuffer(
+ BufferUtils.HEADER_LENGTH, buffer.readableBytes() - BufferUtils.HEADER_LENGTH));
+ }
+ }
+ recordWritten.rewind();
+ recordRead1.flip();
+ assertEquals(recordWritten, recordRead0);
+ }
+
+ @Test
+ public void testFlush() throws Exception {
+ int numSubpartitions = 10;
+
+ initResultPartitionWriter(numSubpartitions, 10, 20, false, conf, 10);
+ partitionWriter.setup();
+
+ partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 0);
+ partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 1);
+ assertEquals(3, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.broadcastRecord(ByteBuffer.allocate(bufferSize));
+ assertEquals(2, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.flush(0);
+ assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 2);
+ partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 3);
+ assertEquals(3, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.flushAll();
+ assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+ partitionWriter.finish();
+ partitionWriter.close();
+ }
+
+ private void testWriteNormalRecord(boolean compressionEnabled) throws Exception {
+ int numSubpartitions = 4;
+ int numRecords = 100;
+ Random random = new Random();
+
+ initResultPartitionWriter(numSubpartitions, 100, 500, compressionEnabled, conf, 10);
+ partitionWriter.setup();
+ assertTrue(outputGate.isSetup());
+
+ Queue<DataAndType>[] dataWritten = new Queue[numSubpartitions];
+ IntStream.range(0, numSubpartitions).forEach(i -> dataWritten[i] = new ArrayDeque<>());
+ int[] numBytesWritten = new int[numSubpartitions];
+ Arrays.fill(numBytesWritten, 0);
+
+ for (int i = 0; i < numRecords; i++) {
+ byte[] data = new byte[random.nextInt(2 * bufferSize) + 1];
+ if (compressionEnabled) {
+ byte randomByte = (byte) random.nextInt();
+ Arrays.fill(data, randomByte);
+ } else {
+ random.nextBytes(data);
+ }
+ ByteBuffer record = ByteBuffer.wrap(data);
+ boolean isBroadCast = random.nextBoolean();
+
+ if (isBroadCast) {
+ partitionWriter.broadcastRecord(record);
+ IntStream.range(0, numSubpartitions)
+ .forEach(
+ subpartition ->
+ recordDataWritten(
+ record,
+ Buffer.DataType.DATA_BUFFER,
+ subpartition,
+ dataWritten,
+ numBytesWritten));
+ } else {
+ int subpartition = random.nextInt(numSubpartitions);
+ partitionWriter.emitRecord(record, subpartition);
+ recordDataWritten(
+ record, Buffer.DataType.DATA_BUFFER, subpartition, dataWritten, numBytesWritten);
+ }
+ }
+
+ partitionWriter.finish();
+ assertTrue(outputGate.isFinished());
+ partitionWriter.close();
+ assertTrue(outputGate.isClosed());
+
+ for (int subpartition = 0; subpartition < numSubpartitions; ++subpartition) {
+ ByteBuffer record = EventSerializer.toSerializedEvent(EndOfPartitionEvent.INSTANCE);
+ recordDataWritten(
+ record, Buffer.DataType.EVENT_BUFFER, subpartition, dataWritten, numBytesWritten);
+ }
+
+ outputGate
+ .getFinishedRegions()
+ .forEach(
+ regionIndex -> assertTrue(outputGate.getNumBuffersByRegion().containsKey(regionIndex)));
+
+ int[] numBytesRead = new int[numSubpartitions];
+ List<Buffer>[] receivedBuffers = outputGate.getReceivedBuffers();
+ List<Buffer>[] validateTarget = new List[numSubpartitions];
+ Arrays.fill(numBytesRead, 0);
+ for (int i = 0; i < numSubpartitions; i++) {
+ validateTarget[i] = new ArrayList<>();
+ for (Buffer buffer : receivedBuffers[i]) {
+ for (Buffer unpackedBuffer : BufferPacker.unpack(buffer.asByteBuf())) {
+ if (compressionEnabled && unpackedBuffer.isCompressed()) {
+ Buffer decompressedBuffer =
+ bufferDecompressor.decompressToIntermediateBuffer(unpackedBuffer);
+ ByteBuffer decompressed = decompressedBuffer.getNioBufferReadable();
+ int numBytes = decompressed.remaining();
+ MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(numBytes);
+ segment.put(0, decompressed, numBytes);
+ decompressedBuffer.recycleBuffer();
+ validateTarget[i].add(
+ new NetworkBuffer(segment, buf -> {}, unpackedBuffer.getDataType(), numBytes));
+ numBytesRead[i] += numBytes;
+ } else {
+ numBytesRead[i] += buffer.readableBytes();
+ validateTarget[i].add(buffer);
+ }
+ }
+ }
+ }
+ IntStream.range(0, numSubpartitions).forEach(subpartitions -> {});
+ checkWriteReadResult(
+ numSubpartitions, numBytesWritten, numBytesWritten, dataWritten, validateTarget);
+ }
+
+ private void initResultPartitionWriter(
+ int numSubpartitions,
+ int sortBufferPoolSize,
+ int nettyBufferPoolSize,
+ boolean compressionEnabled,
+ CelebornConf conf,
+ int numMappers)
+ throws Exception {
+
+ sortBufferPool = globalBufferPool.createBufferPool(sortBufferPoolSize, sortBufferPoolSize);
+ nettyBufferPool = globalBufferPool.createBufferPool(nettyBufferPoolSize, nettyBufferPoolSize);
+
+ outputGate =
+ new FakedRemoteShuffleOutputGate(
+ getShuffleDescriptor(), numSubpartitions, () -> nettyBufferPool, conf, numMappers);
+ outputGate.setup();
+
+ if (compressionEnabled) {
+ partitionWriter =
+ new RemoteShuffleResultPartition(
+ "RemoteShuffleResultPartitionWriterTest",
+ 0,
+ new ResultPartitionID(),
+ ResultPartitionType.BLOCKING,
+ numSubpartitions,
+ numSubpartitions,
+ bufferSize,
+ new ResultPartitionManager(),
+ bufferCompressor,
+ () -> sortBufferPool,
+ outputGate);
+ } else {
+ partitionWriter =
+ new RemoteShuffleResultPartition(
+ "RemoteShuffleResultPartitionWriterTest",
+ 0,
+ new ResultPartitionID(),
+ ResultPartitionType.BLOCKING,
+ numSubpartitions,
+ numSubpartitions,
+ bufferSize,
+ new ResultPartitionManager(),
+ null,
+ () -> sortBufferPool,
+ outputGate);
+ }
+ }
+
+ private void recordDataWritten(
+ ByteBuffer record,
+ Buffer.DataType dataType,
+ int subpartition,
+ Queue<DataAndType>[] dataWritten,
+ int[] numBytesWritten) {
+
+ record.rewind();
+ dataWritten[subpartition].add(new DataAndType(record, dataType));
+ numBytesWritten[subpartition] += record.remaining();
+ }
+
+ private static class FakedRemoteShuffleOutputGate extends RemoteShuffleOutputGate {
+
+ private boolean isSetup;
+ private boolean isFinished;
+ private boolean isClosed;
+ private final List<Buffer>[] receivedBuffers;
+ private final Map<Integer, Integer> numBuffersByRegion;
+ private final Set<Integer> finishedRegions;
+ private int currentRegionIndex;
+ private boolean currentIsBroadcast;
+
+ FakedRemoteShuffleOutputGate(
+ RemoteShuffleDescriptor shuffleDescriptor,
+ int numSubpartitions,
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+ CelebornConf celebornConf,
+ int numMappers) {
+
+ super(
+ shuffleDescriptor,
+ numSubpartitions,
+ bufferSize,
+ bufferPoolFactory,
+ celebornConf,
+ numMappers);
+ isSetup = false;
+ isFinished = false;
+ isClosed = false;
+ numBuffersByRegion = new HashMap<>();
+ finishedRegions = new HashSet<>();
+ currentRegionIndex = -1;
+ receivedBuffers = new ArrayList[numSubpartitions];
+ IntStream.range(0, numSubpartitions).forEach(i -> receivedBuffers[i] = new ArrayList<>());
+ currentIsBroadcast = false;
+ }
+
+ @Override
+ FlinkShuffleClientImpl createWriteClient() {
+ FlinkShuffleClientImpl client = mock(FlinkShuffleClientImpl.class);
+ doNothing().when(client).cleanup(anyString(), anyInt(), anyInt(), anyInt());
+ return client;
+ }
+
+ @Override
+ public void setup() throws IOException, InterruptedException {
+ bufferPool = bufferPoolFactory.get();
+ isSetup = true;
+ }
+
+ @Override
+ public void write(Buffer buffer, int subIdx) {
+ if (currentIsBroadcast) {
+ assertEquals(0, subIdx);
+ ByteBuffer byteBuffer = buffer.getNioBufferReadable();
+ for (int i = 0; i < numSubs; i++) {
+ int numBytes = buffer.readableBytes();
+ MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(numBytes);
+ byteBuffer.rewind();
+ segment.put(0, byteBuffer, numBytes);
+ receivedBuffers[i].add(
+ new NetworkBuffer(
+ segment, buf -> {}, buffer.getDataType(), buffer.isCompressed(), numBytes));
+ }
+ buffer.recycleBuffer();
+ } else {
+ receivedBuffers[subIdx].add(buffer);
+ }
+ if (numBuffersByRegion.containsKey(currentRegionIndex)) {
+ int prev = numBuffersByRegion.get(currentRegionIndex);
+ numBuffersByRegion.put(currentRegionIndex, prev + 1);
+ } else {
+ numBuffersByRegion.put(currentRegionIndex, 1);
+ }
+ }
+
+ @Override
+ public void regionStart(boolean isBroadcast) {
+ currentIsBroadcast = isBroadcast;
+ currentRegionIndex++;
+ }
+
+ @Override
+ public void regionFinish() {
+ if (finishedRegions.contains(currentRegionIndex)) {
+ throw new IllegalStateException("Unexpected region: " + currentRegionIndex);
+ }
+ finishedRegions.add(currentRegionIndex);
+ }
+
+ @Override
+ public void finish() throws InterruptedException {
+ isFinished = true;
+ }
+
+ @Override
+ public void close() {
+ isClosed = true;
+ }
+
+ public List<Buffer>[] getReceivedBuffers() {
+ return receivedBuffers;
+ }
+
+ public Map<Integer, Integer> getNumBuffersByRegion() {
+ return numBuffersByRegion;
+ }
+
+ public Set<Integer> getFinishedRegions() {
+ return finishedRegions;
+ }
+
+ public boolean isSetup() {
+ return isSetup;
+ }
+
+ public boolean isFinished() {
+ return isFinished;
+ }
+
+ public boolean isClosed() {
+ return isClosed;
+ }
+
+ public void release() throws Exception {
+ IntStream.range(0, numSubs)
+ .forEach(
+ subpartitionIndex -> {
+ receivedBuffers[subpartitionIndex].forEach(Buffer::recycleBuffer);
+ receivedBuffers[subpartitionIndex].clear();
+ });
+ numBuffersByRegion.clear();
+ finishedRegions.clear();
+ super.close();
+ }
+ }
+
+ private RemoteShuffleDescriptor getShuffleDescriptor() throws Exception {
+ Random random = new Random();
+ byte[] bytes = new byte[16];
+ random.nextBytes(bytes);
+ LifecycleManager.ShuffleTask shuffleTask = Mockito.mock(LifecycleManager.ShuffleTask.class);
+ Mockito.when(shuffleTask.attemptId()).thenReturn(1);
+ Mockito.when(shuffleTask.mapId()).thenReturn(1);
+ Mockito.when(shuffleTask.shuffleId()).thenReturn(1);
+ return new RemoteShuffleDescriptor(
+ new JobID(bytes).toString(),
+ new JobID(bytes).toString(),
+ new ResultPartitionID(),
+ new RemoteShuffleResource("1", 2, new ShuffleResourceDescriptor(shuffleTask)));
+ }
+
+ /** Data written and its {@link Buffer.DataType}. */
+ public static class DataAndType {
+ private final ByteBuffer data;
+ private final Buffer.DataType dataType;
+
+ DataAndType(ByteBuffer data, Buffer.DataType dataType) {
+ this.data = data;
+ this.dataType = dataType;
+ }
+ }
+
+ public static void checkWriteReadResult(
+ int numSubpartitions,
+ int[] numBytesWritten,
+ int[] numBytesRead,
+ Queue<DataAndType>[] dataWritten,
+ Collection<Buffer>[] buffersRead) {
+ for (int subpartitionIndex = 0; subpartitionIndex < numSubpartitions; ++subpartitionIndex) {
+ assertEquals(numBytesWritten[subpartitionIndex], numBytesRead[subpartitionIndex]);
+
+ List<DataAndType> eventsWritten = new ArrayList<>();
+ List<Buffer> eventsRead = new ArrayList<>();
+
+ ByteBuffer subpartitionDataWritten = ByteBuffer.allocate(numBytesWritten[subpartitionIndex]);
+ for (DataAndType dataAndType : dataWritten[subpartitionIndex]) {
+ subpartitionDataWritten.put(dataAndType.data);
+ dataAndType.data.rewind();
+ if (dataAndType.dataType.isEvent()) {
+ eventsWritten.add(dataAndType);
+ }
+ }
+
+ ByteBuffer subpartitionDataRead = ByteBuffer.allocate(numBytesRead[subpartitionIndex]);
+ for (Buffer buffer : buffersRead[subpartitionIndex]) {
+ subpartitionDataRead.put(buffer.getNioBufferReadable());
+ if (!buffer.isBuffer()) {
+ eventsRead.add(buffer);
+ }
+ }
+
+ subpartitionDataWritten.flip();
+ subpartitionDataRead.flip();
+ assertEquals(subpartitionDataWritten, subpartitionDataRead);
+
+ assertEquals(eventsWritten.size(), eventsRead.size());
+ for (int i = 0; i < eventsWritten.size(); ++i) {
+ assertEquals(eventsWritten.get(i).dataType, eventsRead.get(i).getDataType());
+ assertEquals(eventsWritten.get(i).data, eventsRead.get(i).getNioBufferReadable());
+ }
+ }
+ }
+}
diff --git a/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java
new file mode 100644
index 000000000..9a4f232f7
--- /dev/null
+++ b/client-flink/flink-1.15/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RemoteShuffleServiceFactorySuitJ {
+ @Test
+ public void testCreateShuffleEnvironment() {
+ RemoteShuffleServiceFactory remoteShuffleServiceFactory = new RemoteShuffleServiceFactory();
+ ShuffleEnvironmentContext shuffleEnvironmentContext = mock(ShuffleEnvironmentContext.class);
+ when(shuffleEnvironmentContext.getConfiguration()).thenReturn(new Configuration());
+ when(shuffleEnvironmentContext.getNetworkMemorySize())
+ .thenReturn(new MemorySize(64 * 1024 * 1024));
+ MetricGroup parentMeric = mock(MetricGroup.class);
+ when(shuffleEnvironmentContext.getParentMetricGroup()).thenReturn(parentMeric);
+ MetricGroup childGroup = mock(MetricGroup.class);
+ MetricGroup childChildGroup = mock(MetricGroup.class);
+ when(parentMeric.addGroup(anyString())).thenReturn(childGroup);
+ when(childGroup.addGroup(any())).thenReturn(childChildGroup);
+ when(childChildGroup.gauge(any(), any())).thenReturn(null);
+ ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> shuffleEnvironment =
+ remoteShuffleServiceFactory.createShuffleEnvironment(shuffleEnvironmentContext);
+ Assert.assertEquals(
+ 32 * 1024,
+ ((RemoteShuffleEnvironment) shuffleEnvironment)
+ .getResultPartitionFactory()
+ .getNetworkBufferSize());
+ }
+}
diff --git a/pom.xml b/pom.xml
index 667c9a09d..8f55b91f8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -64,6 +64,11 @@
<maven.version>3.6.3</maven.version>
<flink.version>1.14.6</flink.version>
+ <flink-streaming-java-artifactId>flink-streaming-java_${scala.binary.version}</flink-streaming-java-artifactId>
+ <flink-clients-artifactId>flink-clients_${scala.binary.version}</flink-clients-artifactId>
+ <flink-runtime-web-artifactId>flink-runtime-web_${scala.binary.version}</flink-runtime-web-artifactId>
+ <flink-plugin-artifactId>celeborn-client-flink-1.14_${scala.binary.version}</flink-plugin-artifactId>
+
<hadoop.version>3.2.1</hadoop.version>
<spark.version>3.3.1</spark.version>
@@ -168,17 +173,17 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
+ <artifactId>${flink-streaming-java-artifactId}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-clients_${scala.binary.version}</artifactId>
+ <artifactId>${flink-clients-artifactId}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-runtime-web_${scala.binary.version}</artifactId>
+ <artifactId>${flink-runtime-web-artifactId}</artifactId>
<version>${flink.version}</version>
</dependency>
<dependency>
@@ -1067,8 +1072,29 @@
</modules>
<properties>
<flink.version>1.14.6</flink.version>
+ <flink.binary.version>1.14</flink.binary.version>
+ <scala.version>2.12.7</scala.version>
+ <scala.binary.version>2.12</scala.binary.version>
+ </properties>
+ </profile>
+
+ <profile>
+ <id>flink-1.15</id>
+ <modules>
+ <module>client-flink/common</module>
+ <module>client-flink/flink-1.15</module>
+ <module>client-flink/flink-1.15-shaded</module>
+ <module>tests/flink-it</module>
+ </modules>
+ <properties>
+ <flink.version>1.15-vvr-6.0.5-SNAPSHOT</flink.version>
+ <flink.binary.version>1.15</flink.binary.version>
<scala.version>2.12.7</scala.version>
<scala.binary.version>2.12</scala.binary.version>
+ <flink-streaming-java-artifactId>flink-streaming-java</flink-streaming-java-artifactId>
+ <flink-clients-artifactId>flink-clients</flink-clients-artifactId>
+ <flink-runtime-web-artifactId>flink-runtime-web</flink-runtime-web-artifactId>
+ <flink-plugin-artifactId>celeborn-client-flink-1.15</flink-plugin-artifactId>
</properties>
</profile>
diff --git a/tests/flink-it/pom.xml b/tests/flink-it/pom.xml
index 33be3e20e..efe1c193f 100644
--- a/tests/flink-it/pom.xml
+++ b/tests/flink-it/pom.xml
@@ -75,7 +75,7 @@
</dependency>
<dependency>
<groupId>org.apache.celeborn</groupId>
- <artifactId>celeborn-client-flink-1.14_${scala.binary.version}</artifactId>
+ <artifactId>${flink-plugin-artifactId}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -92,13 +92,13 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
+ <artifactId>${flink-streaming-java-artifactId}</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-clients_${scala.binary.version}</artifactId>
+ <artifactId>${flink-clients-artifactId}</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
</dependency>
@@ -110,7 +110,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
- <artifactId>flink-runtime-web_${scala.binary.version}</artifactId>
+ <artifactId>${flink-runtime-web-artifactId}</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
</dependency>