You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2015/06/14 13:54:26 UTC

mina-sshd git commit: [SSHD-490] Use an enum to follow the KEX state

Repository: mina-sshd
Updated Branches:
  refs/heads/master d44c35824 -> 9fc3a98ab


[SSHD-490] Use an enum to follow the KEX state


Project: http://git-wip-us.apache.org/repos/asf/mina-sshd/repo
Commit: http://git-wip-us.apache.org/repos/asf/mina-sshd/commit/9fc3a98a
Tree: http://git-wip-us.apache.org/repos/asf/mina-sshd/tree/9fc3a98a
Diff: http://git-wip-us.apache.org/repos/asf/mina-sshd/diff/9fc3a98a

Branch: refs/heads/master
Commit: 9fc3a98ab147228a4cb3a449ceb5cb6fd50a91a5
Parents: d44c358
Author: Lyor Goldstein <lg...@vmware.com>
Authored: Sun Jun 14 14:54:14 2015 +0300
Committer: Lyor Goldstein <lg...@vmware.com>
Committed: Sun Jun 14 14:54:14 2015 +0300

----------------------------------------------------------------------
 .../sshd/client/session/ClientSessionImpl.java  |   7 +-
 .../org/apache/sshd/common/SshConstants.java    |   1 +
 .../org/apache/sshd/common/kex/KexState.java    |  38 ++++++
 .../sshd/common/session/AbstractSession.java    | 117 +++++++++++--------
 .../sshd/server/session/ServerSession.java      |  29 +++--
 5 files changed, 131 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/9fc3a98a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
index b8a143c..bc7cf74 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSessionImpl.java
@@ -58,6 +58,7 @@ import org.apache.sshd.common.future.DefaultSshFuture;
 import org.apache.sshd.common.future.SshFuture;
 import org.apache.sshd.common.io.IoSession;
 import org.apache.sshd.common.kex.KexProposalOption;
+import org.apache.sshd.common.kex.KexState;
 import org.apache.sshd.common.scp.ScpTransferEventListener;
 import org.apache.sshd.common.session.AbstractConnectionService;
 import org.apache.sshd.common.session.AbstractSession;
@@ -142,7 +143,7 @@ public class ClientSessionImpl extends AbstractSession implements ClientSession
         authFuture = new DefaultAuthFuture(lock);
         authFuture.setAuthed(false);
         sendClientIdentification();
-        kexState.set(KEX_STATE_INIT);
+        kexState.set(KexState.INIT);
         sendKexInit();
     }
 
@@ -270,7 +271,7 @@ public class ClientSessionImpl extends AbstractSession implements ClientSession
                 || !((AbstractConnectionService) currentService).getChannels().isEmpty()) {
             throw new IllegalStateException("The switch to the none cipher must be done immediately after authentication");
         }
