You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by re...@apache.org on 2023/03/29 02:23:37 UTC

[incubator-celeborn] 22/42: [CELEBORN-418][FLINK] Need drop unused bytes from netty when task was already failed (#1348)

This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git

commit a52bc8c0d2c4bc68f9f80b49827363bc0cae716b
Author: Shuang <lv...@gmail.com>
AuthorDate: Tue Mar 14 19:48:41 2023 +0800

    [CELEBORN-418][FLINK] Need drop unused bytes from netty when task was already failed (#1348)
---
 .../TransportFrameDecoderWithBufferSupplier.java   |  50 +++++---
 ...nsportFrameDecoderWithBufferSupplierSuiteJ.java | 127 +++++++++++++++++++++
 2 files changed, 160 insertions(+), 17 deletions(-)

diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
index 01283547a..eacd375d1 100644
--- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
+++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
@@ -17,8 +17,6 @@
 
 package org.apache.celeborn.plugin.flink.network;
 
-import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
-
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Supplier;
 
@@ -29,7 +27,6 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.util.FrameDecoder;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -46,6 +43,8 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
   private ByteBuf externalBuf = null;
   private final ByteBuf msgBuf = Unpooled.buffer(8);
   private Message curMsg = null;
+  private int remainingSize = -1;
+
   private final ConcurrentHashMap<Long, Supplier<ByteBuf>> bufferSuppliers;
 
   public TransportFrameDecoderWithBufferSupplier(
@@ -60,6 +59,18 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
     }
   }
 
+  private void dropUnusedBytes(io.netty.buffer.ByteBuf source) {
+    if (source.readableBytes() > 0) {
+      if (remainingSize > source.readableBytes()) {
+        remainingSize = remainingSize - source.readableBytes();
+        source.skipBytes(source.readableBytes());
+      } else {
+        source.skipBytes(remainingSize);
+        clear();
+      }
+    }
+  }
+
   private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
     copyByteBuf(buf, headerBuf, HEADER_SIZE);
     if (!headerBuf.isWritable()) {
@@ -121,12 +132,24 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
 
   private io.netty.buffer.ByteBuf decodeBodyCopyOut(
       io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
+    if (remainingSize > 0) {
+      dropUnusedBytes(buf);
+      return buf;
+    }
+
     ReadData readData = (ReadData) curMsg;
+    long streamId = readData.getStreamId();
     if (externalBuf == null) {
-      Supplier<ByteBuf> supplier = bufferSuppliers.get(readData.getStreamId());
-      checkState(supplier != null, "Stream " + readData.getStreamId() + " buffer supplier is null");
-      externalBuf = bufferSuppliers.get(readData.getStreamId()).get();
+      Supplier<ByteBuf> supplier = bufferSuppliers.get(streamId);
+      if (supplier == null) {
+        logger.warn("Need drop unused bytes, streamId: {}, bodySize: {}", streamId, bodySize);
+        remainingSize = bodySize;
+        dropUnusedBytes(buf);
+        return buf;
+      }
+      externalBuf = supplier.get();
     }
+
     copyByteBuf(buf, externalBuf, bodySize);
     if (externalBuf.readableBytes() == bodySize) {
       ((ReadData) curMsg).setFlinkBuffer(externalBuf);
@@ -146,20 +169,13 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
           decodeMsg(nettyBuf, ctx);
         } else if (bodySize > 0) {
           if (curMsg.needCopyOut()) {
-            // Only readdata will enter this branch
+            // Only read data will enter this branch
             nettyBuf = decodeBodyCopyOut(nettyBuf, ctx);
           } else {
             nettyBuf = decodeBody(nettyBuf, ctx);
           }
         }
       }
-    } catch (IllegalStateException e) {
-      // Decode ReadData might encounter IllegalStateException.
-      long streamId = ((ReadData) curMsg).getStreamId();
-      logger.info("Stream {} is closed,reply to server", streamId);
-      if (ctx.channel().isActive()) {
-        ctx.channel().writeAndFlush(new BufferStreamEnd(streamId));
-      }
     } finally {
       if (nettyBuf != null) {
         nettyBuf.release();
@@ -174,6 +190,7 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
     headerBuf.clear();
     bodyBuf = null;
     bodySize = -1;
+    remainingSize = -1;
   }
 
   @Override
@@ -184,10 +201,9 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
   @Override
   public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
     clear();
-    if (externalBuf != null) {
-      externalBuf.clear();
-    }
+
     headerBuf.release();
+    msgBuf.release();
     super.handlerRemoved(ctx);
   }
 
diff --git a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
new file mode 100644
index 000000000..150711624
--- /dev/null
+++ b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.network;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Supplier;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.protocol.Message;
+import org.apache.celeborn.common.network.protocol.ReadData;
+
+public class TransportFrameDecoderWithBufferSupplierSuiteJ {
+
+  @Test
+  public void testDropUnusedBytes() throws IOException {
+    ConcurrentHashMap<Long, Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
+        supplier = new ConcurrentHashMap<>();
+    List<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf> buffers = new ArrayList<>();
+
+    supplier.put(
+        2L,
+        () -> {
+          org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf buffer =
+              org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled.buffer(32000);
+          buffers.add(buffer);
+          return buffer;
+        });
+
+    TransportFrameDecoderWithBufferSupplier decoder =
+        new TransportFrameDecoderWithBufferSupplier(supplier);
+    ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
+
+    BacklogAnnouncement announcement = new BacklogAnnouncement(0, 0);
+    ReadData unUsedReadData = new ReadData(1, 8, 0, generateData(1024));
+    ReadData readData = new ReadData(2, 8, 0, generateData(1024));
+    BacklogAnnouncement announcement1 = new BacklogAnnouncement(0, 0);
+    ReadData unUsedReadData1 = new ReadData(1, 8, 0, generateData(1024));
+    ReadData readData1 = new ReadData(2, 8, 0, generateData(8));
+
+    ByteBuf buffer = Unpooled.buffer(5000);
+    encodeMessage(announcement, buffer);
+    encodeMessage(unUsedReadData, buffer);
+    encodeMessage(readData, buffer);
+    encodeMessage(announcement1, buffer);
+    encodeMessage(unUsedReadData1, buffer);
+    encodeMessage(readData1, buffer);
+
+    // simulate
+    buffer.retain();
+    decoder.channelRead(context, buffer);
+    Assert.assertEquals(buffers.get(0).nioBuffer(), readData.body().nioByteBuffer());
+    Assert.assertEquals(buffers.get(1).nioBuffer(), readData1.body().nioByteBuffer());
+
+    // simulate 1 - split the unUsedReadData buffer
+    buffer.retain();
+    buffer.resetReaderIndex();
+    decoder.channelRead(context, buffer.retainedSlice(0, 555));
+    ByteBuf byteBuf = buffer.retainedSlice(0, buffer.readableBytes());
+    byteBuf.readerIndex(555);
+    decoder.channelRead(context, byteBuf);
+
+    Assert.assertEquals(buffers.get(2).nioBuffer(), readData.body().nioByteBuffer());
+    Assert.assertEquals(buffers.get(3).nioBuffer(), readData1.body().nioByteBuffer());
+
+    // simulate 2 - split the readData buffer
+    buffer.retain();
+    buffer.resetReaderIndex();
+    decoder.channelRead(context, buffer.retainedSlice(0, 1500));
+    byteBuf = buffer.retainedSlice(0, buffer.readableBytes());
+    byteBuf.readerIndex(1500);
+    decoder.channelRead(context, byteBuf);
+
+    Assert.assertEquals(buffers.get(4).nioBuffer(), readData.body().nioByteBuffer());
+    Assert.assertEquals(buffers.get(5).nioBuffer(), readData1.body().nioByteBuffer());
+    Assert.assertEquals(buffers.size(), 6);
+  }
+
+  public ByteBuf encodeMessage(Message in, ByteBuf byteBuf) throws IOException {
+    byteBuf.writeInt(in.encodedLength());
+    in.type().encode(byteBuf);
+    if (in.body() != null) {
+      byteBuf.writeInt((int) in.body().size());
+      in.encode(byteBuf);
+      byteBuf.writeBytes(in.body().nioByteBuffer());
+    } else {
+      byteBuf.writeInt(0);
+      in.encode(byteBuf);
+    }
+
+    return byteBuf;
+  }
+
+  public ByteBuf generateData(int size) {
+    ByteBuf data = Unpooled.buffer(size);
+    for (int i = 0; i < size; i++) {
+      data.writeByte(new Random().nextInt(7));
+    }
+
+    return data;
+  }
+}