You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2021/03/05 09:17:55 UTC
[mina-sshd] 02/03: [SSHD-1125] Added mechanism to throttle pending
write requests in BufferedIoOutputStream
This is an automated email from the ASF dual-hosted git repository.
lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit 18609370696cc52ac780864237b37b2f173c4090
Author: Lyor Goldstein <lg...@apache.org>
AuthorDate: Thu Feb 25 21:05:49 2021 +0200
[SSHD-1125] Added mechanism to throttle pending write requests in BufferedIoOutputStream
---
CHANGES.md | 1 +
.../common/channel/BufferedIoOutputStream.java | 203 ++++++++++++++++++---
.../SshChannelBufferedOutputException.java | 41 +++++
.../org/apache/sshd/core/CoreModuleProperties.java | 21 ++-
.../sshd/server/forward/TcpipServerChannel.java | 8 +-
.../sshd/util/test/AsyncEchoShellFactory.java | 13 +-
.../org/apache/sshd/sftp/server/SftpSubsystem.java | 3 +-
.../client/impl/SftpRemotePathChannelTest.java | 2 +
8 files changed, 260 insertions(+), 32 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 685e749..4bf9e41 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -38,4 +38,5 @@
* [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added callbacks for client-side host-based authentication progress
* [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added capability for interactive password authentication participation via UserInteraction
* [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added capability for interactive key based authentication participation via UserInteraction
+* [SSHD-1125](https://issues.apache.org/jira/browse/SSHD-1125) Added mechanism to throttle pending write requests in BufferedIoOutputStream
* [SSHD-1127](https://issues.apache.org/jira/browse/SSHD-1127) Added capability to register a custom receiver for SFTP STDERR channel raw or stream data
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java
index 3ee3ece..e8b81d8 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java
@@ -20,29 +20,55 @@ package org.apache.sshd.common.channel;
import java.io.EOFException;
import java.io.IOException;
+import java.time.Duration;
+import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.sshd.common.Closeable;
+import org.apache.sshd.common.PropertyResolver;
+import org.apache.sshd.common.channel.exception.SshChannelBufferedOutputException;
import org.apache.sshd.common.future.SshFutureListener;
import org.apache.sshd.common.io.IoOutputStream;
import org.apache.sshd.common.io.IoWriteFuture;
+import org.apache.sshd.common.util.GenericUtils;
+import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.closeable.AbstractInnerCloseable;
+import org.apache.sshd.core.CoreModuleProperties;
/**
* An {@link IoOutputStream} capable of queuing write requests.
*/
public class BufferedIoOutputStream extends AbstractInnerCloseable implements IoOutputStream {
+ protected final Object id;
+ protected final int channelId;
+ protected final int maxPendingBytesCount;
+ protected final Duration maxWaitForPendingWrites;
protected final IoOutputStream out;
+ protected final AtomicInteger pendingBytesCount = new AtomicInteger();
+ protected final AtomicLong writtenBytesCount = new AtomicLong();
protected final Queue<IoWriteFutureImpl> writes = new ConcurrentLinkedQueue<>();
protected final AtomicReference<IoWriteFutureImpl> currentWrite = new AtomicReference<>();
- protected final Object id;
+ protected final AtomicReference<SshChannelBufferedOutputException> pendingException = new AtomicReference<>();
+
+ public BufferedIoOutputStream(Object id, int channelId, IoOutputStream out, PropertyResolver resolver) {
+ this(id, channelId, out, CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE.getRequired(resolver),
+ CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT.getRequired(resolver));
+ }
- public BufferedIoOutputStream(Object id, IoOutputStream out) {
- this.out = out;
- this.id = id;
+ public BufferedIoOutputStream(
+ Object id, int channelId, IoOutputStream out, int maxPendingBytesCount,
+ Duration maxWaitForPendingWrites) {
+ this.id = Objects.requireNonNull(id, "No stream identifier provided");
+ this.channelId = channelId;
+ this.out = Objects.requireNonNull(out, "No delegate output stream provided");
+ this.maxPendingBytesCount = maxPendingBytesCount;
+ ValidateUtils.checkTrue(maxPendingBytesCount > 0, "Invalid max. pending bytes count: %d", maxPendingBytesCount);
+ this.maxWaitForPendingWrites = Objects.requireNonNull(maxWaitForPendingWrites, "No max. pending time value provided");
}
public Object getId() {
@@ -52,60 +78,187 @@ public class BufferedIoOutputStream extends AbstractInnerCloseable implements Io
@Override
public IoWriteFuture writeBuffer(Buffer buffer) throws IOException {
if (isClosing()) {
- throw new EOFException("Closed - state=" + state);
+ throw new EOFException("Closed/ing - state=" + state);
}
+ waitForAvailableWriteSpace(buffer.available());
+
IoWriteFutureImpl future = new IoWriteFutureImpl(getId(), buffer);
writes.add(future);
startWriting();
return future;
}
+ protected void waitForAvailableWriteSpace(int requiredSize) throws IOException {
+ /*
+ * NOTE: this code allows a single pending write to give this mechanism "the slip" and
+ * exit the loop "unscathed" even though there is a pending exception. However, the goal
+ * here is to avoid an OOM by having an unlimited accumulation of pending write requests
+ * due to fact that the peer is not consuming the sent data. Please note that the pending
+ * exception is "sticky" - i.e., the next write attempt will fail. This also means that if
+ * the write request that "got away" was the last one by chance and it was consumed by the
+ * peer there will be no exception thrown - which is also fine since as mentioned the goal
+ * is not to enforce a strict limit on the pending bytes size but rather on the accumulation
+ * of the pending write requests.
+ *
+ * We could have counted pending requests rather than bytes. However, we also want to avoid
+ * having a large amount of data pending consumption by the peer as well. This code strikes
+ * such a balance by allowing a single pending request to exceed the limit, but at the same
+ * time prevents too many bytes from pending by having a bunch of pending requests that while
+ * below the imposed number limit may cumulatively represent a lot of pending bytes.
+ */
+
+ long expireTime = System.currentTimeMillis() + maxWaitForPendingWrites.toMillis();
+ synchronized (pendingBytesCount) {
+ for (int count = pendingBytesCount.get();
+ /*
+ * The (count > 0) condition is put in place to allow a single pending
+ * write to exceed the maxPendingBytesCount as long as there are no
+ * other pending ones.
+ */
+ (count > 0)
+ // Not already over the limit or about to be over it
+ && ((count + requiredSize) > maxPendingBytesCount)
+ // No pending exception signaled
+ && (pendingException.get() == null);
+ count = pendingBytesCount.get()) {
+ long remTime = expireTime - System.currentTimeMillis();
+ if (remTime <= 0L) {
+ pendingException.compareAndSet(null,
+ new SshChannelBufferedOutputException(
+ channelId,
+ "Max. pending write timeout expired after " + writtenBytesCount + " bytes"));
+ throw pendingException.get();
+ }
+
+ try {
+ pendingBytesCount.wait(remTime);
+ } catch (InterruptedException e) {
+ pendingException.compareAndSet(null,
+ new SshChannelBufferedOutputException(
+ channelId,
+ "Waiting for pending writes interrupted after " + writtenBytesCount + " bytes"));
+ throw pendingException.get();
+ }
+ }
+
+ IOException e = pendingException.get();
+ if (e != null) {
+ throw e;
+ }
+
+ pendingBytesCount.addAndGet(requiredSize);
+ }
+ }
+
protected void startWriting() throws IOException {
IoWriteFutureImpl future = writes.peek();
+ // No more pending requests
if (future == null) {
return;
}
+ // Don't try to write any further if pending exception signaled
+ Throwable pendingError = pendingException.get();
+ if (pendingError != null) {
+ log.error("startWriting({})[{}] propagate to {} write requests pending error={}[{}]",
+ getId(), out, writes.size(), getClass().getSimpleName(), pendingError.getMessage());
+
+ IoWriteFutureImpl currentFuture = currentWrite.getAndSet(null);
+ for (IoWriteFutureImpl pendingWrite : writes) {
+ // Checking reference by design
+ if (GenericUtils.isSameReference(pendingWrite, currentFuture)) {
+ continue; // will be taken care of when its listener is eventually called
+ }
+
+ future.setValue(pendingError);
+ }
+
+ writes.clear();
+ return;
+ }
+
+ // Cannot honor this request yet since other pending one incomplete
if (!currentWrite.compareAndSet(null, future)) {
return;
}
- out.writeBuffer(future.getBuffer()).addListener(
- new SshFutureListener<IoWriteFuture>() {
- @Override
- public void operationComplete(IoWriteFuture f) {
- if (f.isWritten()) {
- future.setValue(Boolean.TRUE);
- } else {
- future.setValue(f.getException());
- }
- finishWrite(future);
- }
- });
+ Buffer buffer = future.getBuffer();
+ int bufferSize = buffer.available();
+ out.writeBuffer(buffer).addListener(new SshFutureListener<IoWriteFuture>() {
+ @Override
+ public void operationComplete(IoWriteFuture f) {
+ if (f.isWritten()) {
+ future.setValue(Boolean.TRUE);
+ } else {
+ future.setValue(f.getException());
+ }
+ finishWrite(future, bufferSize);
+ }
+ });
}
- protected void finishWrite(IoWriteFutureImpl future) {
+ protected void finishWrite(IoWriteFutureImpl future, int bufferSize) {
+ /*
+ * Update the pending bytes count only if successfully written,
+ * otherwise signal an error
+ */
+ if (future.isWritten()) {
+ long writtenSize = writtenBytesCount.addAndGet(bufferSize);
+
+ int stillPending;
+ synchronized (pendingBytesCount) {
+ stillPending = pendingBytesCount.addAndGet(0 - bufferSize);
+ pendingBytesCount.notifyAll();
+ }
+
+ /*
+ * NOTE: since the pending exception is updated outside the synchronized block
+ * a pending write could be successfully enqueued, however this is acceptable
+ * - see comment in waitForAvailableWriteSpace
+ */
+ if (stillPending < 0) {
+ log.error("finishWrite({})[{}] - pending byte counts underflow ({}) after {} bytes", getId(), out, stillPending,
+ writtenSize);
+ pendingException.compareAndSet(null,
+ new SshChannelBufferedOutputException(channelId, "Pending byte counts underflow"));
+ }
+ } else {
+ Throwable t = future.getException();
+ if (t instanceof SshChannelBufferedOutputException) {
+ pendingException.compareAndSet(null, (SshChannelBufferedOutputException) t);
+ } else {
+ pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, t));
+ }
+
+ // In case someone waiting so that they can detect the exception
+ synchronized (pendingBytesCount) {
+ pendingBytesCount.notifyAll();
+ }
+ }
+
writes.remove(future);
currentWrite.compareAndSet(future, null);
try {
startWriting();
} catch (IOException e) {
- error("finishWrite({}) failed ({}) re-start writing: {}",
- out, e.getClass().getSimpleName(), e.getMessage(), e);
+ if (e instanceof SshChannelBufferedOutputException) {
+ pendingException.compareAndSet(null, (SshChannelBufferedOutputException) e);
+ } else {
+ pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, e));
+ }
+ error("finishWrite({})[{}] failed ({}) re-start writing: {}",
+ getId(), out, e.getClass().getSimpleName(), e.getMessage(), e);
}
}
@Override
protected Closeable getInnerCloseable() {
- return builder()
- .when(getId(), writes)
- .close(out)
- .build();
+ return builder().when(getId(), writes).close(out).build();
}
@Override
public String toString() {
- return getClass().getSimpleName() + "[" + out + "]";
+ return getClass().getSimpleName() + "(" + getId() + ")[" + out + "]";
}
}
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java
new file mode 100644
index 0000000..97e6105
--- /dev/null
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sshd.common.channel.exception;
+
+/**
+ * Used by the {@code BufferedIoOutputStream} to signal a non-recoverable error
+ *
+ * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
+ */
+public class SshChannelBufferedOutputException extends SshChannelException {
+ private static final long serialVersionUID = -8663890657820958046L;
+
+ public SshChannelBufferedOutputException(int channelId, String message) {
+ this(channelId, message, null);
+ }
+
+ public SshChannelBufferedOutputException(int channelId, Throwable cause) {
+ this(channelId, cause.getMessage(), cause);
+ }
+
+ public SshChannelBufferedOutputException(int channelId, String message, Throwable cause) {
+ super(channelId, message, cause);
+ }
+}
diff --git a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
index d728c3e..88d9724 100644
--- a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
+++ b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java
@@ -24,6 +24,7 @@ import java.time.Duration;
import org.apache.sshd.client.config.keys.ClientIdentityLoader;
import org.apache.sshd.common.Property;
+import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.channel.Channel;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.util.OsUtils;
@@ -244,6 +245,24 @@ public final class CoreModuleProperties {
= Property.duration("window-timeout", Duration.ZERO);
/**
+ * Key used when creating a {@code BufferedIoOutputStream} in order to specify max. allowed unwritten pending bytes.
+ * If this value is exceeded then the code waits up to {@link #BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT} for the
+ * pending data to be written and thus make room for the new request.
+ */
+ public static final Property<Integer> BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE
+ = Property.integer("buffered-io-output-max-pending-write-size",
+ SshConstants.SSH_REQUIRED_PAYLOAD_PACKET_LENGTH_SUPPORT * 8);
+
+ /**
+ * Key used when creating a {@code BufferedIoOutputStream} in order to specify max. wait time (msec.) for pending
+ * writes to be completed before enqueuing a new request
+ *
+ * @see #BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE
+ */
+ public static final Property<Duration> BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT
+ = Property.duration("buffered-io-output-max-pending-write-wait", Duration.ofSeconds(30L));
+
+ /**
* Key used to retrieve the value of the maximum packet size in the configuration properties map.
*/
public static final Property<Long> MAX_PACKET_SIZE
@@ -689,7 +708,7 @@ public final class CoreModuleProperties {
/**
* The lower threshold. If not set, half the higher threshold will be used.
- *
+ *
* @see #TCPIP_SERVER_CHANNEL_BUFFER_SIZE_THRESHOLD_HIGH
*/
public static final Property<Long> TCPIP_SERVER_CHANNEL_BUFFER_SIZE_THRESHOLD_LOW
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java b/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java
index 874b49e..0581ed4 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/forward/TcpipServerChannel.java
@@ -215,10 +215,12 @@ public class TcpipServerChannel extends AbstractServerChannel implements Streami
}
if (streaming == Streaming.Async) {
+ int channelId = getId();
out = new BufferedIoOutputStream(
- "tcpip channel", new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_DATA) {
- @SuppressWarnings("synthetic-access")
+ "aysnc-tcpip-channel@" + channelId, channelId,
+ new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_DATA) {
@Override
+ @SuppressWarnings("synthetic-access")
protected CloseFuture doCloseGracefully() {
try {
sendEof();
@@ -227,7 +229,7 @@ public class TcpipServerChannel extends AbstractServerChannel implements Streami
}
return super.doCloseGracefully();
}
- });
+ }, this);
} else {
this.out = new SimpleIoOutputStream(
new ChannelOutputStream(
diff --git a/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java b/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java
index b550893..7218ffd 100644
--- a/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java
+++ b/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java
@@ -99,12 +99,21 @@ public class AsyncEchoShellFactory implements ShellFactory {
@Override
public void setIoOutputStream(IoOutputStream out) {
- this.out = new BufferedIoOutputStream("STDOUT", out);
+ this.out = wrapOutputStream("SHELL-STDOUT", out);
}
@Override
public void setIoErrorStream(IoOutputStream err) {
- this.err = new BufferedIoOutputStream("STDERR", err);
+ this.err = wrapOutputStream("SHELL-STDERR", err);
+ }
+
+ protected BufferedIoOutputStream wrapOutputStream(String prefix, IoOutputStream stream) {
+ if (stream instanceof BufferedIoOutputStream) {
+ return (BufferedIoOutputStream) stream;
+ }
+
+ int channelId = session.getId();
+ return new BufferedIoOutputStream(prefix + "@" + channelId, channelId, stream, session);
}
@Override
diff --git a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
index a5bb4ae..2c79d23 100644
--- a/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
+++ b/sshd-sftp/src/main/java/org/apache/sshd/sftp/server/SftpSubsystem.java
@@ -242,7 +242,8 @@ public class SftpSubsystem
@Override
public void setIoOutputStream(IoOutputStream out) {
- this.out = new BufferedIoOutputStream("sftp out buffer", out);
+ int channelId = channelSession.getId();
+ this.out = new BufferedIoOutputStream("sftp-out@" + channelId, channelId, out, channelSession);
}
@Override
diff --git a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java
index 5d80f62..a69e3db 100644
--- a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java
+++ b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/impl/SftpRemotePathChannelTest.java
@@ -48,6 +48,7 @@ import org.apache.sshd.sftp.common.SftpConstants;
import org.apache.sshd.util.test.CommonTestSupportUtils;
import org.junit.Before;
import org.junit.FixMethodOrder;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
@@ -217,6 +218,7 @@ public class SftpRemotePathChannelTest extends AbstractSftpClientTestSupport {
* limit the available heap memory of the junit execution by passing "-Xmx256m" to the VM.
*/
@Test(timeout = 5L * 60L * 1000L) // see SSHD-1125
+ @Ignore("Used only for debugging SSHD-1125")
public void testReadRequestsOutOfMemory() throws Exception {
Path targetPath = detectTargetFolder();
Path parentPath = targetPath.getParent();