You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2020/04/23 14:33:15 UTC

[mina-sshd] branch master updated: [SSHD-982] Fix race condition when loading known hosts file

This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e89f97  [SSHD-982] Fix race condition when loading known hosts file
5e89f97 is described below

commit 5e89f970f73531c4712c94b76e898b932d046c1e
Author: FliegenKLATSCH <ch...@koras.de>
AuthorDate: Thu Apr 23 17:16:08 2020 +0300

    [SSHD-982] Fix race condition when loading known hosts file
---
 .../org/apache/sshd/common/util/GenericUtils.java  | 24 +++++++++++++++
 .../keyverifier/KnownHostsServerKeyVerifier.java   | 32 +++++++++++++-------
 .../KnownHostsServerKeyVerifierTest.java           | 34 ++++++++++++++++++++++
 3 files changed, 79 insertions(+), 11 deletions(-)

diff --git a/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java b/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java
index 62a2910..e5da0d5 100644
--- a/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java
+++ b/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java
@@ -42,6 +42,7 @@ import java.util.SortedSet;
 import java.util.TreeMap;
 import java.util.TreeSet;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BinaryOperator;
 import java.util.function.Consumer;
 import java.util.function.Function;
@@ -1010,4 +1011,27 @@ public final class GenericUtils {
             Iterable<? extends Supplier<? extends Iterable<? extends T>>> providers) {
         return () -> stream(providers).<T> flatMap(s -> stream(s.get())).map(Function.identity()).iterator();
     }
+
+    /**
+     * The delegate Suppliers get() method is called exactly once and the result is cached.
+     * 
+     * @param  delegate The actual Supplier
+     * @return          The memoized Supplier
+     */
+    public static <T> Supplier<T> memoizeLock(Supplier<T> delegate) {
+        AtomicReference<T> value = new AtomicReference<>();
+        return () -> {
+            T val = value.get();
+            if (val == null) {
+                synchronized (value) {
+                    val = value.get();
+                    if (val == null) {
+                        val = Objects.requireNonNull(delegate.get());
+                        value.set(val);
+                    }
+                }
+            }
+            return val;
+        };
+    }
 }
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java b/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java
index 8cdecd7..b0d1f35 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java
@@ -38,6 +38,7 @@ import java.util.List;
 import java.util.Objects;
 import java.util.TreeSet;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
 
 import org.apache.sshd.client.config.hosts.KnownHostEntry;
 import org.apache.sshd.client.config.hosts.KnownHostHashValue;
