You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by re...@apache.org on 2023/03/29 02:23:37 UTC
[incubator-celeborn] 22/42: [CELEBORN-418][FLINK] Need drop unused bytes from netty when task was already failed (#1348)
This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
commit a52bc8c0d2c4bc68f9f80b49827363bc0cae716b
Author: Shuang <lv...@gmail.com>
AuthorDate: Tue Mar 14 19:48:41 2023 +0800
[CELEBORN-418][FLINK] Need drop unused bytes from netty when task was already failed (#1348)
---
.../TransportFrameDecoderWithBufferSupplier.java | 50 +++++---
...nsportFrameDecoderWithBufferSupplierSuiteJ.java | 127 +++++++++++++++++++++
2 files changed, 160 insertions(+), 17 deletions(-)
diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
index 01283547a..eacd375d1 100644
--- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
+++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
@@ -17,8 +17,6 @@
package org.apache.celeborn.plugin.flink.network;
-import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
-
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
@@ -29,7 +27,6 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -46,6 +43,8 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
private ByteBuf externalBuf = null;
private final ByteBuf msgBuf = Unpooled.buffer(8);
private Message curMsg = null;
+ private int remainingSize = -1;
+
private final ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
public TransportFrameDecoderWithBufferSupplier(
@@ -60,6 +59,18 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
}
}
+ private void dropUnusedBytes(io.netty.buffer.ByteBuf source) {
+ if (source.readableBytes() > 0) {
+ if (remainingSize > source.readableBytes()) {
+ remainingSize = remainingSize - source.readableBytes();
+ source.skipBytes(source.readableBytes());
+ } else {
+ source.skipBytes(remainingSize);
+ clear();
+ }
+ }
+ }
+
private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
copyByteBuf(buf, headerBuf, HEADER_SIZE);
if (!headerBuf.isWritable()) {
@@ -121,12 +132,24 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
private io.netty.buffer.ByteBuf decodeBodyCopyOut(
io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
+ if (remainingSize > 0) {
+ dropUnusedBytes(buf);
+ return buf;
+ }
+
ReadData readData = (ReadData) curMsg;
+ long streamId = readData.getStreamId();
if (externalBuf == null) {
- Supplier<ByteBuf> supplier = bufferSuppliers.get(readData.getStreamId());
- checkState(supplier != null, "Stream " + readData.getStreamId() + " buffer supplier is null");
- externalBuf = bufferSuppliers.get(readData.getStreamId()).get();
+ Supplier<ByteBuf> supplier = bufferSuppliers.get(streamId);
+ if (supplier == null) {
+ logger.warn("Need drop unused bytes, streamId: {}, bodySize: {}", streamId, bodySize);
+ remainingSize = bodySize;
+ dropUnusedBytes(buf);
+ return buf;
+ }
+ externalBuf = supplier.get();
}
+
copyByteBuf(buf, externalBuf, bodySize);
if (externalBuf.readableBytes() == bodySize) {
((ReadData) curMsg).setFlinkBuffer(externalBuf);
@@ -146,20 +169,13 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
decodeMsg(nettyBuf, ctx);
} else if (bodySize > 0) {
if (curMsg.needCopyOut()) {
- // Only readdata will enter this branch
+ // Only read data will enter this branch
nettyBuf = decodeBodyCopyOut(nettyBuf, ctx);
} else {
nettyBuf = decodeBody(nettyBuf, ctx);
}
}
}
- } catch (IllegalStateException e) {
- // Decode ReadData might encounter IllegalStateException.
- long streamId = ((ReadData) curMsg).getStreamId();
- logger.info("Stream {} is closed,reply to server", streamId);
- if (ctx.channel().isActive()) {
- ctx.channel().writeAndFlush(new BufferStreamEnd(streamId));
- }
} finally {
if (nettyBuf != null) {
nettyBuf.release();
@@ -174,6 +190,7 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
headerBuf.clear();
bodyBuf = null;
bodySize = -1;
+ remainingSize = -1;
}
@Override
@@ -184,10 +201,9 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
clear();
- if (externalBuf != null) {
- externalBuf.clear();
- }
+
headerBuf.release();
+ msgBuf.release();
super.handlerRemoved(ctx);
}
diff --git a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
new file mode 100644
index 000000000..150711624
--- /dev/null
+++ b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.network;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Supplier;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.protocol.Message;
+import org.apache.celeborn.common.network.protocol.ReadData;
+
+public class TransportFrameDecoderWithBufferSupplierSuiteJ {
+
+ @Test
+ public void testDropUnusedBytes() throws IOException {
+ ConcurrentHashMap<Long, Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
+ supplier = new ConcurrentHashMap<>();
+ List<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf> buffers = new ArrayList<>();
+
+ supplier.put(
+ 2L,
+ () -> {
+ org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf buffer =
+ org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.buffer(32000);
+ buffers.add(buffer);
+ return buffer;
+ });
+
+ TransportFrameDecoderWithBufferSupplier decoder =
+ new TransportFrameDecoderWithBufferSupplier(supplier);
+ ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
+
+ BacklogAnnouncement announcement = new BacklogAnnouncement(0, 0);
+ ReadData unUsedReadData = new ReadData(1, 8, 0, generateData(1024));
+ ReadData readData = new ReadData(2, 8, 0, generateData(1024));
+ BacklogAnnouncement announcement1 = new BacklogAnnouncement(0, 0);
+ ReadData unUsedReadData1 = new ReadData(1, 8, 0, generateData(1024));
+ ReadData readData1 = new ReadData(2, 8, 0, generateData(8));
+
+ ByteBuf buffer = Unpooled.buffer(5000);
+ encodeMessage(announcement, buffer);
+ encodeMessage(unUsedReadData, buffer);
+ encodeMessage(readData, buffer);
+ encodeMessage(announcement1, buffer);
+ encodeMessage(unUsedReadData1, buffer);
+ encodeMessage(readData1, buffer);
+
+ // simulate
+ buffer.retain();
+ decoder.channelRead(context, buffer);
+ Assert.assertEquals(buffers.get(0).nioBuffer(), readData.body().nioByteBuffer());
+ Assert.assertEquals(buffers.get(1).nioBuffer(), readData1.body().nioByteBuffer());
+
+ // simulate 1 - split the unUsedReadData buffer
+ buffer.retain();
+ buffer.resetReaderIndex();
+ decoder.channelRead(context, buffer.retainedSlice(0, 555));
+ ByteBuf byteBuf = buffer.retainedSlice(0, buffer.readableBytes());
+ byteBuf.readerIndex(555);
+ decoder.channelRead(context, byteBuf);
+
+ Assert.assertEquals(buffers.get(2).nioBuffer(), readData.body().nioByteBuffer());
+ Assert.assertEquals(buffers.get(3).nioBuffer(), readData1.body().nioByteBuffer());
+
+ // simulate 2 - split the readData buffer
+ buffer.retain();
+ buffer.resetReaderIndex();
+ decoder.channelRead(context, buffer.retainedSlice(0, 1500));
+ byteBuf = buffer.retainedSlice(0, buffer.readableBytes());
+ byteBuf.readerIndex(1500);
+ decoder.channelRead(context, byteBuf);
+
+ Assert.assertEquals(buffers.get(4).nioBuffer(), readData.body().nioByteBuffer());
+ Assert.assertEquals(buffers.get(5).nioBuffer(), readData1.body().nioByteBuffer());
+ Assert.assertEquals(buffers.size(), 6);
+ }
+
+ public ByteBuf encodeMessage(Message in, ByteBuf byteBuf) throws IOException {
+ byteBuf.writeInt(in.encodedLength());
+ in.type().encode(byteBuf);
+ if (in.body() != null) {
+ byteBuf.writeInt((int) in.body().size());
+ in.encode(byteBuf);
+ byteBuf.writeBytes(in.body().nioByteBuffer());
+ } else {
+ byteBuf.writeInt(0);
+ in.encode(byteBuf);
+ }
+
+ return byteBuf;
+ }
+
+ public ByteBuf generateData(int size) {
+ ByteBuf data = Unpooled.buffer(size);
+ for (int i = 0; i < size; i++) {
+ data.writeByte(new Random().nextInt(7));
+ }
+
+ return data;
+ }
+}