You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2021/03/05 09:17:55 UTC

[mina-sshd] 02/03: [SSHD-1125] Added mechanism to throttle pending write requests in BufferedIoOutputStream

This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 18609370696cc52ac780864237b37b2f173c4090
Author: Lyor Goldstein <lg...@apache.org>
AuthorDate: Thu Feb 25 21:05:49 2021 +0200

    [SSHD-1125] Added mechanism to throttle pending write requests in BufferedIoOutputStream
---
 CHANGES.md                                         |   1 +
 .../common/channel/BufferedIoOutputStream.java     | 203 ++++++++++++++++++---
 .../SshChannelBufferedOutputException.java         |  41 +++++
 .../org/apache/sshd/core/CoreModuleProperties.java |  21 ++-
 .../sshd/server/forward/TcpipServerChannel.java    |   8 +-
 .../sshd/util/test/AsyncEchoShellFactory.java      |  13 +-
 .../org/apache/sshd/sftp/server/SftpSubsystem.java |   3 +-
 .../client/impl/SftpRemotePathChannelTest.java     |   2 +
 8 files changed, 260 insertions(+), 32 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 685e749..4bf9e41 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -38,4 +38,5 @@
 * [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added callbacks for client-side host-based authentication progress
 * [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added capability for interactive password authentication participation via UserInteraction
 * [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added capability for interactive key based authentication participation via UserInteraction
+* [SSHD-1125](https://issues.apache.org/jira/browse/SSHD-1125) Added mechanism to throttle pending write requests in BufferedIoOutputStream
 * [SSHD-1127](https://issues.apache.org/jira/browse/SSHD-1127) Added capability to register a custom receiver for SFTP STDERR channel raw or stream data
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java
index 3ee3ece..e8b81d8 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java
@@ -20,29 +20,55 @@ package org.apache.sshd.common.channel;
 
 import java.io.EOFException;
 import java.io.IOException;
+import java.time.Duration;
+import java.util.Objects;
 import java.util.Queue;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 
 import org.apache.sshd.common.Closeable;
+import org.apache.sshd.common.PropertyResolver;
+import org.apache.sshd.common.channel.exception.SshChannelBufferedOutputException;
 import org.apache.sshd.common.future.SshFutureListener;
 import org.apache.sshd.common.io.IoOutputStream;
 import org.apache.sshd.common.io.IoWriteFuture;
+import org.apache.sshd.common.util.GenericUtils;
+import org.apache.sshd.common.util.ValidateUtils;
 import org.apache.sshd.common.util.buffer.Buffer;
 import org.apache.sshd.common.util.closeable.AbstractInnerCloseable;
+import org.apache.sshd.core.CoreModuleProperties;
 
 /**
  * An {@link IoOutputStream} capable of queuing write requests.
  */
 public class BufferedIoOutputStream extends AbstractInnerCloseable implements IoOutputStream {
+    protected final Object id;
+    protected final int channelId;
+    protected final int maxPendingBytesCount;
+    protected final Duration maxWaitForPendingWrites;
     protected final IoOutputStream out;
+    protected final AtomicInteger pendingBytesCount = new AtomicInteger();
+    protected final AtomicLong writtenBytesCount = new AtomicLong();
     protected final Queue<IoWriteFutureImpl> writes = new ConcurrentLinkedQueue<>();
     protected final AtomicReference<IoWriteFutureImpl> currentWrite = new AtomicReference<>();
-    protected final Object id;
+    protected final AtomicReference<SshChannelBufferedOutputException> pendingException = new AtomicReference<>();
+
+    public BufferedIoOutputStream(Object id, int channelId, IoOutputStream out, PropertyResolver resolver) {
+        this(id, channelId, out, CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE.getRequired(resolver),
+             CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT.getRequired(resolver));
+    }
 
-    public BufferedIoOutputStream(Object id, IoOutputStream out) {
-        this.out = out;
-        this.id = id;
+    public BufferedIoOutputStream(
+                                  Object id, int channelId, IoOutputStream out, int maxPendingBytesCount,
+                                  Duration maxWaitForPendingWrites) {
+        this.id = Objects.requireNonNull(id, "No stream identifier provided");
+        this.channelId = channelId;
+        this.out = Objects.requireNonNull(out, "No delegate output stream provided");
+        this.maxPendingBytesCount = maxPendingBytesCount;
+        ValidateUtils.checkTrue(maxPendingBytesCount > 0, "Invalid max. pending bytes count: %d", maxPendingBytesCount);
+        this.maxWaitForPendingWrites = Objects.requireNonNull(maxWaitForPendingWrites, "No max. pending time value provided");
     }
 
     public Object getId() {
@@ -52,60 +78,187 @@ public class BufferedIoOutputStream extends AbstractInnerCloseable implements Io
     @Override
     public IoWriteFuture writeBuffer(Buffer buffer) throws IOException {
         if (isClosing()) {
-            throw new EOFException("Closed - state=" + state);
+            throw new EOFException("Closed/ing - state=" + state);
         }
 
+        waitForAvailableWriteSpace(buffer.available());
+
         IoWriteFutureImpl future = new IoWriteFutureImpl(getId(), buffer);
         writes.add(future);
         startWriting();
         return future;
     }
 
+    protected void waitForAvailableWriteSpace(int requiredSize) throws IOException {
+        /*
+         * NOTE: this code allows a single pending write to give this mechanism "the slip" and
+         * exit the loop "unscathed" even though there is a pending exception. However, the goal
+         * here is to avoid an OOM by having an unlimited accumulation of pending write requests
+         * due to fact that the peer is not consuming the sent data. Please note that the pending
+         * exception is "sticky" - i.e., the next write attempt will fail. This also means that if
+         * the write request that "got away" was the last one by chance and it was consumed by the
+         * peer there will be no exception thrown - which is also fine since as mentioned the goal
+         * is not to enforce a strict limit on the pending bytes size but rather on the accumulation
+         * of the pending write requests.
+         *
+         * We could have counted pending requests rather than bytes. However, we also want to avoid
+         * having a large amount of data pending consumption by the peer as well. This code strikes
+         * such a balance by allowing a single pending request to exceed the limit, but at the same
+         * time prevents too many bytes from pending by having a bunch of pending requests that while
+         * below the imposed number limit may cumulatively represent a lot of pending bytes.
+         */
+
+        long expireTime = System.currentTimeMillis() + maxWaitForPendingWrites.toMillis();
+        synchronized (pendingBytesCount) {
+            for (int count = pendingBytesCount.get();
+                 /*
+                  * The (count > 0) condition is put in place to allow a single pending
+                  * write to exceed the maxPendingBytesCount as long as there are no
+                  * other pending ones.
+                  */
+                 (count > 0)
+                         // Not already over the limit or about to be over it
+                         && ((count + requiredSize) > maxPendingBytesCount)
+                         // No pending exception signaled
+                         && (pendingException.get() == null);
+                 count = pendingBytesCount.get()) {
+                long remTime = expireTime - System.currentTimeMillis();
+                if (remTime <= 0L) {
+                    pendingException.compareAndSet(null,
+                            new SshChannelBufferedOutputException(
+                                    channelId,
+                                    "Max. pending write timeout expired after " + writtenBytesCount + " bytes"));
+                    throw pendingException.get();
+                }
+
+                try {
+                    pendingBytesCount.wait(remTime);
+                } catch (InterruptedException e) {
+                    pendingException.compareAndSet(null,
+                            new SshChannelBufferedOutputException(
+                                    channelId,
+                                    "Waiting for pending writes interrupted after " + writtenBytesCount + " bytes"));
+                    throw pendingException.get();
+                }
+            }
+
+            IOException e = pendingException.get();
+            if (e != null) {
+                throw e;
+            }
+
+            pendingBytesCount.addAndGet(requiredSize);
+        }
+    }
+
     protected void startWriting() throws IOException {
         IoWriteFutureImpl future = writes.peek();
+        // No more pending requests
         if (future == null) {
             return;
         }
 
+        // Don't try to write any further if pending exception signaled
+        Throwable pendingError = pendingException.get();
+        if (pendingError != null) {
+            log.error("startWriting({})[{}] propagate to {} write requests pending error={}[{}]",
+                    getId(), out, writes.size(), getClass().getSimpleName(), pendingError.getMessage());
+
+            IoWriteFutureImpl currentFuture = currentWrite.getAndSet(null);
+            for (IoWriteFutureImpl pendingWrite : writes) {
+                // Checking reference by design
+                if (GenericUtils.isSameReference(pendingWrite, currentFuture)) {
+                    continue;   // will be taken care of when its listener is eventually called
+                }
+
+                future.setValue(pendingError);
+            }
+
+            writes.clear();
+            return;
+        }
+
+        // Cannot honor this request yet since other pending one incomplete
         if (!currentWrite.compareAndSet(null, future)) {
             return;
         }
 
-        out.writeBuffer(future.getBuffer()).addListener(
-                new SshFutureListener<IoWriteFuture>() {
-                    @Override
-                    public void operationComplete(IoWriteFuture f) {
-                        if (f.isWritten()) {
-                            future.setValue(Boolean.TRUE);
-                        } else {
-                            future.setValue(f.getException());
-                        }
-                        finishWrite(future);
-                    }
-                });
+        Buffer buffer = future.getBuffer();
+        int bufferSize = buffer.available();
+        out.writeBuffer(buffer).addListener(new SshFutureListener<IoWriteFuture>() {
+            @Override
+            public void operationComplete(IoWriteFuture f) {
+                if (f.isWritten()) {
+                    future.setValue(Boolean.TRUE);
+                } else {
+                    future.setValue(f.getException());
+                }
+                finishWrite(future, bufferSize);
+            }
+        });
     }
 
-    protected void finishWrite(IoWriteFutureImpl future) {
+    protected void finishWrite(IoWriteFutureImpl future, int bufferSize) {
+        /*
+         * Update the pending bytes count only if successfully written,
+         * otherwise signal an error
+         */
+        if (future.isWritten()) {
+            long writtenSize = writtenBytesCount.addAndGet(bufferSize);
+
+            int stillPending;
+            synchronized (pendingBytesCount) {
+                stillPending = pendingBytesCount.addAndGet(0 - bufferSize);
+                pendingBytesCount.notifyAll();
+            }
+
+            /*
+             * NOTE: since the pending exception is updated outside the synchronized block
+             * a pending write could be successfully enqueued, however this is acceptable
+             * - see comment in waitForAvailableWriteSpace
+             */
+            if (stillPending < 0) {
+                log.error("finishWrite({})[{}] - pending byte counts underflow ({}) after {} bytes", getId(), out, stillPending,
+                        writtenSize);
+                pendingException.compareAndSet(null,
+                        new SshChannelBufferedOutputException(channelId, "Pending byte counts underflow"));
+            }
+        } else {
+            Throwable t = future.getException();
+            if (t instanceof SshChannelBufferedOutputException) {
+                pendingException.compareAndSet(null, (SshChannelBufferedOutputException) t);
+            } else {
+                pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, t));
+            }
+
+            // In case someone waiting so that they can detect the exception
+            synchronized (pendingBytesCount) {
+                pendingBytesCount.notifyAll();
+            }
+        }
+
         writes.remove(future);
         currentWrite.compareAndSet(future, null);
         try {
             startWriting();
         } catch (IOException e) {
-            error("finishWrite({}) failed ({}) re-start writing: {}",
-                    out, e.getClass().getSimpleName(), e.getMessage(), e);
+            if (e instanceof SshChannelBufferedOutputException) {
+                pendingException.compareAndSet(null, (SshChannelBufferedOutputException) e);
+            } else {
+                pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, e));
+            }
+            error("finishWrite({})[{}] failed ({}) re-start writing: {}",
+                    getId(), out, e.getClass().getSimpleName(), e.getMessage(), e);
         }
     }
 
     @Override
     protected Closeable getInnerCloseable() {
-        return builder()
-                .when(getId(), writes)
-                .close(out)
-                .build();
+        return builder().when(getId(), writes).close(out).build();
     }
 
     @Override
     public String toString() {
-        return getClass().getSimpleName() + "[" + out + "]";
+        return getClass().getSimpleName() + "(" + getId() + ")[" + out + "]";
     }
 }
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java
new file mode 100644
index 0000000..97e6105
--- /dev/null
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sshd.common.channel.exception;
+
+/**
+ * Used by the {@code BufferedIoOutputStream} to signal a non-recoverable error
+ *
+ * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
+ */
+public class SshChannelBufferedOutputException extends SshChannelException {
+    private static final long serialVersionUID = -8663890657820958046L;
+
+    public SshChannelBufferedOutputException(int channelId, String message) {
+        this(channelId, message, null);
+    }
+
+    public SshChannelBufferedOutputException(int channelId, Throwable cause) {
+        this(channelId, cause.getMessage(), cause);
+    }
+
+    public SshChannelBufferedOutputException(int channelId, String message, Throwable cause) {
+        super(channelId, message, cause);
+    }
+}
diff --git a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
index d728c3e..88d9724 100644
--- a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
+++ b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
@@ -24,6 +24,7 @@ import java.time.Duration;
 
 import org.apache.sshd.client.config.keys.ClientIdentityLoader;
 import org.apache.sshd.common.Property;
+import org.apache.sshd.common.SshConstants;
 import org.apache.sshd.common.channel.Channel;
 import org.apache.sshd.common.session.Session;
 import org.apache.sshd.common.util.OsUtils;
@@ -244,6 +245,24 @@ public final class CoreModuleProperties {
             = Property.duration("window-timeout", Duration.ZERO);
 
     /**
+     * Key used when creating a {@code BufferedIoOutputStream} in order to specify max. allowed unwritten pending bytes.
+     * If this value is exceeded then the code waits up to {@link #BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT} for the
+     * pending data to be written and thus make room for the new request.
+     */
+    public static final Property<Integer> BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE
+            = Property.integer("buffered-io-output-max-pending-write-size",
+                    SshConstants.SSH_REQUIRED_PAYLOAD_PACKET_LENGTH_SUPPORT * 8);
+
+    /**
+     * Key used when creating a {@code BufferedIoOutputStream} in order to specify max. wait time (msec.) for pending
+     * writes to be completed before enqueuing a new request
+     *
+     * @see #BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE
+     */
+    public static final Property<Duration> BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT
+            = Property.duration("buffered-io-output-max-pending-write-wait", Duration.ofSeconds(30L));
+
+    /**
      * Key used to retrieve the value of the maximum packet size in the configuration properties map.
      */
     public static final Property<Long> MAX_PACKET_SIZE
@@ -689,7 +708,7 @@ public final class CoreModuleProperties {
 
     /**
      * The lower threshold. If not set, half the higher threshold will be used.
-     * 
+     *
      * @see #TCPIP_SERVER_CHANNEL_BUFFER_SIZE_THRESHOLD_HIGH
      */
     public static final Property<Long> TCPIP_SERVER_CHANNEL_BUFFER_SIZE_THRESHOLD_LOW
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java b/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java
index 874b49e..0581ed4 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java
@@ -215,10 +215,12 @@ public class TcpipServerChannel extends AbstractServerChannel implements Streami
         }
 
         if (streaming == Streaming.Async) {
+            int channelId = getId();
             out = new BufferedIoOutputStream(
-                    "tcpip channel", new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_DATA) {
-                        @SuppressWarnings("synthetic-access")
+                    "aysnc-tcpip-channel@" + channelId, channelId,
+                    new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_DATA) {
                         @Override
+                        @SuppressWarnings("synthetic-access")
                         protected CloseFuture doCloseGracefully() {
                             try {
                                 sendEof();
@@ -227,7 +229,7 @@ public class TcpipServerChannel extends AbstractServerChannel implements Streami
                             }
                             return super.doCloseGracefully();
                         }
-                    });
+                    }, this);
         } else {
             this.out = new SimpleIoOutputStream(
                     new ChannelOutputStream(
diff --git a/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java b/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java
index b550893..7218ffd 100644
--- a/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java
+++ b/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java
@@ -99,12 +99,21 @@ public class AsyncEchoShellFactory implements ShellFactory {
 
         @Override
         public void setIoOutputStream(IoOutputStream out) {
-            this.out = new BufferedIoOutputStream("STDOUT", out);
+            this.out = wrapOutputStream("SHELL-STDOUT", out);
         }
 
         @Override
         public void setIoErrorStream(IoOutputStream err) {
-            this.err = new BufferedIoOutputStream("STDERR", err);
+            this.err = wrapOutputStream("SHELL-STDERR", err);
+        }
+
+        protected BufferedIoOutputStream wrapOutputStream(String prefix, IoOutputStream stream) {
+            if (stream instanceof BufferedIoOutputStream) {
+                return (BufferedIoOutputStream) stream;
+            }
+
+            int channelId = session.getId();
+            return new BufferedIoOutputStream(prefix + "@" + channelId, channelId, stream, session);
         }
 
         @Override
diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
index a5bb4ae..2c79d23 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
@@ -242,7 +242,8 @@ public class SftpSubsystem
 
     @Override
     public void setIoOutputStream(IoOutputStream out) {
-        this.out = new BufferedIoOutputStream("sftp out buffer", out);
+        int channelId = channelSession.getId();
+        this.out = new BufferedIoOutputStream("sftp-out@" + channelId, channelId, out, channelSession);
     }
 
     @Override
diff --git a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java
index 5d80f62..a69e3db 100644
--- a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java
+++ b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java
@@ -48,6 +48,7 @@ import org.apache.sshd.sftp.common.SftpConstants;
 import org.apache.sshd.util.test.CommonTestSupportUtils;
 import org.junit.Before;
 import org.junit.FixMethodOrder;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runners.MethodSorters;
 import org.slf4j.Logger;
@@ -217,6 +218,7 @@ public class SftpRemotePathChannelTest extends AbstractSftpClientTestSupport {
      * limit the available heap memory of the junit execution by passing "-Xmx256m" to the VM.
      */
     @Test(timeout = 5L * 60L * 1000L)   // see SSHD-1125
+    @Ignore("Used only for debugging SSHD-1125")
     public void testReadRequestsOutOfMemory() throws Exception {
         Path targetPath = detectTargetFolder();
         Path parentPath = targetPath.getParent();