@@ -79,7 +80,7 @@ public class KnownHostsServerKeyVerifier
 
     /**
      * Represents an entry in the internal verifier's cache
-     * 
+     *
      * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
      */
     public static class HostEntryPair {
@@ -119,7 +120,8 @@ public class KnownHostsServerKeyVerifier
 
     protected final Object updateLock = new Object();
     private final ServerKeyVerifier delegate;
-    private final AtomicReference<Collection<HostEntryPair>> keysHolder = new AtomicReference<>(Collections.emptyList());
+    private final AtomicReference<Supplier<? extends Collection<HostEntryPair>>> keysSupplier
+            = new AtomicReference<>(getKnownHostSupplier(null, getPath()));
     private ModifiedServerKeyAcceptor modKeyAcceptor;
 
     public KnownHostsServerKeyVerifier(ServerKeyVerifier delegate, Path file) {
@@ -153,35 +155,43 @@ public class KnownHostsServerKeyVerifier
 
     @Override
     public boolean verifyServerKey(ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey) {
-        Collection<HostEntryPair> knownHosts = getLoadedHostsEntries();
         try {
             if (checkReloadRequired()) {
                 Path file = getPath();
                 if (exists()) {
-                    knownHosts = reloadKnownHosts(clientSession, file);
+                    updateReloadAttributes();
+                    keysSupplier.set(GenericUtils.memoizeLock(getKnownHostSupplier(clientSession, file)));
                 } else {
                     if (log.isDebugEnabled()) {
                         log.debug("verifyServerKey({})[{}] missing known hosts file {}",
                                 clientSession, remoteAddress, file);
                     }
-                    knownHosts = Collections.emptyList();
+                    keysSupplier.set(GenericUtils.memoizeLock(Collections::emptyList));
                 }
-
-                setLoadedHostsEntries(knownHosts);
             }
         } catch (Throwable t) {
             return acceptIncompleteHostKeys(clientSession, remoteAddress, serverKey, t);
         }
 
+        Collection<HostEntryPair> knownHosts = keysSupplier.get().get();
+
         return acceptKnownHostEntries(clientSession, remoteAddress, serverKey, knownHosts);
     }
 
-    protected Collection<HostEntryPair> getLoadedHostsEntries() {
-        return keysHolder.get();
+    protected Supplier<Collection<HostEntryPair>> getKnownHostSupplier(ClientSession clientSession, Path file) {
+        return () -> {
+            try {
+                return reloadKnownHosts(clientSession, file);
+            } catch (Exception e) {
+                log.warn("verifyServerKey({}) Could not reload known hosts file {}",
+                        clientSession, file, e);
+                return Collections.emptyList();
+            }
+        };
     }
 
     protected void setLoadedHostsEntries(Collection<HostEntryPair> keys) {
-        keysHolder.set(keys);
+        keysSupplier.set(() -> keys);
     }
 
     /**
@@ -579,7 +589,7 @@ public class KnownHostsServerKeyVerifier
 
         if (delegate.verifyServerKey(clientSession, remoteAddress, serverKey)) {
             Path file = getPath();
-            Collection<HostEntryPair> keys = getLoadedHostsEntries();
+            Collection<HostEntryPair> keys = keysSupplier.get().get();
             try {
                 updateKnownHostsFile(clientSession, remoteAddress, serverKey, file, keys);
             } catch (Throwable t) {
diff --git a/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java b/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java
index 9552eee..9482ffd 100644
--- a/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java
@@ -94,6 +94,40 @@ public class KnownHostsServerKeyVerifierTest extends BaseTestSupport {
     }
 
     @Test
+    public void testParallelLoading() {
+        KnownHostsServerKeyVerifier verifier
+                = new KnownHostsServerKeyVerifier(AcceptAllServerKeyVerifier.INSTANCE, entriesFile) {
+                    @Override
+                    public ModifiedServerKeyAcceptor getModifiedServerKeyAcceptor() {
+                        return (clientSession, remoteAddress, entry, expected, actual) -> true; // don't care here
+                    }
+
+                    @Override
+                    protected boolean acceptKnownHostEntries(
+                            ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey,
+                            Collection<HostEntryPair> knownHosts) {
+                        if (GenericUtils.isEmpty(knownHosts)) {
+                            fail("Loaded known_hosts collection is empty!");
+                        }
+                        return super.acceptKnownHostEntries(clientSession, remoteAddress, serverKey, knownHosts);
+                    }
+                };
+
+        ClientFactoryManager manager = Mockito.mock(ClientFactoryManager.class);
+        Mockito.when(manager.getRandomFactory()).thenReturn(JceRandomFactory.INSTANCE);
+
+        HOST_KEYS.entrySet().parallelStream().forEach(line -> {
+            KnownHostEntry entry = hostsEntries.get(line.getKey());
+
+            ClientSession session = Mockito.mock(ClientSession.class);
+            Mockito.when(session.getFactoryManager()).thenReturn(manager);
+
+            Mockito.when(session.getConnectAddress()).thenReturn(line.getKey());
+            assertTrue("Failed to validate server=" + entry, verifier.verifyServerKey(session, line.getKey(), line.getValue()));
+        });
+    }
+
+    @Test
     public void testNoUpdatesNoNewHostsAuthentication() throws Exception {
         AtomicInteger delegateCount = new AtomicInteger(0);
         ServerKeyVerifier delegate = (clientSession, remoteAddress, serverKey) -> {