You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2020/05/15 08:11:38 UTC

[mina-sshd] 01/01: [SSHD-966] Using same lock to synchronize session pending packets and ChannelOutputStream mutual exclusion

This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch SSHD-966
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit 25e7e13c95ada68f3def91dd153b44fd0127971b
Author: Lyor Goldstein <lg...@apache.org>
AuthorDate: Fri May 15 10:24:50 2020 +0300

    [SSHD-966] Using same lock to synchronize session pending packets and ChannelOutputStream mutual exclusion
---
 .../sshd/common/channel/ChannelOutputStream.java   | 253 ++++++++++++++-------
 .../common/session/helpers/AbstractSession.java    |  98 ++++++--
 2 files changed, 245 insertions(+), 106 deletions(-)

diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
index d0d879b..14f7e91 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
@@ -26,11 +26,14 @@ import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import org.apache.sshd.common.FactoryManager;
+import org.apache.sshd.common.RuntimeSshException;
 import org.apache.sshd.common.SshConstants;
 import org.apache.sshd.common.SshException;
 import org.apache.sshd.common.channel.exception.SshChannelClosedException;
 import org.apache.sshd.common.io.PacketWriter;
 import org.apache.sshd.common.session.Session;
+import org.apache.sshd.common.session.helpers.AbstractSession;
 import org.apache.sshd.common.util.ValidateUtils;
 import org.apache.sshd.common.util.buffer.Buffer;
 import org.slf4j.Logger;
@@ -103,24 +106,62 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
     }
 
     @Override
-    public synchronized void write(int w) throws IOException {
-        b[0] = (byte) w;
-        write(b, 0, 1);
+    public void write(int w) throws IOException {
+        try {
+            Channel channel = getChannel();
+            Session session = channel.getSession();
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(1), () -> {
+                        b[0] = (byte) w;
+                        lockedWrite(session, channel, b, 0, 1);
+                        return null;
+                    });
+        } catch (Exception e) {
+            log.error("write(" + this + ") value=0x" + Integer.toHexString(w) + " failed to write", e);
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
+        }
     }
 
     @Override
-    public synchronized void write(byte[] buf, int s, int l) throws IOException {
-        Channel channel = getChannel();
+    public void write(byte[] buf, int startOffset, int dataLen) throws IOException {
+        try {
+            Channel channel = getChannel();
+            Session session = channel.getSession();
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(dataLen), () -> {
+                        lockedWrite(session, channel, buf, startOffset, dataLen);
+                        return null;
+                    });
+        } catch (Exception e) {
+            log.error("write(" + this + ") len=" + dataLen + " failed to write", e);
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
+        }
+    }
+
+    protected void lockedWrite(
+            Session session, Channel channel, byte[] buf, int startOffset, int dataLen)
+            throws Exception {
         if (!isOpen()) {
             throw new SshChannelClosedException(
                     channel.getId(),
-                    "write(" + this + ") len=" + l + " - channel already closed");
+                    "lockedWrite(" + this + ") len=" + dataLen + " - channel already closed");
         }
 
-        Session session = channel.getSession();
         boolean debugEnabled = log.isDebugEnabled();
         boolean traceEnabled = log.isTraceEnabled();
