You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2020/05/15 08:11:38 UTC
[mina-sshd] 01/01: [SSHD-966] Using same lock to synchronize
session pending packets and ChannelOutputStream mutual exclusion
This is an automated email from the ASF dual-hosted git repository.
lgoldstein pushed a commit to branch SSHD-966
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit 25e7e13c95ada68f3def91dd153b44fd0127971b
Author: Lyor Goldstein <lg...@apache.org>
AuthorDate: Fri May 15 10:24:50 2020 +0300
[SSHD-966] Using same lock to synchronize session pending packets and ChannelOutputStream mutual exclusion
---
.../sshd/common/channel/ChannelOutputStream.java | 253 ++++++++++++++-------
.../common/session/helpers/AbstractSession.java | 98 ++++++--
2 files changed, 245 insertions(+), 106 deletions(-)
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
index d0d879b..14f7e91 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
@@ -26,11 +26,14 @@ import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.sshd.common.FactoryManager;
+import org.apache.sshd.common.RuntimeSshException;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.channel.exception.SshChannelClosedException;
import org.apache.sshd.common.io.PacketWriter;
import org.apache.sshd.common.session.Session;
+import org.apache.sshd.common.session.helpers.AbstractSession;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.slf4j.Logger;
@@ -103,24 +106,62 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
}
@Override
- public synchronized void write(int w) throws IOException {
- b[0] = (byte) w;
- write(b, 0, 1);
+ public void write(int w) throws IOException {
+ try {
+ Channel channel = getChannel();
+ Session session = channel.getSession();
+ ((AbstractSession) session).executeUnderPendingPacketsLock(
+ getExtraPendingPacketLockWaitTime(1), () -> {
+ b[0] = (byte) w;
+ lockedWrite(session, channel, b, 0, 1);
+ return null;
+ });
+ } catch (Exception e) {
+ log.error("write(" + this + ") value=0x" + Integer.toHexString(w) + " failed to write", e);
+ if (e instanceof IOException) {
+ throw (IOException) e;
+ } else if (e instanceof RuntimeException) {
+ throw (RuntimeException) e;
+ } else {
+ throw new RuntimeSshException(e);
+ }
+ }
}
@Override
- public synchronized void write(byte[] buf, int s, int l) throws IOException {
- Channel channel = getChannel();
+ public void write(byte[] buf, int startOffset, int dataLen) throws IOException {
+ try {
+ Channel channel = getChannel();
+ Session session = channel.getSession();
+ ((AbstractSession) session).executeUnderPendingPacketsLock(
+ getExtraPendingPacketLockWaitTime(dataLen), () -> {
+ lockedWrite(session, channel, buf, startOffset, dataLen);
+ return null;
+ });
+ } catch (Exception e) {
+ log.error("write(" + this + ") len=" + dataLen + " failed to write", e);
+ if (e instanceof IOException) {
+ throw (IOException) e;
+ } else if (e instanceof RuntimeException) {
+ throw (RuntimeException) e;
+ } else {
+ throw new RuntimeSshException(e);
+ }
+ }
+ }
+
+ protected void lockedWrite(
+ Session session, Channel channel, byte[] buf, int startOffset, int dataLen)
+ throws Exception {
if (!isOpen()) {
throw new SshChannelClosedException(
channel.getId(),
- "write(" + this + ") len=" + l + " - channel already closed");
+ "lockedWrite(" + this + ") len=" + dataLen + " - channel already closed");
}
- Session session = channel.getSession();
boolean debugEnabled = log.isDebugEnabled();
boolean traceEnabled = log.isTraceEnabled();
- while (l > 0) {
+ while (dataLen > 0) {
// The maximum amount we should admit without flushing again
// is enough to make up one full packet within our allowed
// window size. We give ourselves a credit equal to the last
@@ -128,31 +169,31 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
// out the next packet before we block and wait for space to
// become available again.
long minReqLen = Math.min(remoteWindow.getSize() + lastSize, remoteWindow.getPacketSize());
- long l2 = Math.min(l, minReqLen - bufferLength);
+ long l2 = Math.min(dataLen, minReqLen - bufferLength);
if (l2 <= 0) {
if (bufferLength > 0) {
- flush();
+ lockedFlush(session, channel);
} else {
session.resetIdleTimeout();
try {
long available = remoteWindow.waitForSpace(maxWaitTimeout);
if (traceEnabled) {
- log.trace("write({}) len={} - available={}", this, l, available);
+ log.trace("lockedWrite({}) len={} - available={}", this, dataLen, available);
}
} catch (IOException e) {
- log.error("write({}) failed ({}) to wait for space of len={}: {}",
- this, e.getClass().getSimpleName(), l, e.getMessage());
+ log.error("lockedWrite({}) failed ({}) to wait for space of len={}: {}",
+ this, e.getClass().getSimpleName(), dataLen, e.getMessage());
if ((e instanceof WindowClosedException) && (!closedState.getAndSet(true))) {
if (debugEnabled) {
- log.debug("write({})[len={}] closing due to window closed", this, l);
+ log.debug("lockedWrite({})[len={}] closing due to window closed", this, dataLen);
}
}
throw e;
} catch (InterruptedException e) {
throw (IOException) new InterruptedIOException(
- "Interrupted while waiting for remote space on write len=" + l + " to " + this)
+ "Interrupted while waiting for remote space on write len=" + dataLen + " to " + this)
.initCause(e);
}
}
@@ -162,81 +203,30 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
ValidateUtils.checkTrue(l2 <= Integer.MAX_VALUE,
"Accumulated bytes length exceeds int boundary: %d", l2);
- buffer.putRawBytes(buf, s, (int) l2);
+ buffer.putRawBytes(buf, startOffset, (int) l2);
bufferLength += l2;
- s += l2;
- l -= l2;
+ startOffset += l2;
+ dataLen -= l2;
}
if (isNoDelay()) {
- flush();
+ lockedFlush(session, channel);
} else {
session.resetIdleTimeout();
}
}
@Override
- public synchronized void flush() throws IOException {
- AbstractChannel channel = getChannel();
- if (!isOpen()) {
- throw new SshChannelClosedException(
- channel.getId(),
- "flush(" + this + ") length=" + bufferLength + " - stream is already closed");
- }
-
+ public void flush() throws IOException {
+ Channel channel = getChannel();
+ Session session = channel.getSession();
try {
- Session session = channel.getSession();
- boolean traceEnabled = log.isTraceEnabled();
- while (bufferLength > 0) {
- session.resetIdleTimeout();
-
- Buffer buf = buffer;
- long total = bufferLength;
- long available;
- try {
- available = remoteWindow.waitForSpace(maxWaitTimeout);
- if (traceEnabled) {
- log.trace("flush({}) len={}, available={}", this, total, available);
- }
- } catch (IOException e) {
- log.error("flush({}) failed ({}) to wait for space of len={}: {}",
- this, e.getClass().getSimpleName(), total, e.getMessage());
- if (log.isDebugEnabled()) {
- log.error("flush(" + this + ") wait for space len=" + total + " exception details", e);
- }
- throw e;
- }
-
- long lenToSend = Math.min(available, total);
- long length = Math.min(lenToSend, remoteWindow.getPacketSize());
- if (length > Integer.MAX_VALUE) {
- throw new StreamCorruptedException(
- "Accumulated " + SshConstants.getCommandMessageName(cmd)
- + " command bytes size (" + length + ") exceeds int boundaries");
- }
-
- int pos = buf.wpos();
- buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10);
- buf.putInt(length);
- buf.wpos(buf.wpos() + (int) length);
- if (total == length) {
- newBuffer((int) length);
- } else {
- long leftover = total - length;
- newBuffer((int) Math.max(leftover, length));
- buffer.putRawBytes(buf.array(), pos - (int) leftover, (int) leftover);
- bufferLength = (int) leftover;
- }
- lastSize = (int) length;
-
- session.resetIdleTimeout();
- remoteWindow.waitAndConsume(length, maxWaitTimeout);
- if (traceEnabled) {
- log.trace("flush({}) send {} len={}",
- channel, SshConstants.getCommandMessageName(cmd), length);
- }
- packetWriter.writePacket(buf);
- }
+ ((AbstractSession) session).executeUnderPendingPacketsLock(
+ getExtraPendingPacketLockWaitTime(bufferLength),
+ () -> {
+ lockedFlush(session, channel);
+ return null;
+ });
} catch (WindowClosedException e) {
if (!closedState.getAndSet(true)) {
if (log.isDebugEnabled()) {
@@ -245,8 +235,11 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
}
throw e;
} catch (Exception e) {
+ log.error("flush(" + this + ") failed", e);
if (e instanceof IOException) {
throw (IOException) e;
+ } else if (e instanceof RuntimeException) {
+ throw (RuntimeException) e;
} else if (e instanceof InterruptedException) {
throw (IOException) new InterruptedIOException(
"Interrupted while waiting for remote space flush len=" + bufferLength + " to " + this)
@@ -257,21 +250,87 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
}
}
- @Override
- public synchronized void close() throws IOException {
+ protected void lockedFlush(Session session, Channel channel) throws Exception {
+ boolean traceEnabled = log.isTraceEnabled();
if (!isOpen()) {
+ if (bufferLength > 0) {
+ throw new SshChannelClosedException(
+ channel.getId(),
+ "lockedFlush(" + this + ") length=" + bufferLength + " - stream is already closed");
+ }
+
+ if (traceEnabled) {
+ log.trace("lockedFlush({}) nothing to flush", this);
+ }
return;
}
+ while (bufferLength > 0) {
+ session.resetIdleTimeout();
+
+ Buffer buf = buffer;
+ long total = bufferLength;
+ long available;
+ try {
+ available = remoteWindow.waitForSpace(maxWaitTimeout);
+ if (traceEnabled) {
+ log.trace("lockedFlush({}) len={}, available={}", this, total, available);
+ }
+ } catch (IOException e) {
+ log.error("lockedFlush({}) failed ({}) to wait for space of len={}: {}",
+ this, e.getClass().getSimpleName(), total, e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.error("lockedFlush(" + this + ") wait for space len=" + total + " exception details", e);
+ }
+ throw e;
+ }
+
+ long lenToSend = Math.min(available, total);
+ long length = Math.min(lenToSend, remoteWindow.getPacketSize());
+ if (length > Integer.MAX_VALUE) {
+ throw new StreamCorruptedException(
+ "Accumulated " + SshConstants.getCommandMessageName(cmd)
+ + " command bytes size (" + length
+ + ") exceeds int boundaries");
+ }
+
+ int pos = buf.wpos();
+ buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10);
+ buf.putInt(length);
+ buf.wpos(buf.wpos() + (int) length);
+ if (total == length) {
+ newBuffer((int) length);
+ } else {
+ long leftover = total - length;
+ newBuffer((int) Math.max(leftover, length));
+ buffer.putRawBytes(buf.array(), pos - (int) leftover, (int) leftover);
+ bufferLength = (int) leftover;
+ }
+ lastSize = (int) length;
+
+ session.resetIdleTimeout();
+ remoteWindow.waitAndConsume(length, maxWaitTimeout);
+ if (traceEnabled) {
+ log.trace("lockedFlush({}) send len={}", this, length);
+ }
+ packetWriter.writePacket(buf);
+ }
+ }
+
+ protected long getExtraPendingPacketLockWaitTime(int dataSize) {
+ // TODO see if can do anything better
+ return Math.min(dataSize, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT) + maxWaitTimeout;
+ }
+
+ protected void lockedClose(Session session, AbstractChannel channel) throws Exception {
if (log.isTraceEnabled()) {
- log.trace("close({}) closing", this);
+ log.trace("lockedClose({}) closing", this);
}
try {
- flush();
+ lockedFlush(session, channel);
if (isEofOnClose()) {
- AbstractChannel channel = getChannel();
channel.sendEof();
}
} finally {
@@ -285,6 +344,32 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
}
}
+ @Override
+ public void close() throws IOException {
+ if (!isOpen()) {
+ return;
+ }
+
+ AbstractChannel channel = getChannel();
+ Session session = channel.getSession();
+ try {
+ ((AbstractSession) session).executeUnderPendingPacketsLock(
+ getExtraPendingPacketLockWaitTime(bufferLength),
+ () -> {
+ lockedClose(session, channel);
+ return null;
+ });
+ } catch (Exception e) {
+ if (e instanceof IOException) {
+ throw (IOException) e;
+ } else if (e instanceof RuntimeException) {
+ throw (RuntimeException) e;
+ } else {
+ throw new RuntimeSshException(e);
+ }
+ }
+ }
+
protected void newBuffer(int size) {
Channel channel = getChannel();
Session session = channel.getSession();
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
index 47a56b9..849a561 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
@@ -35,10 +35,15 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
+import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import org.apache.sshd.common.Closeable;
@@ -187,6 +192,7 @@ public abstract class AbstractSession extends SessionHelper {
protected long maxRekyPackets = FactoryManager.DEFAULT_REKEY_PACKETS_LIMIT;
protected long maxRekeyBytes = FactoryManager.DEFAULT_REKEY_BYTES_LIMIT;
protected long maxRekeyInterval = FactoryManager.DEFAULT_REKEY_TIME_LIMIT;
+ protected final Lock pendingPacketsLock = new ReentrantLock();
protected final Queue<PendingWriteFuture> pendingPackets = new LinkedList<>();
protected Service currentService;
@@ -656,6 +662,38 @@ public abstract class AbstractSession extends SessionHelper {
doKexNegotiation();
}
+ /**
+ * Attempts to lock the pending packets access and execute the relevant code. Max. wait time is derived from the
+ * current number of pending packets
+ *
+ * @param <V> The executed code return value
+ * @param extraWait An extra amount of time (msec.) that the caller is willing to wait beyond the time derived from
+ * the number of pending packets. <B>Note:</B> a hardcoded max. value of
+ * {@link FactoryManager#DEFAULT_AUTH_TIMEOUT} is imposed on the total calculated time.
+ * @param executor The code to execute under lock
+ * @return The executed code result
+ * @throws Exception If failed to lock or exception thrown by executor code
+ * @see <A HREF="https://issues.apache.org/jira/browse/SSHD-966">SSHD-966</A>
+ */
+ public <V> V executeUnderPendingPacketsLock(long extraWait, Callable<? extends V> executor) throws Exception {
+ ValidateUtils.checkTrue(extraWait >= 0L, "Invalid extra wait time: %d", extraWait);
+ int numPending = pendingPackets.size();
+ long maxWait = numPending * FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT + extraWait;
+ // in case zero
+ maxWait = Math.max(maxWait, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT);
+ // in case lots of pending packets or large extra time
+ maxWait = Math.min(maxWait, FactoryManager.DEFAULT_AUTH_TIMEOUT);
+ if (!pendingPacketsLock.tryLock(maxWait, TimeUnit.MILLISECONDS)) {
+ throw new TimeoutException("Failed to acquire " + numPending + " pending packets lock");
+ }
+
+ try {
+ return executor.call();
+ } finally {
+ pendingPacketsLock.unlock();
+ }
+ }
+
protected void doKexNegotiation() throws Exception {
if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) {
sendKexInit();
@@ -669,9 +707,10 @@ public abstract class AbstractSession extends SessionHelper {
KeyExchangeFactory kexFactory = NamedResource.findByName(
kexAlgorithm, String.CASE_INSENSITIVE_ORDER, kexFactories);
ValidateUtils.checkNotNull(kexFactory, "Unknown negotiated KEX algorithm: %s", kexAlgorithm);
- synchronized (pendingPackets) {
+ executeUnderPendingPacketsLock(0L, () -> {
kex = kexFactory.createKeyExchange(this);
- }
+ return kex;
+ });
byte[] v_s = serverVersion.getBytes(StandardCharsets.UTF_8);
byte[] v_c = clientVersion.getBytes(StandardCharsets.UTF_8);
@@ -707,12 +746,13 @@ public abstract class AbstractSession extends SessionHelper {
signalSessionEvent(SessionListener.Event.KeyEstablished);
- Collection<? extends Map.Entry<? extends SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites;
- synchronized (pendingPackets) {
- pendingWrites = sendPendingPackets(pendingPackets);
- kex = null; // discard and GC since KEX is completed
- kexState.set(KexState.DONE);
- }
+ Collection<? extends Map.Entry<? extends SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites
+ = executeUnderPendingPacketsLock(FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT, () -> {
+ List<Map.Entry<PendingWriteFuture, IoWriteFuture>> result = sendPendingPackets(pendingPackets);
+ kex = null; // discard and GC since KEX is completed
+ kexState.set(KexState.DONE);
+ return result;
+ });
int pendingCount = pendingWrites.size();
if (pendingCount > 0) {
@@ -734,7 +774,7 @@ public abstract class AbstractSession extends SessionHelper {
}
}
- protected List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> sendPendingPackets(
+ protected List<Map.Entry<PendingWriteFuture, IoWriteFuture>> sendPendingPackets(
Queue<PendingWriteFuture> packetsQueue)
throws IOException {
if (GenericUtils.isEmpty(packetsQueue)) {
@@ -742,7 +782,7 @@ public abstract class AbstractSession extends SessionHelper {
}
int numPending = packetsQueue.size();
- List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new ArrayList<>(numPending);
+ List<Map.Entry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new ArrayList<>(numPending);
synchronized (encodeLock) {
for (PendingWriteFuture future = packetsQueue.poll();
future != null;
@@ -872,10 +912,11 @@ public abstract class AbstractSession extends SessionHelper {
* Checks if key-exchange is done - if so, or the packet is related to the key-exchange protocol, then allows the
* packet to go through, otherwise enqueues it to be sent when key-exchange completed
*
- * @param buffer The {@link Buffer} containing the packet to be sent
- * @return A {@link PendingWriteFuture} if enqueued, {@code null} if packet can go through.
+ * @param buffer The {@link Buffer} containing the packet to be sent
+ * @return A {@link PendingWriteFuture} if enqueued, {@code null} if packet can go through.
+ * @throws IOException If failed to enqueue
*/
- protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) {
+ protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) throws IOException {
if (KexState.DONE.equals(kexState.get())) {
return null;
}
@@ -887,20 +928,33 @@ public abstract class AbstractSession extends SessionHelper {
}
String cmdName = SshConstants.getCommandMessageName(cmd);
+ AtomicInteger numPending = new AtomicInteger();
PendingWriteFuture future;
- int numPending;
- synchronized (pendingPackets) {
- if (KexState.DONE.equals(kexState.get())) {
- return null;
- }
+ try {
+ future = executeUnderPendingPacketsLock(0L, () -> {
+ if (KexState.DONE.equals(kexState.get())) {
+ return null;
+ }
+
+ PendingWriteFuture pending = new PendingWriteFuture(cmdName, buffer);
+ pendingPackets.add(pending);
+ numPending.set(pendingPackets.size());
+ return pending;
+ });
+ } catch (Exception e) {
+ log.error("enqueuePendingPacket(" + this + ")[" + cmdName + "] failed to generate future", e);
- future = new PendingWriteFuture(cmdName, buffer);
- pendingPackets.add(future);
- numPending = pendingPackets.size();
+ if (e instanceof IOException) {
+ throw (IOException) e;
+ } else if (e instanceof RuntimeException) {
+ throw (RuntimeException) e;
+ } else {
+ throw new RuntimeSshException(e);
+ }
}
if (log.isDebugEnabled()) {
- if (numPending == 1) {
+ if (numPending.get() == 1) {
log.debug("enqueuePendingPacket({})[{}] Start flagging packets as pending until key exchange is done", this,
cmdName);
} else {