-        if (kexState.compareAndSet(KEX_STATE_DONE, KEX_STATE_INIT)) {
+        if (kexState.compareAndSet(KexState.DONE, KexState.INIT)) {
             reexchangeFuture = new DefaultSshFuture(null);
             
             String c2sEncServer = serverProposal.get(KexProposalOption.C2SENC);
@@ -442,7 +443,7 @@ public class ClientSessionImpl extends AbstractSession implements ClientSession
                 if (authed) { // authFuture.isSuccess()
                     cond |= AUTHED;
                 }
-                if (kexState.get() == KEX_STATE_DONE && authFuture.isFailure()) {
+                if (KexState.DONE.equals(kexState.get()) && authFuture.isFailure()) {
                     cond |= WAIT_AUTH;
                 }
                 if ((cond & mask) != 0) {

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/9fc3a98a/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java b/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
index 91ce971..008e66a 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
@@ -37,6 +37,7 @@ public interface SshConstants {
     static final byte SSH_MSG_SERVICE_REQUEST=                 5;
     static final byte SSH_MSG_SERVICE_ACCEPT=                  6;
     static final byte SSH_MSG_KEXINIT=                        20;
+        static final int MSG_KEX_COOKIE_SIZE = 16;
     static final byte SSH_MSG_NEWKEYS=                        21;
 
     static final byte SSH_MSG_KEX_FIRST=                      30;

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/9fc3a98a/sshd-core/src/main/java/org/apache/sshd/common/kex/KexState.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/KexState.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/KexState.java
new file mode 100644
index 0000000..47756b8
--- /dev/null
+++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/KexState.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sshd.common.kex;
+
+import java.util.Collections;
+import java.util.EnumSet;
+import java.util.Set;
+
+/**
+ * Used to track the key-exchange (KEX) protocol progression
+ * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
+ */
+public enum KexState {
+    UNKNOWN,
+    INIT,
+    RUN,
+    KEYS,
+    DONE;
+    
+    public static final Set<KexState> VALUES = Collections.unmodifiableSet(EnumSet.allOf(KexState.class));
+}

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/9fc3a98a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
index 77bace2..5b60279 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
@@ -39,7 +39,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
-import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import org.apache.sshd.common.Closeable;
@@ -59,6 +58,7 @@ import org.apache.sshd.common.future.SshFutureListener;
 import org.apache.sshd.common.io.IoSession;
 import org.apache.sshd.common.io.IoWriteFuture;
 import org.apache.sshd.common.kex.KexProposalOption;
+import org.apache.sshd.common.kex.KexState;
 import org.apache.sshd.common.kex.KeyExchange;
 import org.apache.sshd.common.mac.Mac;
 import org.apache.sshd.common.random.Random;
@@ -93,11 +93,6 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
      */
     public static final String SESSION = "org.apache.sshd.session";
 
-    protected static final int KEX_STATE_INIT = 1;
-    protected static final int KEX_STATE_RUN =  2;
-    protected static final int KEX_STATE_KEYS = 3;
-    protected static final int KEX_STATE_DONE = 4;
-
     /** Client or server side */
     protected final boolean isServer;
     /** The factory manager used to retrieve factories of Ciphers, Macs and other objects */
@@ -127,7 +122,7 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
     protected byte[] I_C; // the payload of the client's SSH_MSG_KEXINIT
     protected byte[] I_S; // the payload of the factoryManager's SSH_MSG_KEXINIT
     protected KeyExchange kex;
-    protected final AtomicInteger kexState = new AtomicInteger();
+    protected final AtomicReference<KexState> kexState = new AtomicReference<KexState>(KexState.UNKNOWN);
     @SuppressWarnings("rawtypes")
     protected DefaultSshFuture reexchangeFuture;
 
@@ -373,9 +368,7 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
             case SSH_MSG_SERVICE_REQUEST:
                 String service = buffer.getString();
                 log.debug("Received SSH_MSG_SERVICE_REQUEST '{}'", service);
-                if (kexState.get() != KEX_STATE_DONE) {
-                    throw new IllegalStateException("Received command " + cmd + " before key exchange is finished");
-                }
+                validateKexState(cmd, KexState.DONE);
                 try {
                     startService(service);
                 } catch (Exception e) {
@@ -390,17 +383,15 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
                 break;
             case SSH_MSG_SERVICE_ACCEPT:
                 log.debug("Received SSH_MSG_SERVICE_ACCEPT");
-                if (kexState.get() != KEX_STATE_DONE) {
-                    throw new IllegalStateException("Received command " + cmd + " before key exchange is finished");
-                }
+                validateKexState(cmd, KexState.DONE);
                 serviceAccept();
                 break;
             case SSH_MSG_KEXINIT:
                 log.debug("Received SSH_MSG_KEXINIT");
                 receiveKexInit(buffer);
-                if (kexState.compareAndSet(KEX_STATE_DONE, KEX_STATE_RUN)) {
+                if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) {
                     sendKexInit();
-                } else if (!kexState.compareAndSet(KEX_STATE_INIT, KEX_STATE_RUN)) {
+                } else if (!kexState.compareAndSet(KexState.INIT, KexState.RUN)) {
                     throw new IllegalStateException("Received SSH_MSG_KEXINIT while key exchange is running");
                 }
 
@@ -417,9 +408,7 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
                 break;
             case SSH_MSG_NEWKEYS:
                 log.debug("Received SSH_MSG_NEWKEYS");
-                if (kexState.get() != KEX_STATE_KEYS) {
-                    throw new IllegalStateException("Received command " + cmd + " before key exchange is finished");
-                }
+                validateKexState(cmd, KexState.KEYS);
                 receiveNewKeys();
                 if (reexchangeFuture != null) {
                     reexchangeFuture.setValue(Boolean.TRUE);
@@ -435,22 +424,20 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
                             }
                         }
                     }
-                    kexState.set(KEX_STATE_DONE);
+                    kexState.set(KexState.DONE);
                 }
                 synchronized (lock) {
                     lock.notifyAll();
                 }
                 break;
             default:
-                if (cmd >= SshConstants.SSH_MSG_KEX_FIRST && cmd <= SshConstants.SSH_MSG_KEX_LAST) {
-                    if (kexState.get() != KEX_STATE_RUN) {
-                        throw new IllegalStateException("Received kex command " + cmd + " while not in key exchange");
-                    }
+                if ((cmd >= SshConstants.SSH_MSG_KEX_FIRST) && (cmd <= SshConstants.SSH_MSG_KEX_LAST)) {
+                    validateKexState(cmd, KexState.RUN);
                     buffer.rpos(buffer.rpos() - 1);
                     if (kex.next(buffer)) {
                         checkKeys();
                         sendNewKeys();
-                        kexState.set(KEX_STATE_KEYS);
+                        kexState.set(KexState.KEYS);
                     }
                 } else if (currentService != null) {
                     currentService.process(cmd, buffer);
@@ -463,8 +450,15 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
         checkRekey();
     }
 
+    protected void validateKexState(byte cmd, KexState expected) {
+        KexState actual = kexState.get();
+        if (!expected.equals(actual)) {
+            throw new IllegalStateException("Received KEX command=" + cmd + " while in state=" + actual + " instead of " + expected);
+        }
+    }
+
     /**
-     * Handle any exceptions that occured on this session.
+     * Handle any exceptions that occurred on this session.
      * The session will be closed and a disconnect packet will be
      * sent before if the given exception is an
      * {@link org.apache.sshd.common.SshException}.
@@ -480,17 +474,20 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
             }
         }
         log.warn("Exception caught", t);
-        try {
-            if (t instanceof SshException) {
-                int code = ((SshException) t).getDisconnectCode();
-                if (code > 0) {
+        if (t instanceof SshException) {
+            int code = ((SshException) t).getDisconnectCode();
+            if (code > 0) {
+                try {
                     disconnect(code, t.getMessage());
-                    return;
+                } catch (Throwable t2) {
+                    if (log.isDebugEnabled()) {
+                        log.debug("Exception while disconnect with code=" + code, t2);
+                    }
                 }
+                return;
             }
-        } catch (Throwable t2) {
-            // Ignore
         }
+
         close(true);
     }
 
@@ -535,11 +532,11 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
     @Override
     public IoWriteFuture writePacket(Buffer buffer) throws IOException {
         // While exchanging key, queue high level packets
-        if (kexState.get() != KEX_STATE_DONE) {
+        if (!KexState.DONE.equals(kexState.get())) {
             byte cmd = buffer.array()[buffer.rpos()];
             if (cmd > SshConstants.SSH_MSG_KEX_LAST) {
                 synchronized (pendingPackets) {
-                    if (kexState.get() != KEX_STATE_DONE) {
+                    if (!KexState.DONE.equals(kexState.get())) {
                         if (pendingPackets.isEmpty()) {
                             log.debug("Start flagging packets as pending until key exchange is done");
                         }
@@ -891,7 +888,7 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
             if (server || str.startsWith("SSH-")) {
                 return str;
             }
-            if (buffer.rpos() > 16 * 1024) {
+            if (buffer.rpos() > (16 * 1024)) {
                 throw new IllegalStateException("Incorrect identification: too many header lines");
             }
         }
@@ -943,18 +940,22 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
         log.debug("Send SSH_MSG_KEXINIT");
         Buffer buffer = createBuffer(SshConstants.SSH_MSG_KEXINIT);
         int p = buffer.wpos();
-        buffer.wpos(p + 16);
-        random.fill(buffer.array(), p, 16);
+        buffer.wpos(p + SshConstants.MSG_KEX_COOKIE_SIZE);
+        random.fill(buffer.array(), p, SshConstants.MSG_KEX_COOKIE_SIZE);
+        if (log.isTraceEnabled()) {
+            log.trace("sendKexInit(" + toString() + ") cookie=" + BufferUtils.printHex(buffer.array(), p, SshConstants.MSG_KEX_COOKIE_SIZE, ':'));
+        }
+
         for (KexProposalOption paramType : KexProposalOption.VALUES) {
             String s = proposal.get(paramType);
             if (log.isTraceEnabled()) {
-                log.trace("sendKexInit(" + paramType.getDescription() + ") " + s);
+                log.trace("sendKexInit(" + toString() + ")[" + paramType.getDescription() + "] " + s);
             }
             buffer.putString(GenericUtils.trimToEmpty(s));
         }
 
-        buffer.putByte((byte) 0);
-        buffer.putInt(0);
+        buffer.putBoolean(false);   // first kex packet follows
+        buffer.putInt(0);   // reserved (FFU)
         byte[] data = buffer.getCompactData();
         writePacket(buffer);
         return data;
@@ -970,27 +971,43 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
      */
     protected byte[] receiveKexInit(Buffer buffer, Map<KexProposalOption,String> proposal) {
         // Recreate the packet payload which will be needed at a later time
-        int size = 22;
         byte[] d = buffer.array();
-        byte[] data = new byte[buffer.available() + 1];
+        byte[] data = new byte[buffer.available() + 1 /* the opcode */];
         data[0] = SshConstants.SSH_MSG_KEXINIT;
-        System.arraycopy(d, buffer.rpos(), data, 1, data.length - 1);
-        // Skip 16 bytes of random data
-        buffer.rpos(buffer.rpos() + 16);
+        
+        int size = 6, cookieStartPos = buffer.rpos();
+        System.arraycopy(d, cookieStartPos, data, 1, data.length - 1);
+        // Skip random cookie data
+        buffer.rpos(cookieStartPos + SshConstants.MSG_KEX_COOKIE_SIZE);
+        size += SshConstants.MSG_KEX_COOKIE_SIZE;
+        if (log.isTraceEnabled()) {
+            log.trace("receiveKexInit(" + toString() + ") cookie=" + BufferUtils.printHex(d, cookieStartPos, SshConstants.MSG_KEX_COOKIE_SIZE, ':'));
+        }
+
         // Read proposal
         for (KexProposalOption paramType : KexProposalOption.VALUES) {
             int lastPos = buffer.rpos();
             String value = buffer.getString();
             if (log.isTraceEnabled()) {
-                log.trace("receiveKexInit(" + paramType.getDescription() + ") " + value);
+                log.trace("receiveKexInit(" + toString() + ")[" + paramType.getDescription() + "] " + value);
             }
             int curPos = buffer.rpos(), readLen = curPos - lastPos;
             proposal.put(paramType, value);
             size += readLen;
         }
-        // Skip 5 bytes
-        buffer.getByte();
-        buffer.getInt();
+        
+        boolean firstKexPacketFollows = buffer.getBoolean();
+        if (log.isTraceEnabled()) {
+            log.trace("receiveKexInit(" + toString() + ") first kex packet follows: " + firstKexPacketFollows);
+        }
+
+        long reserved = buffer.getUInt();
+        if (reserved != 0) {
+            if (log.isTraceEnabled()) {
+                log.trace("receiveKexInit(" + toString() + ") non-zero reserved value: " + reserved);
+            }
+        }
+
         // Return data
         byte[] dataShrinked = new byte[size];
         System.arraycopy(data, 0, dataShrinked, 0, size);
@@ -1377,7 +1394,7 @@ public abstract class AbstractSession extends CloseableUtils.AbstractInnerClosea
     @Override
     @SuppressWarnings("rawtypes")
     public SshFuture reExchangeKeys() throws IOException {
-        if (kexState.compareAndSet(KEX_STATE_DONE, KEX_STATE_INIT)) {
+        if (kexState.compareAndSet(KexState.DONE, KexState.INIT)) {
             log.info("Initiating key re-exchange");
             sendKexInit();
             reexchangeFuture = new DefaultSshFuture(null);

http://git-wip-us.apache.org/repos/asf/mina-sshd/blob/9fc3a98a/sshd-core/src/main/java/org/apache/sshd/server/session/ServerSession.java
----------------------------------------------------------------------
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/session/ServerSession.java b/sshd-core/src/main/java/org/apache/sshd/server/session/ServerSession.java
index d67a023..0639c4b 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/session/ServerSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/session/ServerSession.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 import java.security.KeyPair;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 import org.apache.sshd.common.FactoryManager;
 import org.apache.sshd.common.FactoryManagerUtils;
@@ -30,9 +31,11 @@ import org.apache.sshd.common.ServiceFactory;
 import org.apache.sshd.common.SshConstants;
 import org.apache.sshd.common.SshException;
 import org.apache.sshd.common.future.SshFutureListener;
+import org.apache.sshd.common.io.IoService;
 import org.apache.sshd.common.io.IoSession;
 import org.apache.sshd.common.io.IoWriteFuture;
 import org.apache.sshd.common.kex.KexProposalOption;
+import org.apache.sshd.common.kex.KexState;
 import org.apache.sshd.common.keyprovider.KeyPairProvider;
 import org.apache.sshd.common.session.AbstractSession;
 import org.apache.sshd.common.util.GenericUtils;
@@ -85,7 +88,7 @@ public class ServerSession extends AbstractSession {
 
     @Override
     protected void checkRekey() throws IOException {
-        if (kexState.get() == KEX_STATE_DONE) {
+        if (KexState.DONE.equals(kexState.get())) {
             if (   inPackets > MAX_PACKETS || outPackets > MAX_PACKETS
                 || inBytes > maxBytes || outBytes > maxBytes
                 || maxKeyInterval > 0 && System.currentTimeMillis() - lastKeyTime > maxKeyInterval)
@@ -156,7 +159,7 @@ public class ServerSession extends AbstractSession {
     @Override
     protected boolean readIdentification(Buffer buffer) throws IOException {
         clientVersion = doReadIdentification(buffer, true);
-        if (clientVersion == null) {
+        if (GenericUtils.isEmpty(clientVersion)) {
             return false;
         }
         log.debug("Client version string: {}", clientVersion);
@@ -170,7 +173,7 @@ public class ServerSession extends AbstractSession {
             });
             throw new SshException(msg);
         } else {
-            kexState.set(KEX_STATE_INIT);
+            kexState.set(KexState.INIT);
             sendKexInit();
         }
         return true;
@@ -195,15 +198,25 @@ public class ServerSession extends AbstractSession {
      * @return The current number of live <code>SshSession</code> objects associated with the user
      */
     protected int getActiveSessionCountForUser(String userName) {
+        IoService service = ioSession.getService();
+        Map<?, IoSession> sessionsMap = service.getManagedSessions();
+        if (GenericUtils.isEmpty(sessionsMap)) {
+            return 0;
+        }
+
         int totalCount = 0;
-        for (IoSession is : ioSession.getService().getManagedSessions().values()) {
+        for (IoSession is : sessionsMap.values()) {
             ServerSession session = (ServerSession) getSession(is, true);
-            if (session != null) {
-                if (session.getUsername() != null && session.getUsername().equals(userName)) {
-                    totalCount++;
-                }
+            if (session == null) {
+                continue;
+            }
+            
+            String sessionUser = session.getUsername();
+            if ((!GenericUtils.isEmpty(sessionUser)) && Objects.equals(sessionUser, userName)) {
+                totalCount++;
             }
         }
+
         return totalCount;
     }