-        while (l > 0) {
+        while (dataLen > 0) {
             // The maximum amount we should admit without flushing again
             // is enough to make up one full packet within our allowed
             // window size. We give ourselves a credit equal to the last
@@ -128,31 +169,31 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
             // out the next packet before we block and wait for space to
             // become available again.
             long minReqLen = Math.min(remoteWindow.getSize() + lastSize, remoteWindow.getPacketSize());
-            long l2 = Math.min(l, minReqLen - bufferLength);
+            long l2 = Math.min(dataLen, minReqLen - bufferLength);
             if (l2 <= 0) {
                 if (bufferLength > 0) {
-                    flush();
+                    lockedFlush(session, channel);
                 } else {
                     session.resetIdleTimeout();
                     try {
                         long available = remoteWindow.waitForSpace(maxWaitTimeout);
                         if (traceEnabled) {
-                            log.trace("write({}) len={} - available={}", this, l, available);
+                            log.trace("lockedWrite({}) len={} - available={}", this, dataLen, available);
                         }
                     } catch (IOException e) {
-                        log.error("write({}) failed ({}) to wait for space of len={}: {}",
-                                this, e.getClass().getSimpleName(), l, e.getMessage());
+                        log.error("lockedWrite({}) failed ({}) to wait for space of len={}: {}",
+                                this, e.getClass().getSimpleName(), dataLen, e.getMessage());
 
                         if ((e instanceof WindowClosedException) && (!closedState.getAndSet(true))) {
                             if (debugEnabled) {
-                                log.debug("write({})[len={}] closing due to window closed", this, l);
+                                log.debug("lockedWrite({})[len={}] closing due to window closed", this, dataLen);
                             }
                         }
 
                         throw e;
                     } catch (InterruptedException e) {
                         throw (IOException) new InterruptedIOException(
-                                "Interrupted while waiting for remote space on write len=" + l + " to " + this)
+                                "Interrupted while waiting for remote space on write len=" + dataLen + " to " + this)
                                         .initCause(e);
                     }
                 }
@@ -162,81 +203,30 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
 
             ValidateUtils.checkTrue(l2 <= Integer.MAX_VALUE,
                     "Accumulated bytes length exceeds int boundary: %d", l2);
-            buffer.putRawBytes(buf, s, (int) l2);
+            buffer.putRawBytes(buf, startOffset, (int) l2);
             bufferLength += l2;
-            s += l2;
-            l -= l2;
+            startOffset += l2;
+            dataLen -= l2;
         }
 
         if (isNoDelay()) {
-            flush();
+            lockedFlush(session, channel);
         } else {
             session.resetIdleTimeout();
         }
     }
 
     @Override
