You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mina.apache.org by lg...@apache.org on 2021/01/02 07:14:50 UTC

[mina-sshd] 11/15: [SSHD-1114] Added callbacks for client-side host-based authentication progress

This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit a8ee3aa1e3fdcf7014073e2ac4c1bf4a8eea4b8c
Author: Lyor Goldstein <lg...@apache.org>
AuthorDate: Fri Jan 1 01:08:05 2021 +0200

    [SSHD-1114] Added callbacks for client-side host-based authentication progress
---
 CHANGES.md                                         |  1 +
 docs/event-listeners.md                            |  6 ++
 .../sshd/client/ClientAuthenticationManager.java   |  5 ++
 .../java/org/apache/sshd/client/SshClient.java     | 12 ++++
 .../hostbased/HostBasedAuthenticationReporter.java | 81 ++++++++++++++++++++++
 .../client/auth/hostbased/UserAuthHostBased.java   | 43 ++++++++++--
 .../sshd/client/session/AbstractClientSession.java | 14 ++++
 .../client/ClientAuthenticationManagerTest.java    | 11 +++
 .../sshd/common/auth/AuthenticationTest.java       | 76 +++++++++++++++++---
 9 files changed, 235 insertions(+), 14 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index c3aebd2..6a084dc 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -28,3 +28,4 @@
 * [SSHD-1109](https://issues.apache.org/jira/browse/SSHD-1109) Replace log4j with logback as the slf4j logger implementation for tests
 * [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added callbacks for client-side password authentication progress
 * [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added callbacks for client-side public key authentication progress
+* [SSHD-1114](https://issues.apache.org/jira/browse/SSHD-1114) Added callbacks for client-side host-based authentication progress
diff --git a/docs/event-listeners.md b/docs/event-listeners.md
index 50869ff..1f17a98 100644
--- a/docs/event-listeners.md
+++ b/docs/event-listeners.md
@@ -205,3 +205,9 @@ overriding any globally registered instance.
 Used to inform about the progress of the client-side public key authentication as described in [RFC-4252 section 7](https://tools.ietf.org/html/rfc4252#section-7).
 Can be registered globally on the `SshClient` and also for a specific `ClientSession` after it is established but before its `auth()` method is called - thus
 overriding any globally registered instance.
+
+### `HostBasedAuthenticationReporter`
+
+Used to inform about the progress of the client-side host-based authentication as described in [RFC-4252 section 9](https://tools.ietf.org/html/rfc4252#section-9).
+Can be registered globally on the `SshClient` and also for a specific `ClientSession` after it is established but before its `auth()` method is called - thus
+overriding any globally registered instance.
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/ClientAuthenticationManager.java b/sshd-core/src/main/java/org/apache/sshd/client/ClientAuthenticationManager.java
index b6c6706..4dff328 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/ClientAuthenticationManager.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/ClientAuthenticationManager.java
@@ -27,6 +27,7 @@ import org.apache.sshd.client.auth.AuthenticationIdentitiesProvider;
 import org.apache.sshd.client.auth.BuiltinUserAuthFactories;
 import org.apache.sshd.client.auth.UserAuth;
 import org.apache.sshd.client.auth.UserAuthFactory;
+import org.apache.sshd.client.auth.hostbased.HostBasedAuthenticationReporter;
 import org.apache.sshd.client.auth.keyboard.UserInteraction;
 import org.apache.sshd.client.auth.password.PasswordAuthenticationReporter;
 import org.apache.sshd.client.auth.password.PasswordIdentityProvider;
@@ -116,6 +117,10 @@ public interface ClientAuthenticationManager
 
     void setPublicKeyAuthenticationReporter(PublicKeyAuthenticationReporter reporter);
 
+    HostBasedAuthenticationReporter getHostBasedAuthenticationReporter();
+
+    void setHostBasedAuthenticationReporter(HostBasedAuthenticationReporter reporter);
+
     @Override
     default void setUserAuthFactoriesNames(Collection<String> names) {
         BuiltinUserAuthFactories.ParseResult result = BuiltinUserAuthFactories.parseFactoriesList(names);
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java b/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
index 2d07cfa..9660510 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/SshClient.java
@@ -44,6 +44,7 @@ import java.util.stream.Collectors;
 import org.apache.sshd.agent.SshAgentFactory;
 import org.apache.sshd.client.auth.AuthenticationIdentitiesProvider;
 import org.apache.sshd.client.auth.UserAuthFactory;
+import org.apache.sshd.client.auth.hostbased.HostBasedAuthenticationReporter;
 import org.apache.sshd.client.auth.keyboard.UserAuthKeyboardInteractiveFactory;
 import org.apache.sshd.client.auth.keyboard.UserInteraction;
 import org.apache.sshd.client.auth.password.PasswordAuthenticationReporter;
@@ -183,6 +184,7 @@ public class SshClient extends AbstractFactoryManager implements ClientFactoryMa
     private FilePasswordProvider filePasswordProvider;
     private PasswordIdentityProvider passwordIdentityProvider;
     private PasswordAuthenticationReporter passwordAuthenticationReporter;
+    private HostBasedAuthenticationReporter hostBasedAuthenticationReporter;
     private UserInteraction userInteraction;
 
     private final List<Object> identities = new CopyOnWriteArrayList<>();
@@ -272,6 +274,16 @@ public class SshClient extends AbstractFactoryManager implements ClientFactoryMa
     }
 
     @Override
+    public HostBasedAuthenticationReporter getHostBasedAuthenticationReporter() {
+        return hostBasedAuthenticationReporter;
+    }
+
+    @Override
+    public void setHostBasedAuthenticationReporter(HostBasedAuthenticationReporter reporter) {
+        this.hostBasedAuthenticationReporter = reporter;
+    }
+
+    @Override
     public List<UserAuthFactory> getUserAuthFactories() {
         return userAuthFactories;
     }
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/HostBasedAuthenticationReporter.java b/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/HostBasedAuthenticationReporter.java
new file mode 100644
index 0000000..18c2193
--- /dev/null
+++ b/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/HostBasedAuthenticationReporter.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sshd.client.auth.hostbased;
+
+import java.security.KeyPair;
+import java.util.List;
+
+import org.apache.sshd.client.session.ClientSession;
+
+/**
+ * Provides report about the client side host-based authentication progress
+ *
+ * @see    <a href="https://tools.ietf.org/html/rfc4252#section-9">RFC-4252 section 9</a>
+ * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
+ */
+public interface HostBasedAuthenticationReporter {
+    /**
+     * Sending the initial request to use host based authentication
+     *
+     * @param  session   The {@link ClientSession}
+     * @param  service   The requesting service name
+     * @param  identity  The {@link KeyPair} identity being attempted
+     * @param  hostname  The host name value sent to the server
+     * @param  username  The username value sent to the server
+     * @param  signature The signature data that is being sent to the server
+     * @throws Exception If failed to handle the callback - <B>Note:</B> may cause session close
+     */
+    default void signalAuthenticationAttempt(
+            ClientSession session, String service, KeyPair identity, String hostname, String username, byte[] signature)
+            throws Exception {
+        // ignored
+    }
+
+    /**
+     * @param  session   The {@link ClientSession}
+     * @param  service   The requesting service name
+     * @param  identity  The {@link KeyPair} identity being attempted
+     * @param  hostname  The host name value sent to the server
+     * @param  username  The username value sent to the server
+     * @throws Exception If failed to handle the callback - <B>Note:</B> may cause session close
+     */
+    default void signalAuthenticationSuccess(
+            ClientSession session, String service, KeyPair identity, String hostname, String username)
+            throws Exception {
+        // ignored
+    }
+
+    /**
+     * @param  session       The {@link ClientSession}
+     * @param  service       The requesting service name
+     * @param  identity      The {@link KeyPair} identity being attempted
+     * @param  hostname      The host name value sent to the server
+     * @param  username      The username value sent to the server
+     * @param  partial       {@code true} if some partial authentication success so far
+     * @param  serverMethods The {@link List} of authentication methods that can continue
+     * @throws Exception     If failed to handle the callback - <B>Note:</B> may cause session close
+     */
+    default void signalAuthenticationFailure(
+            ClientSession session, String service, KeyPair identity,
+            String hostname, String username, boolean partial, List<String> serverMethods)
+            throws Exception {
+        // ignored
+    }
+}
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/UserAuthHostBased.java b/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/UserAuthHostBased.java
index 861549a..5c18b5f 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/UserAuthHostBased.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/auth/hostbased/UserAuthHostBased.java
@@ -48,8 +48,9 @@ import org.apache.sshd.common.util.net.SshdSocketAddress;
 public class UserAuthHostBased extends AbstractUserAuth implements SignatureFactoriesManager {
     public static final String NAME = UserAuthHostBasedFactory.NAME;
 
-    private Iterator<? extends Map.Entry<KeyPair, ? extends Collection<X509Certificate>>> keys;
-    private final HostKeyIdentityProvider clientHostKeys;
+    protected Iterator<? extends Map.Entry<KeyPair, ? extends Collection<X509Certificate>>> keys;
+    protected Map.Entry<KeyPair, ? extends Collection<X509Certificate>> keyInfo;
+    protected final HostKeyIdentityProvider clientHostKeys;
     private List<NamedFactory<Signature>> factories;
     private String clientUsername;
     private String clientHostname;
@@ -103,7 +104,7 @@ public class UserAuthHostBased extends AbstractUserAuth implements SignatureFact
             return false;
         }
 
-        Map.Entry<KeyPair, ? extends Collection<X509Certificate>> keyInfo = keys.next();
+        keyInfo = keys.next();
         KeyPair kp = keyInfo.getKey();
         PublicKey pub = kp.getPublic();
         String keyType = KeyUtils.getKeyType(pub);
@@ -159,13 +160,22 @@ public class UserAuthHostBased extends AbstractUserAuth implements SignatureFact
         buffer.putBytes(keyBytes);
         buffer.putString(clientHostname);
         buffer.putString(clientUsername);
-        appendSignature(session, service, keyType, pub, keyBytes, clientHostname, clientUsername, verifier, buffer);
+
+        byte[] signature = appendSignature(
+                session, service, keyType, pub, keyBytes,
+                clientHostname, clientUsername, verifier, buffer);
+        HostBasedAuthenticationReporter reporter = session.getHostBasedAuthenticationReporter();
+        if (reporter != null) {
+            reporter.signalAuthenticationAttempt(
+                    session, service, kp, clientHostname, clientUsername, signature);
+        }
+
         session.writePacket(buffer);
         return true;
     }
 
     @SuppressWarnings("checkstyle:ParameterNumber")
-    protected void appendSignature(
+    protected byte[] appendSignature(
             ClientSession session, String service,
             String keyType, PublicKey key, byte[] keyBytes,
             String clientHostname, String clientUsername,
@@ -203,6 +213,7 @@ public class UserAuthHostBased extends AbstractUserAuth implements SignatureFact
         bs.putString(keyType);
         bs.putBytes(signature);
         buffer.putBytes(bs.array(), bs.rpos(), bs.available());
+        return signature;
     }
 
     @Override
@@ -215,6 +226,28 @@ public class UserAuthHostBased extends AbstractUserAuth implements SignatureFact
                                         + " received unknown packet: cmd=" + SshConstants.getCommandMessageName(cmd));
     }
 
+    @Override
+    public void signalAuthMethodSuccess(ClientSession session, String service, Buffer buffer) throws Exception {
+        HostBasedAuthenticationReporter reporter = session.getHostBasedAuthenticationReporter();
+        if (reporter != null) {
+            reporter.signalAuthenticationSuccess(
+                    session, service, (keyInfo == null) ? null : keyInfo.getKey(), resolveClientHostname(),
+                    resolveClientUsername());
+        }
+    }
+
+    @Override
+    public void signalAuthMethodFailure(
+            ClientSession session, String service, boolean partial, List<String> serverMethods, Buffer buffer)
+            throws Exception {
+        HostBasedAuthenticationReporter reporter = session.getHostBasedAuthenticationReporter();
+        if (reporter != null) {
+            reporter.signalAuthenticationFailure(
+                    session, service, (keyInfo == null) ? null : keyInfo.getKey(),
+                    resolveClientHostname(), resolveClientUsername(), partial, serverMethods);
+        }
+    }
+
     protected String resolveClientUsername() {
         String value = getClientUsername();
         return GenericUtils.isEmpty(value) ? OsUtils.getCurrentUser() : value;
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java b/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java
index 490bbfb..6b1faff 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java
@@ -32,6 +32,7 @@ import java.util.concurrent.CopyOnWriteArrayList;
 import org.apache.sshd.client.ClientFactoryManager;
 import org.apache.sshd.client.auth.AuthenticationIdentitiesProvider;
 import org.apache.sshd.client.auth.UserAuthFactory;
+import org.apache.sshd.client.auth.hostbased.HostBasedAuthenticationReporter;
 import org.apache.sshd.client.auth.keyboard.UserInteraction;
 import org.apache.sshd.client.auth.password.PasswordAuthenticationReporter;
 import org.apache.sshd.client.auth.password.PasswordIdentityProvider;
@@ -95,6 +96,7 @@ public abstract class AbstractClientSession extends AbstractSession implements C
     private PasswordAuthenticationReporter passwordAuthenticationReporter;
     private KeyIdentityProvider keyIdentityProvider;
     private PublicKeyAuthenticationReporter publicKeyAuthenticationReporter;
+    private HostBasedAuthenticationReporter hostBasedAuthenticationReporter;
     private List<UserAuthFactory> userAuthFactories;
     private SocketAddress connectAddress;
     private ClientProxyConnector proxyConnector;
@@ -229,6 +231,18 @@ public abstract class AbstractClientSession extends AbstractSession implements C
     }
 
     @Override
+    public HostBasedAuthenticationReporter getHostBasedAuthenticationReporter() {
+        ClientFactoryManager manager = getFactoryManager();
+        return resolveEffectiveProvider(HostBasedAuthenticationReporter.class, hostBasedAuthenticationReporter,
+                manager.getHostBasedAuthenticationReporter());
+    }
+
+    @Override
+    public void setHostBasedAuthenticationReporter(HostBasedAuthenticationReporter reporter) {
+        this.hostBasedAuthenticationReporter = reporter;
+    }
+
+    @Override
     public ClientProxyConnector getClientProxyConnector() {
         ClientFactoryManager manager = getFactoryManager();
         return resolveEffectiveProvider(ClientProxyConnector.class, proxyConnector, manager.getClientProxyConnector());
diff --git a/sshd-core/src/test/java/org/apache/sshd/client/ClientAuthenticationManagerTest.java b/sshd-core/src/test/java/org/apache/sshd/client/ClientAuthenticationManagerTest.java
index c1062e6..981f811 100644
--- a/sshd-core/src/test/java/org/apache/sshd/client/ClientAuthenticationManagerTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/client/ClientAuthenticationManagerTest.java
@@ -30,6 +30,7 @@ import java.util.concurrent.atomic.AtomicReference;
 import org.apache.sshd.client.auth.AuthenticationIdentitiesProvider;
 import org.apache.sshd.client.auth.BuiltinUserAuthFactories;
 import org.apache.sshd.client.auth.UserAuthFactory;
+import org.apache.sshd.client.auth.hostbased.HostBasedAuthenticationReporter;
 import org.apache.sshd.client.auth.keyboard.UserInteraction;
 import org.apache.sshd.client.auth.password.PasswordAuthenticationReporter;
 import org.apache.sshd.client.auth.password.PasswordIdentityProvider;
@@ -102,6 +103,16 @@ public class ClientAuthenticationManagerTest extends BaseTestSupport {
             }
 
             @Override
+            public HostBasedAuthenticationReporter getHostBasedAuthenticationReporter() {
+                return null;
+            }
+
+            @Override
+            public void setHostBasedAuthenticationReporter(HostBasedAuthenticationReporter reporter) {
+                throw new UnsupportedOperationException("setHostBasedAuthenticationReporter(" + reporter + ")");
+            }
+
+            @Override
             public UserInteraction getUserInteraction() {
                 return null;
             }
diff --git a/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java b/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java
index c75d07d..beb3141 100644
--- a/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java
@@ -36,8 +36,10 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
 
 import org.apache.sshd.client.SshClient;
+import org.apache.sshd.client.auth.hostbased.HostBasedAuthenticationReporter;
 import org.apache.sshd.client.auth.hostbased.HostKeyIdentityProvider;
 import org.apache.sshd.client.auth.keyboard.UserInteraction;
 import org.apache.sshd.client.auth.password.PasswordAuthenticationReporter;
@@ -72,6 +74,7 @@ import org.apache.sshd.common.util.security.SecurityUtils;
 import org.apache.sshd.core.CoreModuleProperties;
 import org.apache.sshd.server.ServerFactoryManager;
 import org.apache.sshd.server.SshServer;
+import org.apache.sshd.server.auth.hostbased.HostBasedAuthenticator;
 import org.apache.sshd.server.auth.keyboard.DefaultKeyboardInteractiveAuthenticator;
 import org.apache.sshd.server.auth.keyboard.InteractiveChallenge;
 import org.apache.sshd.server.auth.keyboard.KeyboardInteractiveAuthenticator;
@@ -773,17 +776,72 @@ public class AuthenticationTest extends BaseTestSupport {
 
     @Test // see SSHD-620
     public void testHostBasedAuthentication() throws Exception {
-        String hostClienUser = getClass().getSimpleName();
+        AtomicInteger invocationCount = new AtomicInteger(0);
+        testHostBasedAuthentication(
+                (
+                        session, username, clientHostKey, clientHostName, clientUsername,
+                        certificates) -> invocationCount.incrementAndGet() > 0,
+                session -> {
+                    /* ignored */ });
+        assertEquals("Mismatched authenticator invocation count", 1, invocationCount.get());
+    }
+
+    @Test   // see SSHD-1114
+    public void testHostBasedAuthenticationReporter() throws Exception {
+        AtomicReference<String> hostnameClientHolder = new AtomicReference<>();
+        AtomicReference<String> usernameClientHolder = new AtomicReference<>();
+        AtomicReference<PublicKey> keyClientHolder = new AtomicReference<>();
+        HostBasedAuthenticator authenticator
+                = (session, username, clientHostKey, clientHostName, clientUsername, certificates) -> {
+                    return Objects.equals(clientHostName, hostnameClientHolder.get())
+                            && Objects.equals(clientUsername, usernameClientHolder.get())
+                            && KeyUtils.compareKeys(clientHostKey, keyClientHolder.get());
+                };
+
+        HostBasedAuthenticationReporter reporter = new HostBasedAuthenticationReporter() {
+            @Override
+            public void signalAuthenticationAttempt(
+                    ClientSession session, String service, KeyPair identity, String hostname, String username, byte[] signature)
+                    throws Exception {
+                hostnameClientHolder.set(hostname);
+                usernameClientHolder.set(username);
+                keyClientHolder.set(identity.getPublic());
+            }
+
+            @Override
+            public void signalAuthenticationSuccess(
+                    ClientSession session, String service, KeyPair identity, String hostname, String username)
+                    throws Exception {
+                assertEquals("Host", hostname, hostnameClientHolder.get());
+                assertEquals("User", username, usernameClientHolder.get());
+                assertKeyEquals("Identity", identity.getPublic(), keyClientHolder.get());
+            }
+
+            @Override
+            public void signalAuthenticationFailure(
+                    ClientSession session, String service, KeyPair identity,
+                    String hostname, String username, boolean partial, List<String> serverMethods)
+                    throws Exception {
+                fail("Unexpected failure signalled");
+            }
+        };
+
+        testHostBasedAuthentication(authenticator, session -> session.setHostBasedAuthenticationReporter(reporter));
+    }
+
+    private void testHostBasedAuthentication(
+            HostBasedAuthenticator delegate, Consumer<? super ClientSession> preAuthInitializer)
+            throws Exception {
+        String hostClientUser = getClass().getSimpleName();
         String hostClientName = SshdSocketAddress.toAddressString(SshdSocketAddress.getFirstExternalNetwork4Address());
         KeyPair hostClientKey = CommonTestSupportUtils.generateKeyPair(
                 CommonTestSupportUtils.DEFAULT_TEST_HOST_KEY_PROVIDER_ALGORITHM,
                 CommonTestSupportUtils.DEFAULT_TEST_HOST_KEY_SIZE);
-        AtomicInteger invocationCount = new AtomicInteger(0);
         sshd.setHostBasedAuthenticator((session, username, clientHostKey, clientHostName, clientUsername, certificates) -> {
-            invocationCount.incrementAndGet();
-            return hostClienUser.equals(clientUsername)
+            return hostClientUser.equals(clientUsername)
                     && hostClientName.equals(clientHostName)
-                    && KeyUtils.compareKeys(hostClientKey.getPublic(), clientHostKey);
+                    && KeyUtils.compareKeys(hostClientKey.getPublic(), clientHostKey)
+                    && delegate.authenticate(session, username, clientHostKey, clientHostName, clientUsername, certificates);
         });
         sshd.setPasswordAuthenticator(RejectAllPasswordAuthenticator.INSTANCE);
         sshd.setKeyboardInteractiveAuthenticator(KeyboardInteractiveAuthenticator.NONE);
@@ -796,16 +854,16 @@ public class AuthenticationTest extends BaseTestSupport {
             org.apache.sshd.client.auth.hostbased.UserAuthHostBasedFactory factory
                     = new org.apache.sshd.client.auth.hostbased.UserAuthHostBasedFactory();
             // TODO factory.setClientHostname(CLIENT_HOSTNAME);
-            factory.setClientUsername(hostClienUser);
+            factory.setClientUsername(hostClientUser);
             factory.setClientHostKeys(HostKeyIdentityProvider.wrap(hostClientKey));
 
             client.setUserAuthFactories(Collections.singletonList(factory));
             client.start();
-            try (ClientSession s = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
+            try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
                     .verify(CONNECT_TIMEOUT)
                     .getSession()) {
-                s.auth().verify(AUTH_TIMEOUT);
-                assertEquals("Mismatched authenticator invocation count", 1, invocationCount.get());
+                preAuthInitializer.accept(session);
+                session.auth().verify(AUTH_TIMEOUT);
             } finally {
                 client.stop();
             }