-    public synchronized void flush() throws IOException {
-        AbstractChannel channel = getChannel();
-        if (!isOpen()) {
-            throw new SshChannelClosedException(
-                    channel.getId(),
-                    "flush(" + this + ") length=" + bufferLength + " - stream is already closed");
-        }
-
+    public void flush() throws IOException {
+        Channel channel = getChannel();
+        Session session = channel.getSession();
         try {
-            Session session = channel.getSession();
-            boolean traceEnabled = log.isTraceEnabled();
-            while (bufferLength > 0) {
-                session.resetIdleTimeout();
-
-                Buffer buf = buffer;
-                long total = bufferLength;
-                long available;
-                try {
-                    available = remoteWindow.waitForSpace(maxWaitTimeout);
-                    if (traceEnabled) {
-                        log.trace("flush({}) len={}, available={}", this, total, available);
-                    }
-                } catch (IOException e) {
-                    log.error("flush({}) failed ({}) to wait for space of len={}: {}",
-                            this, e.getClass().getSimpleName(), total, e.getMessage());
-                    if (log.isDebugEnabled()) {
-                        log.error("flush(" + this + ") wait for space len=" + total + " exception details", e);
-                    }
-                    throw e;
-                }
-
-                long lenToSend = Math.min(available, total);
-                long length = Math.min(lenToSend, remoteWindow.getPacketSize());
-                if (length > Integer.MAX_VALUE) {
-                    throw new StreamCorruptedException(
-                            "Accumulated " + SshConstants.getCommandMessageName(cmd)
-                                                       + " command bytes size (" + length + ") exceeds int boundaries");
-                }
-
-                int pos = buf.wpos();
-                buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10);
-                buf.putInt(length);
-                buf.wpos(buf.wpos() + (int) length);
-                if (total == length) {
-                    newBuffer((int) length);
-                } else {
-                    long leftover = total - length;
-                    newBuffer((int) Math.max(leftover, length));
-                    buffer.putRawBytes(buf.array(), pos - (int) leftover, (int) leftover);
-                    bufferLength = (int) leftover;
-                }
-                lastSize = (int) length;
-
-                session.resetIdleTimeout();
-                remoteWindow.waitAndConsume(length, maxWaitTimeout);
-                if (traceEnabled) {
-                    log.trace("flush({}) send {} len={}",
-                            channel, SshConstants.getCommandMessageName(cmd), length);
-                }
-                packetWriter.writePacket(buf);
-            }
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(bufferLength),
+                    () -> {
+                        lockedFlush(session, channel);
+                        return null;
+                    });
         } catch (WindowClosedException e) {
             if (!closedState.getAndSet(true)) {
                 if (log.isDebugEnabled()) {
@@ -245,8 +235,11 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
             }
             throw e;
         } catch (Exception e) {
+            log.error("flush(" + this + ") failed", e);
             if (e instanceof IOException) {
                 throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
             } else if (e instanceof InterruptedException) {
                 throw (IOException) new InterruptedIOException(
                         "Interrupted while waiting for remote space flush len=" + bufferLength + " to " + this)
@@ -257,21 +250,87 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
         }
     }
 
-    @Override
-    public synchronized void close() throws IOException {
+    protected void lockedFlush(Session session, Channel channel) throws Exception {
+        boolean traceEnabled = log.isTraceEnabled();
         if (!isOpen()) {
+            if (bufferLength > 0) {
+                throw new SshChannelClosedException(
+                        channel.getId(),
+                        "lockedFlush(" + this + ") length=" + bufferLength + " - stream is already closed");
+            }
+
+            if (traceEnabled) {
+                log.trace("lockedFlush({}) nothing to flush", this);
+            }
             return;
         }
 
+        while (bufferLength > 0) {
+            session.resetIdleTimeout();
+
+            Buffer buf = buffer;
+            long total = bufferLength;
+            long available;
+            try {
+                available = remoteWindow.waitForSpace(maxWaitTimeout);
+                if (traceEnabled) {
+                    log.trace("lockedFlush({}) len={}, available={}", this, total, available);
+                }
+            } catch (IOException e) {
+                log.error("lockedFlush({}) failed ({}) to wait for space of len={}: {}",
+                        this, e.getClass().getSimpleName(), total, e.getMessage());
+                if (log.isDebugEnabled()) {
+                    log.error("lockedFlush(" + this + ") wait for space len=" + total + " exception details", e);
+                }
+                throw e;
+            }
+
+            long lenToSend = Math.min(available, total);
+            long length = Math.min(lenToSend, remoteWindow.getPacketSize());
+            if (length > Integer.MAX_VALUE) {
+                throw new StreamCorruptedException(
+                        "Accumulated " + SshConstants.getCommandMessageName(cmd)
+                                                   + " command bytes size (" + length
+                                                   + ") exceeds int boundaries");
+            }
+
+            int pos = buf.wpos();
+            buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10);
+            buf.putInt(length);
+            buf.wpos(buf.wpos() + (int) length);
+            if (total == length) {
+                newBuffer((int) length);
+            } else {
+                long leftover = total - length;
+                newBuffer((int) Math.max(leftover, length));
+                buffer.putRawBytes(buf.array(), pos - (int) leftover, (int) leftover);
+                bufferLength = (int) leftover;
+            }
+            lastSize = (int) length;
+
+            session.resetIdleTimeout();
+            remoteWindow.waitAndConsume(length, maxWaitTimeout);
+            if (traceEnabled) {
+                log.trace("lockedFlush({}) send len={}", this, length);
+            }
+            packetWriter.writePacket(buf);
+        }
+    }
+
+    protected long getExtraPendingPacketLockWaitTime(int dataSize) {
+        // TODO see if can do anything better
+        return Math.min(dataSize, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT) + maxWaitTimeout;
+    }
+
+    protected void lockedClose(Session session, AbstractChannel channel) throws Exception {
         if (log.isTraceEnabled()) {
-            log.trace("close({}) closing", this);
+            log.trace("lockedClose({}) closing", this);
         }
 
         try {
-            flush();
+            lockedFlush(session, channel);
 
             if (isEofOnClose()) {
-                AbstractChannel channel = getChannel();
                 channel.sendEof();
             }
         } finally {
@@ -285,6 +344,32 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe
         }
     }
 
+    @Override
+    public void close() throws IOException {
+        if (!isOpen()) {
+            return;
+        }
+
+        AbstractChannel channel = getChannel();
+        Session session = channel.getSession();
+        try {
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(bufferLength),
+                    () -> {
+                        lockedClose(session, channel);
+                        return null;
+                    });
+        } catch (Exception e) {
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
+        }
+    }
+
     protected void newBuffer(int size) {
         Channel channel = getChannel();
         Session session = channel.getSession();
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
index 47a56b9..849a561 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
@@ -35,10 +35,15 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Queue;
+import java.util.concurrent.Callable;
 import java.util.concurrent.CopyOnWriteArraySet;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.logging.Level;
 
 import org.apache.sshd.common.Closeable;
@@ -187,6 +192,7 @@ public abstract class AbstractSession extends SessionHelper {
     protected long maxRekyPackets = FactoryManager.DEFAULT_REKEY_PACKETS_LIMIT;
     protected long maxRekeyBytes = FactoryManager.DEFAULT_REKEY_BYTES_LIMIT;
     protected long maxRekeyInterval = FactoryManager.DEFAULT_REKEY_TIME_LIMIT;
+    protected final Lock pendingPacketsLock = new ReentrantLock();
     protected final Queue<PendingWriteFuture> pendingPackets = new LinkedList<>();
 
     protected Service currentService;
@@ -656,6 +662,38 @@ public abstract class AbstractSession extends SessionHelper {
         doKexNegotiation();
     }
 
+    /**
+     * Attempts to lock the pending packets access and execute the relevant code. Max. wait time is derived from the
+     * current number of pending packets
+     *
+     * @param  <V>       The executed code return value
+     * @param  extraWait An extra amount of time (msec.) that the caller is willing to wait beyond the time derived from
+     *                   the number of pending packets. <B>Note:</B> a hardcoded max. value of
+     *                   {@link FactoryManager#DEFAULT_AUTH_TIMEOUT} is imposed on the total calculated time.
+     * @param  executor  The code to execute under lock
+     * @return           The executed code result
+     * @throws Exception If failed to lock or exception thrown by executor code
+     * @see              <A HREF="https://issues.apache.org/jira/browse/SSHD-966">SSHD-966</A>
+     */
+    public <V> V executeUnderPendingPacketsLock(long extraWait, Callable<? extends V> executor) throws Exception {
+        ValidateUtils.checkTrue(extraWait >= 0L, "Invalid extra wait time: %d", extraWait);
+        int numPending = pendingPackets.size();
+        long maxWait = numPending * FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT + extraWait;
+        // in case zero
+        maxWait = Math.max(maxWait, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT);
+        // in case lots of pending packets or large extra time
+        maxWait = Math.min(maxWait, FactoryManager.DEFAULT_AUTH_TIMEOUT);
+        if (!pendingPacketsLock.tryLock(maxWait, TimeUnit.MILLISECONDS)) {
+            throw new TimeoutException("Failed to acquire " + numPending + " pending packets lock");
+        }
+
+        try {
+            return executor.call();
+        } finally {
+            pendingPacketsLock.unlock();
+        }
+    }
+
     protected void doKexNegotiation() throws Exception {
         if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) {
             sendKexInit();
@@ -669,9 +707,10 @@ public abstract class AbstractSession extends SessionHelper {
         KeyExchangeFactory kexFactory = NamedResource.findByName(
                 kexAlgorithm, String.CASE_INSENSITIVE_ORDER, kexFactories);
         ValidateUtils.checkNotNull(kexFactory, "Unknown negotiated KEX algorithm: %s", kexAlgorithm);
-        synchronized (pendingPackets) {
+        executeUnderPendingPacketsLock(0L, () -> {
             kex = kexFactory.createKeyExchange(this);
-        }
+            return kex;
+        });
 
         byte[] v_s = serverVersion.getBytes(StandardCharsets.UTF_8);
         byte[] v_c = clientVersion.getBytes(StandardCharsets.UTF_8);
@@ -707,12 +746,13 @@ public abstract class AbstractSession extends SessionHelper {
 
         signalSessionEvent(SessionListener.Event.KeyEstablished);
 
-        Collection<? extends Map.Entry<? extends SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites;
-        synchronized (pendingPackets) {
-            pendingWrites = sendPendingPackets(pendingPackets);
-            kex = null; // discard and GC since KEX is completed
-            kexState.set(KexState.DONE);
-        }
+        Collection<? extends Map.Entry<? extends SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites
+                = executeUnderPendingPacketsLock(FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT, () -> {
+                    List<Map.Entry<PendingWriteFuture, IoWriteFuture>> result = sendPendingPackets(pendingPackets);
+                    kex = null; // discard and GC since KEX is completed
+                    kexState.set(KexState.DONE);
+                    return result;
+                });
 
         int pendingCount = pendingWrites.size();
         if (pendingCount > 0) {
@@ -734,7 +774,7 @@ public abstract class AbstractSession extends SessionHelper {
         }
     }
 
-    protected List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> sendPendingPackets(
+    protected List<Map.Entry<PendingWriteFuture, IoWriteFuture>> sendPendingPackets(
             Queue<PendingWriteFuture> packetsQueue)
             throws IOException {
         if (GenericUtils.isEmpty(packetsQueue)) {
@@ -742,7 +782,7 @@ public abstract class AbstractSession extends SessionHelper {
         }
 
         int numPending = packetsQueue.size();
-        List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new ArrayList<>(numPending);
+        List<Map.Entry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new ArrayList<>(numPending);
         synchronized (encodeLock) {
             for (PendingWriteFuture future = packetsQueue.poll();
                  future != null;
@@ -872,10 +912,11 @@ public abstract class AbstractSession extends SessionHelper {
      * Checks if key-exchange is done - if so, or the packet is related to the key-exchange protocol, then allows the
      * packet to go through, otherwise enqueues it to be sent when key-exchange completed
      *
-     * @param  buffer The {@link Buffer} containing the packet to be sent
-     * @return        A {@link PendingWriteFuture} if enqueued, {@code null} if packet can go through.
+     * @param  buffer      The {@link Buffer} containing the packet to be sent
+     * @return             A {@link PendingWriteFuture} if enqueued, {@code null} if packet can go through.
+     * @throws IOException If failed to enqueue
      */
-    protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) {
+    protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) throws IOException {
         if (KexState.DONE.equals(kexState.get())) {
             return null;
         }
@@ -887,20 +928,33 @@ public abstract class AbstractSession extends SessionHelper {
         }
 
         String cmdName = SshConstants.getCommandMessageName(cmd);
+        AtomicInteger numPending = new AtomicInteger();
         PendingWriteFuture future;
-        int numPending;
-        synchronized (pendingPackets) {
-            if (KexState.DONE.equals(kexState.get())) {
-                return null;
-            }
+        try {
+            future = executeUnderPendingPacketsLock(0L, () -> {
+                if (KexState.DONE.equals(kexState.get())) {
+                    return null;
+                }
+
+                PendingWriteFuture pending = new PendingWriteFuture(cmdName, buffer);
+                pendingPackets.add(pending);
+                numPending.set(pendingPackets.size());
+                return pending;
+            });
+        } catch (Exception e) {
+            log.error("enqueuePendingPacket(" + this + ")[" + cmdName + "] failed to generate future", e);
 
-            future = new PendingWriteFuture(cmdName, buffer);
-            pendingPackets.add(future);
-            numPending = pendingPackets.size();
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
         }
 
         if (log.isDebugEnabled()) {
-            if (numPending == 1) {
+            if (numPending.get() == 1) {
                 log.debug("enqueuePendingPacket({})[{}] Start flagging packets as pending until key exchange is done", this,
                         cmdName);
             } else {