You are viewing a plain text version of this content. The canonical link for it is here.
Posted to jira@kafka.apache.org by "ASF GitHub Bot (JIRA)" <ji...@apache.org> on 2018/04/05 08:42:01 UTC

[jira] [Commented] (KAFKA-4292) KIP-86: Configurable SASL callback handlers

    [ https://issues.apache.org/jira/browse/KAFKA-4292?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16426633#comment-16426633 ] 

ASF GitHub Bot commented on KAFKA-4292:
---------------------------------------

rajinisivaram closed pull request #2022: KAFKA-4292: Configurable SASL callback handlers
URL: https://github.com/apache/kafka/pull/2022
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java
index f61b7dd5e6c..148ab15dfad 100644
--- a/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java
+++ b/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java
@@ -49,7 +49,24 @@
     public static final String SASL_JAAS_CONFIG = "sasl.jaas.config";
     public static final String SASL_JAAS_CONFIG_DOC = "JAAS login context parameters for SASL connections in the format used by JAAS configuration files. "
         + "JAAS configuration file format is described <a href=\"http://docs.oracle.com/javase/8/docs/technotes/guides/security/jgss/tutorials/LoginConfigFile.html\">here</a>. "
-        + "The format for the value is: '<loginModuleClass> <controlFlag> (<optionName>=<optionValue>)*;'";
+        + "The format for the value is: '<loginModuleClass> <controlFlag> (<optionName>=<optionValue>)*;'. For brokers, "
+        + "the config must be prefixed with listener prefix and SASL mechanism name in lower-case. For example, "
+        + "listener.name.sasl_ssl.scram-sha-256.sasl.jaas.config=com.example.ScramLoginModule required;";
+
+    public static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS = "sasl.client.callback.handler.class";
+    public static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL client callback handler class "
+        + "that implements the AuthenticateCallbackHandler interface.";
+
+    public static final String SASL_LOGIN_CALLBACK_HANDLER_CLASS = "sasl.login.callback.handler.class";
+    public static final String SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL login callback handler class "
+            + "that implements the AuthenticateCallbackHandler interface. For brokers, login callback handler config must be prefixed with "
+            + "listener prefix and SASL mechanism name in lower-case. For example, "
+            + "listener.name.sasl_ssl.scram-sha-256.sasl.login.callback.handler.class=com.example.CustomScramLoginCallbackHandler";
+
+    public static final String SASL_LOGIN_CLASS = "sasl.login.class";
+    public static final String SASL_LOGIN_CLASS_DOC = "The fully qualified name of a class that implements the Login interface. "
+        + "For brokers, login config must be prefixed with listener prefix and SASL mechanism name in lower-case. For example, "
+        + "listener.name.sasl_ssl.scram-sha-256.sasl.login.class=com.example.CustomScramLogin";
 
     public static final String SASL_KERBEROS_SERVICE_NAME = "sasl.kerberos.service.name";
     public static final String SASL_KERBEROS_SERVICE_NAME_DOC = "The Kerberos principal name that Kafka runs as. "
@@ -95,6 +112,9 @@ public static void addClientSaslSupport(ConfigDef config) {
                 .define(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER, ConfigDef.Type.DOUBLE, SaslConfigs.DEFAULT_KERBEROS_TICKET_RENEW_JITTER, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER_DOC)
                 .define(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN, ConfigDef.Type.LONG, SaslConfigs.DEFAULT_KERBEROS_MIN_TIME_BEFORE_RELOGIN, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN_DOC)
                 .define(SaslConfigs.SASL_MECHANISM, ConfigDef.Type.STRING, SaslConfigs.DEFAULT_SASL_MECHANISM, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_MECHANISM_DOC)
-                .define(SaslConfigs.SASL_JAAS_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_JAAS_CONFIG_DOC);
+                .define(SaslConfigs.SASL_JAAS_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_JAAS_CONFIG_DOC)
+                .define(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC)
+                .define(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC)
+                .define(SaslConfigs.SASL_LOGIN_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_LOGIN_CLASS_DOC);
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java
index 18616ec74e5..a29d8069b99 100644
--- a/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java
+++ b/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java
@@ -33,6 +33,7 @@
     public static final String SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG = "sasl.kerberos.principal.to.local.rules";
     public static final String SSL_CLIENT_AUTH_CONFIG = "ssl.client.auth";
     public static final String SASL_ENABLED_MECHANISMS_CONFIG = "sasl.enabled.mechanisms";
+    public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS = "sasl.server.callback.handler.class";
 
     public static final String PRINCIPAL_BUILDER_CLASS_DOC = "The fully qualified name of a class that implements the " +
             "KafkaPrincipalBuilder interface, which is used to build the KafkaPrincipal object used during " +
@@ -67,4 +68,9 @@
             + "Only GSSAPI is enabled by default.";
     public static final List<String> DEFAULT_SASL_ENABLED_MECHANISMS = Collections.singletonList(SaslConfigs.GSSAPI_MECHANISM);
 
+    public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL server callback handler "
+            + "class that implements the AuthenticateCallbackHandler interface. Server callback handlers must be prefixed with "
+            + "listener prefix and SASL mechanism name in lower-case. For example, "
+            + "listener.name.sasl_ssl.plain.sasl.server.callback.handler.class=com.example.CustomPlainCallbackHandler.";
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java b/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java
index fc0cb14925b..2decccbc506 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java
@@ -73,6 +73,10 @@ public String configPrefix() {
     }
 
     public String saslMechanismConfigPrefix(String saslMechanism) {
-        return configPrefix() + saslMechanism.toLowerCase(Locale.ROOT) + ".";
+        return configPrefix() + saslMechanismPrefix(saslMechanism);
+    }
+
+    public static String saslMechanismPrefix(String saslMechanism) {
+        return saslMechanism.toLowerCase(Locale.ROOT) + ".";
     }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java
index 095f826ccba..5502164563b 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java
@@ -21,22 +21,35 @@
 import org.apache.kafka.common.config.SslConfigs;
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs;
 import org.apache.kafka.common.memory.MemoryPool;
-import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.security.JaasContext;
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
-import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.Login;
+import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.authenticator.DefaultLogin;
 import org.apache.kafka.common.security.authenticator.LoginManager;
 import org.apache.kafka.common.security.authenticator.SaslClientAuthenticator;
+import org.apache.kafka.common.security.authenticator.SaslClientCallbackHandler;
 import org.apache.kafka.common.security.authenticator.SaslServerAuthenticator;
+import org.apache.kafka.common.security.authenticator.SaslServerCallbackHandler;
+import org.apache.kafka.common.security.kerberos.KerberosClientCallbackHandler;
+import org.apache.kafka.common.security.kerberos.KerberosLogin;
+import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
+import org.apache.kafka.common.security.plain.internal.PlainSaslServer;
+import org.apache.kafka.common.security.plain.internal.PlainServerCallbackHandler;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramServerCallbackHandler;
 import org.apache.kafka.common.security.ssl.SslFactory;
+import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.apache.kafka.common.utils.Java;
+import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
+import java.io.IOException;
 import java.net.Socket;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
@@ -66,6 +79,7 @@
     private SslFactory sslFactory;
     private Map<String, ?> configs;
     private KerberosShortNamer kerberosShortNamer;
+    private Map<String, AuthenticateCallbackHandler> saslCallbackHandlers;
 
     public SaslChannelBuilder(Mode mode,
                               Map<String, JaasContext> jaasContexts,
@@ -87,15 +101,25 @@ public SaslChannelBuilder(Mode mode,
         this.clientSaslMechanism = clientSaslMechanism;
         this.credentialCache = credentialCache;
         this.tokenCache = tokenCache;
+        this.saslCallbackHandlers = new HashMap<>();
     }
 
+    @SuppressWarnings("unchecked")
     @Override
     public void configure(Map<String, ?> configs) throws KafkaException {
         try {
             this.configs = configs;
-            boolean hasKerberos = jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM);
+            if (mode == Mode.SERVER)
+                createServerCallbackHandlers(configs);
+            else
+                createClientCallbackHandler(configs);
+            for (Map.Entry<String, AuthenticateCallbackHandler> entry : saslCallbackHandlers.entrySet()) {
+                String mechanism = entry.getKey();
+                entry.getValue().configure(configs, mechanism, jaasContexts.get(mechanism).configurationEntries());
+            }
 
-            if (hasKerberos) {
+            Class<? extends Login> defaultLoginClass = DefaultLogin.class;
+            if (jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM)) {
                 String defaultRealm;
                 try {
                     defaultRealm = defaultKerberosRealm();
@@ -106,12 +130,13 @@ public void configure(Map<String, ?> configs) throws KafkaException {
                 List<String> principalToLocalRules = (List<String>) configs.get(BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG);
                 if (principalToLocalRules != null)
                     kerberosShortNamer = KerberosShortNamer.fromUnparsedRules(defaultRealm, principalToLocalRules);
+                defaultLoginClass = KerberosLogin.class;
             }
             for (Map.Entry<String, JaasContext> entry : jaasContexts.entrySet()) {
                 String mechanism = entry.getKey();
                 // With static JAAS configuration, use KerberosLogin if Kerberos is enabled. With dynamic JAAS configuration,
                 // use KerberosLogin only for the LoginContext corresponding to GSSAPI
-                LoginManager loginManager = LoginManager.acquireLoginManager(entry.getValue(), mechanism, hasKerberos, configs);
+                LoginManager loginManager = LoginManager.acquireLoginManager(entry.getValue(), mechanism, defaultLoginClass, configs);
                 loginManagers.put(mechanism, loginManager);
                 subjects.put(mechanism, loginManager.subject());
             }
@@ -120,7 +145,7 @@ public void configure(Map<String, ?> configs) throws KafkaException {
                 this.sslFactory = new SslFactory(mode, "none", isInterBrokerListener);
                 this.sslFactory.configure(configs);
             }
-        } catch (Exception e) {
+        } catch (Throwable e) {
             close();
             throw new KafkaException(e);
         }
@@ -156,11 +181,20 @@ public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize
             TransportLayer transportLayer = buildTransportLayer(id, key, socketChannel);
             Authenticator authenticator;
             if (mode == Mode.SERVER) {
-                authenticator = buildServerAuthenticator(configs, id, transportLayer, subjects);
+                authenticator = buildServerAuthenticator(configs,
+                        saslCallbackHandlers,
+                        id,
+                        transportLayer,
+                        subjects);
             } else {
                 LoginManager loginManager = loginManagers.get(clientSaslMechanism);
-                authenticator = buildClientAuthenticator(configs, id, socket.getInetAddress().getHostName(),
-                        loginManager.serviceName(), transportLayer, loginManager.subject());
+                authenticator = buildClientAuthenticator(configs,
+                        saslCallbackHandlers.get(clientSaslMechanism),
+                        id,
+                        socket.getInetAddress().getHostName(),
+                        loginManager.serviceName(),
+                        transportLayer,
+                        subjects.get(clientSaslMechanism));
             }
             return new KafkaChannel(id, transportLayer, authenticator, maxReceiveSize, memoryPool != null ? memoryPool : MemoryPool.NONE);
         } catch (Exception e) {
@@ -174,6 +208,8 @@ public void close()  {
         for (LoginManager loginManager : loginManagers.values())
             loginManager.release();
         loginManagers.clear();
+        for (AuthenticateCallbackHandler handler : saslCallbackHandlers.values())
+            handler.close();
     }
 
     private TransportLayer buildTransportLayer(String id, SelectionKey key, SocketChannel socketChannel) throws IOException {
@@ -186,16 +222,23 @@ private TransportLayer buildTransportLayer(String id, SelectionKey key, SocketCh
     }
 
     // Visible to override for testing
-    protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs, String id,
-            TransportLayer transportLayer, Map<String, Subject> subjects) throws IOException {
-        return new SaslServerAuthenticator(configs, id, jaasContexts, subjects,
-                kerberosShortNamer, credentialCache, listenerName, securityProtocol, transportLayer, tokenCache);
+    protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs,
+                                                               Map<String, AuthenticateCallbackHandler> callbackHandlers,
+                                                               String id,
+                                                               TransportLayer transportLayer,
+                                                               Map<String, Subject> subjects) throws IOException {
+        return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects,
+                kerberosShortNamer, listenerName, securityProtocol, transportLayer);
     }
 
     // Visible to override for testing
-    protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs, String id,
-            String serverHost, String servicePrincipal, TransportLayer transportLayer, Subject subject) throws IOException {
-        return new SaslClientAuthenticator(configs, id, subject, servicePrincipal,
+    protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
+                                                               AuthenticateCallbackHandler callbackHandler,
+                                                               String id,
+                                                               String serverHost,
+                                                               String servicePrincipal,
+                                                               TransportLayer transportLayer, Subject subject) throws IOException {
+        return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal,
                 serverHost, clientSaslMechanism, handshakeRequestEnable, transportLayer);
     }
 
@@ -224,4 +267,30 @@ private static String defaultKerberosRealm() throws ClassNotFoundException, NoSu
         getDefaultRealmMethod = classRef.getDeclaredMethod("getDefaultRealm", new Class[0]);
         return (String) getDefaultRealmMethod.invoke(kerbConf, new Object[0]);
     }
+
+    private void createClientCallbackHandler(Map<String, ?> configs) {
+        Class<? extends AuthenticateCallbackHandler> clazz = (Class<? extends AuthenticateCallbackHandler>) configs.get(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS);
+        if (clazz == null)
+            clazz = clientSaslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM) ? KerberosClientCallbackHandler.class : SaslClientCallbackHandler.class;
+        AuthenticateCallbackHandler callbackHandler = Utils.newInstance(clazz);
+        saslCallbackHandlers.put(clientSaslMechanism, callbackHandler);
+    }
+
+    private void createServerCallbackHandlers(Map<String, ?> configs) throws ClassNotFoundException {
+        for (String mechanism : jaasContexts.keySet()) {
+            AuthenticateCallbackHandler callbackHandler;
+            String prefix = ListenerName.saslMechanismPrefix(mechanism);
+            Class<? extends AuthenticateCallbackHandler> clazz =
+                    (Class<? extends AuthenticateCallbackHandler>) configs.get(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
+            if (clazz != null)
+                callbackHandler = Utils.newInstance(clazz);
+            else if (mechanism.equals(PlainSaslServer.PLAIN_MECHANISM))
+                callbackHandler = new PlainServerCallbackHandler();
+            else if (ScramMechanism.isScram(mechanism))
+                callbackHandler = new ScramServerCallbackHandler(credentialCache.cache(mechanism, ScramCredential.class), tokenCache);
+            else
+                callbackHandler = new SaslServerCallbackHandler();
+            saslCallbackHandlers.put(mechanism, callbackHandler);
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java b/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java
index d72f00dc590..58f770e8014 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java
@@ -183,7 +183,7 @@ public Password dynamicJaasConfig() {
      * Returns the configuration option for <code>key</code> from this context.
      * If login module name is specified, return option value only from that module.
      */
-    public String configEntryOption(String key, String loginModuleName) {
+    public static String configEntryOption(List<AppConfigurationEntry> configurationEntries, String key, String loginModuleName) {
         for (AppConfigurationEntry entry : configurationEntries) {
             if (loginModuleName != null && !loginModuleName.equals(entry.getLoginModuleName()))
                 continue;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java
new file mode 100644
index 00000000000..8951d3a5893
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.auth;
+
+import java.util.List;
+import java.util.Map;
+
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.login.AppConfigurationEntry;
+
+/*
+ * Callback handler for SASL-based authentication
+ */
+public interface AuthenticateCallbackHandler extends CallbackHandler {
+
+    /**
+     * Configures this callback handler for the specified SASL mechanism.
+     *
+     * @param configs Key-value pairs containing the parsed configuration options of
+     *        the client or broker. Note that these are the Kafka configuration options
+     *        and not the JAAS configuration options. JAAS config options may be obtained
+     *        from `jaasConfigEntries` for callbacks which obtain some configs from the
+     *        JAAS configuration. For configs that may be specified as both Kafka config
+     *        as well as JAAS config (e.g. sasl.kerberos.service.name), the configuration
+     *        is treated as invalid if conflicting values are provided.
+     * @param saslMechanism Negotiated SASL mechanism. For clients, this is the SASL
+     *        mechanism configured for the client. For brokers, this is the mechanism
+     *        negotiated with the client and is one of the mechanisms enabled on the broker.
+     * @param jaasConfigEntries JAAS configuration entries from the JAAS login context.
+     *        This list contains a single entry for clients and may contain more than
+     *        one entry for brokers if multiple mechanisms are enabled on a listener using
+     *        static JAAS configuration where there is no mapping between mechanisms and
+     *        login module entries. In this case, callback handlers can use the login module in
+     *        `jaasConfigEntries` to identify the entry corresponding to `saslMechanism`.
+     *        Alternatively, dynamic JAAS configuration option
+     *        {@link org.apache.kafka.common.config.SaslConfigs#SASL_JAAS_CONFIG} may be
+     *        configured on brokers with listener and mechanism prefix, in which case
+     *        only the configuration entry corresponding to `saslMechanism` will be provided
+     *        in `jaasConfigEntries`.
+     */
+    void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries);
+
+    /**
+     * Closes this instance.
+     */
+    void close();
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/Login.java b/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java
similarity index 55%
rename from clients/src/main/java/org/apache/kafka/common/security/authenticator/Login.java
rename to clients/src/main/java/org/apache/kafka/common/security/auth/Login.java
index b41d1b2572f..eda5e7a225a 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/Login.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java
@@ -14,13 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.authenticator;
-
-import org.apache.kafka.common.security.JaasContext;
+package org.apache.kafka.common.security.auth;
 
 import java.util.Map;
 
 import javax.security.auth.Subject;
+import javax.security.auth.login.Configuration;
 import javax.security.auth.login.LoginContext;
 import javax.security.auth.login.LoginException;
 
@@ -31,8 +30,21 @@
 
     /**
      * Configures this login instance.
+     * @param configs Key-value pairs containing the parsed configuration options of
+     *        the client or broker. Note that these are the Kafka configuration options
+     *        and not the JAAS configuration options. The JAAS options may be obtained
+     *        from `jaasConfiguration`.
+     * @param contextName JAAS context name for this login which may be used to obtain
+     *        the login context from `jaasConfiguration`.
+     * @param jaasConfiguration JAAS configuration containing the login context named
+     *        `contextName`. If static JAAS configuration is used, this `Configuration`
+     *         may also contain other login contexts.
+     * @param loginCallbackHandler Login callback handler instance to use for this Login.
+     *        Login callback handler class may be configured using
+     *        {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_CALLBACK_HANDLER_CLASS}.
      */
-    void configure(Map<String, ?> configs, JaasContext jaasContext);
+    void configure(Map<String, ?> configs, String contextName, Configuration jaasConfiguration,
+                   AuthenticateCallbackHandler loginCallbackHandler);
 
     /**
      * Performs login for each login module specified for the login context of this instance.
@@ -54,4 +66,3 @@
      */
     void close();
 }
-
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java
index 643f859e829..7e1350864e4 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java
@@ -16,20 +16,23 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.auth.login.Configuration;
 import javax.security.auth.login.LoginContext;
 import javax.security.auth.login.LoginException;
 import javax.security.sasl.RealmCallback;
 import javax.security.auth.callback.Callback;
-import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.auth.Subject;
 
-import org.apache.kafka.common.security.JaasContext;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.Login;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -38,17 +41,22 @@
 public abstract class AbstractLogin implements Login {
     private static final Logger log = LoggerFactory.getLogger(AbstractLogin.class);
 
-    private JaasContext jaasContext;
+    private String contextName;
+    private Configuration configuration;
     private LoginContext loginContext;
+    private AuthenticateCallbackHandler loginCallbackHandler;
 
     @Override
-    public void configure(Map<String, ?> configs, JaasContext jaasContext) {
-        this.jaasContext = jaasContext;
+    public void configure(Map<String, ?> configs, String contextName, Configuration configuration,
+                          AuthenticateCallbackHandler loginCallbackHandler) {
+        this.contextName = contextName;
+        this.configuration = configuration;
+        this.loginCallbackHandler = loginCallbackHandler;
     }
 
     @Override
     public LoginContext login() throws LoginException {
-        loginContext = new LoginContext(jaasContext.name(), null, new LoginCallbackHandler(), jaasContext.configuration());
+        loginContext = new LoginContext(contextName, null, loginCallbackHandler, configuration);
         loginContext.login();
         log.info("Successfully logged in.");
         return loginContext;
@@ -59,8 +67,12 @@ public Subject subject() {
         return loginContext.getSubject();
     }
 
-    protected JaasContext jaasContext() {
-        return jaasContext;
+    protected String contextName() {
+        return contextName;
+    }
+
+    protected Configuration configuration() {
+        return configuration;
     }
 
     /**
@@ -70,7 +82,11 @@ protected JaasContext jaasContext() {
      * callback handlers which require additional user input.
      *
      */
-    public static class LoginCallbackHandler implements CallbackHandler {
+    public static class DefaultLoginCallbackHandler implements AuthenticateCallbackHandler {
+
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        }
 
         @Override
         public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
@@ -90,6 +106,10 @@ public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
                 }
             }
         }
+
+        @Override
+        public void close() {
+        }
     }
 }
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AuthCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AuthCallbackHandler.java
deleted file mode 100644
index d517162a176..00000000000
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AuthCallbackHandler.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.common.security.authenticator;
-
-import java.util.Map;
-
-import org.apache.kafka.common.network.Mode;
-
-import javax.security.auth.Subject;
-import javax.security.auth.callback.CallbackHandler;
-
-/*
- * Callback handler for SASL-based authentication
- */
-public interface AuthCallbackHandler extends CallbackHandler {
-
-    /**
-     * Configures this callback handler.
-     *
-     * @param configs Configuration
-     * @param mode The mode that indicates if this is a client or server connection
-     * @param subject Subject from login context
-     * @param saslMechanism Negotiated SASL mechanism
-     */
-    void configure(Map<String, ?> configs, Mode mode, Subject subject, String saslMechanism);
-
-    /**
-     * Closes this instance.
-     */
-    void close();
-}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java
index 81dc063c743..4ae798d7d29 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java
@@ -23,11 +23,17 @@
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Objects;
 
+
+import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.config.types.Password;
+import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.security.JaasContext;
-import org.apache.kafka.common.security.kerberos.KerberosLogin;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.auth.Login;
+import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -36,20 +42,23 @@
     private static final Logger LOGGER = LoggerFactory.getLogger(LoginManager.class);
 
     // static configs (broker or client)
-    private static final Map<String, LoginManager> STATIC_INSTANCES = new HashMap<>();
+    private static final Map<LoginMetadata<String>, LoginManager> STATIC_INSTANCES = new HashMap<>();
 
-    // dynamic configs (client-only)
-    private static final Map<Password, LoginManager> DYNAMIC_INSTANCES = new HashMap<>();
+    // dynamic configs (broker or client)
+    private static final Map<LoginMetadata<Password>, LoginManager> DYNAMIC_INSTANCES = new HashMap<>();
 
     private final Login login;
-    private final Object cacheKey;
+    private final LoginMetadata<?> loginMetadata;
+    private final AuthenticateCallbackHandler loginCallbackHandler;
     private int refCount;
 
-    private LoginManager(JaasContext jaasContext, boolean hasKerberos, Map<String, ?> configs,
-                         Object cacheKey) throws IOException, LoginException {
-        this.cacheKey = cacheKey;
-        login = hasKerberos ? new KerberosLogin() : new DefaultLogin();
-        login.configure(configs, jaasContext);
+    private LoginManager(JaasContext jaasContext, String saslMechanism, Map<String, ?> configs,
+                         LoginMetadata<?> loginMetadata) throws IOException, LoginException {
+        this.loginMetadata = loginMetadata;
+        this.login = Utils.newInstance(loginMetadata.loginClass);
+        loginCallbackHandler = Utils.newInstance(loginMetadata.loginCallbackClass);
+        loginCallbackHandler.configure(configs, saslMechanism, jaasContext.configurationEntries());
+        login.configure(configs, jaasContext.name(), jaasContext.configuration(), loginCallbackHandler);
         login.login();
     }
 
@@ -72,28 +81,34 @@ private LoginManager(JaasContext jaasContext, boolean hasKerberos, Map<String, ?
      * @param saslMechanism SASL mechanism for which login manager is being acquired. For dynamic contexts, the single
      *                      login module in `jaasContext` corresponds to this SASL mechanism. Hence `Login` class is
      *                      chosen based on this mechanism.
-     * @param hasKerberos Boolean flag that indicates if Kerberos is enabled for the server listener or client. Since
-     *                    static broker configuration may contain multiple login modules in a login context, KerberosLogin
-     *                    must be used if Kerberos is enabled on the listener, even if `saslMechanism` is not GSSAPI.
+     * @param defaultLoginClass Default login class to use if an override is not specified in `configs`
      * @param configs Config options used to configure `Login` if a new login manager is created.
      *
      */
-    public static LoginManager acquireLoginManager(JaasContext jaasContext, String saslMechanism, boolean hasKerberos,
+    public static LoginManager acquireLoginManager(JaasContext jaasContext, String saslMechanism,
+                                                   Class<? extends Login> defaultLoginClass,
                                                    Map<String, ?> configs) throws IOException, LoginException {
+        Class<? extends Login> loginClass = configuredClassOrDefault(configs, jaasContext,
+                saslMechanism, SaslConfigs.SASL_LOGIN_CLASS, defaultLoginClass);
+        Class<? extends AuthenticateCallbackHandler> loginCallbackClass = configuredClassOrDefault(configs,
+                jaasContext, saslMechanism, SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS,
+                AbstractLogin.DefaultLoginCallbackHandler.class);
         synchronized (LoginManager.class) {
             LoginManager loginManager;
             Password jaasConfigValue = jaasContext.dynamicJaasConfig();
             if (jaasConfigValue != null) {
-                loginManager = DYNAMIC_INSTANCES.get(jaasConfigValue);
+                LoginMetadata<Password> loginMetadata = new LoginMetadata<>(jaasConfigValue, loginClass, loginCallbackClass);
+                loginManager = DYNAMIC_INSTANCES.get(loginMetadata);
                 if (loginManager == null) {
-                    loginManager = new LoginManager(jaasContext, saslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM), configs, jaasConfigValue);
-                    DYNAMIC_INSTANCES.put(jaasConfigValue, loginManager);
+                    loginManager = new LoginManager(jaasContext, saslMechanism, configs, loginMetadata);
+                    DYNAMIC_INSTANCES.put(loginMetadata, loginManager);
                 }
             } else {
-                loginManager = STATIC_INSTANCES.get(jaasContext.name());
+                LoginMetadata<String> loginMetadata = new LoginMetadata<>(jaasContext.name(), loginClass, loginCallbackClass);
+                loginManager = STATIC_INSTANCES.get(loginMetadata);
                 if (loginManager == null) {
-                    loginManager = new LoginManager(jaasContext, hasKerberos, configs, jaasContext.name());
-                    STATIC_INSTANCES.put(jaasContext.name(), loginManager);
+                    loginManager = new LoginManager(jaasContext, saslMechanism, configs, loginMetadata);
+                    STATIC_INSTANCES.put(loginMetadata, loginManager);
                 }
             }
             return loginManager.acquire();
@@ -110,7 +125,7 @@ public String serviceName() {
 
     // Only for testing
     Object cacheKey() {
-        return cacheKey;
+        return loginMetadata.configInfo;
     }
 
     private LoginManager acquire() {
@@ -127,12 +142,13 @@ public void release() {
             if (refCount == 0)
                 throw new IllegalStateException("release() called on disposed " + this);
             else if (refCount == 1) {
-                if (cacheKey instanceof Password) {
-                    DYNAMIC_INSTANCES.remove(cacheKey);
+                if (loginMetadata.configInfo instanceof Password) {
+                    DYNAMIC_INSTANCES.remove(loginMetadata);
                 } else {
-                    STATIC_INSTANCES.remove(cacheKey);
+                    STATIC_INSTANCES.remove(loginMetadata);
                 }
                 login.close();
+                loginCallbackHandler.close();
             }
             --refCount;
             LOGGER.trace("{} released", this);
@@ -150,10 +166,56 @@ public String toString() {
     /* Should only be used in tests. */
     public static void closeAll() {
         synchronized (LoginManager.class) {
-            for (String key : new ArrayList<>(STATIC_INSTANCES.keySet()))
+            for (LoginMetadata<String> key : new ArrayList<>(STATIC_INSTANCES.keySet()))
                 STATIC_INSTANCES.remove(key).login.close();
-            for (Password key : new ArrayList<>(DYNAMIC_INSTANCES.keySet()))
+            for (LoginMetadata<Password> key : new ArrayList<>(DYNAMIC_INSTANCES.keySet()))
                 DYNAMIC_INSTANCES.remove(key).login.close();
         }
     }
+
+    private static <T> Class<? extends T> configuredClassOrDefault(Map<String, ?> configs,
+                                                     JaasContext jaasContext,
+                                                     String saslMechanism,
+                                                     String configName,
+                                                     Class<? extends T> defaultClass) {
+        String prefix  = jaasContext.type() == JaasContext.Type.SERVER ? ListenerName.saslMechanismPrefix(saslMechanism) : "";
+        Class<? extends T> clazz = (Class<? extends T>) configs.get(prefix + configName);
+        if (clazz != null && jaasContext.configurationEntries().size() != 1) {
+            String errorMessage = configName + " cannot be specified with multiple login modules in the JAAS context. " +
+                    SaslConfigs.SASL_JAAS_CONFIG + " must be configured to override mechanism-specific configs.";
+            throw new ConfigException(errorMessage);
+        }
+        if (clazz == null)
+            clazz = defaultClass;
+        return clazz;
+    }
+
+    private static class LoginMetadata<T> {
+        final T configInfo;
+        final Class<? extends Login> loginClass;
+        final Class<? extends AuthenticateCallbackHandler> loginCallbackClass;
+
+        LoginMetadata(T configInfo, Class<? extends Login> loginClass,
+                      Class<? extends AuthenticateCallbackHandler> loginCallbackClass) {
+            this.configInfo = configInfo;
+            this.loginClass = loginClass;
+            this.loginCallbackClass = loginCallbackClass;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(configInfo, loginClass, loginCallbackClass);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            LoginMetadata<?> loginMetadata = (LoginMetadata<?>) o;
+            return Objects.equals(configInfo, loginMetadata.configInfo) &&
+                   Objects.equals(loginClass, loginMetadata.loginClass) &&
+                   Objects.equals(loginCallbackClass, loginMetadata.loginCallbackClass);
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
index 8b0116563d8..2ef6d77f13f 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
@@ -24,9 +24,8 @@
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
 import org.apache.kafka.common.network.Authenticator;
-import org.apache.kafka.common.network.Mode;
-import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.NetworkSend;
+import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.Send;
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
@@ -41,6 +40,7 @@
 import org.apache.kafka.common.requests.SaslAuthenticateResponse;
 import org.apache.kafka.common.requests.SaslHandshakeRequest;
 import org.apache.kafka.common.requests.SaslHandshakeResponse;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
@@ -87,7 +87,7 @@
     private final SaslClient saslClient;
     private final Map<String, ?> configs;
     private final String clientPrincipalName;
-    private final AuthCallbackHandler callbackHandler;
+    private final AuthenticateCallbackHandler callbackHandler;
 
     // buffers used in `authenticate`
     private NetworkReceive netInBuffer;
@@ -105,6 +105,7 @@
     private short saslAuthenticateVersion;
 
     public SaslClientAuthenticator(Map<String, ?> configs,
+                                   AuthenticateCallbackHandler callbackHandler,
                                    String node,
                                    Subject subject,
                                    String servicePrincipal,
@@ -114,6 +115,7 @@ public SaslClientAuthenticator(Map<String, ?> configs,
                                    TransportLayer transportLayer) throws IOException {
         this.node = node;
         this.subject = subject;
+        this.callbackHandler = callbackHandler;
         this.host = host;
         this.servicePrincipal = servicePrincipal;
         this.mechanism = mechanism;
@@ -133,9 +135,6 @@ public SaslClientAuthenticator(Map<String, ?> configs,
             else
                 this.clientPrincipalName = null;
 
-            callbackHandler = new SaslClientCallbackHandler();
-            callbackHandler.configure(configs, Mode.CLIENT, subject, mechanism);
-
             saslClient = createSaslClient();
         } catch (Exception e) {
             throw new SaslAuthenticationException("Failed to configure SaslClientAuthenticator", e);
@@ -325,8 +324,6 @@ public boolean complete() {
     public void close() throws IOException {
         if (saslClient != null)
             saslClient.dispose();
-        if (callbackHandler != null)
-            callbackHandler.close();
     }
 
     private byte[] receiveToken() throws IOException {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
index 31c51c22ca3..5b2a28181cd 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
+import java.security.AccessController;
+import java.util.List;
 import java.util.Map;
 
 import javax.security.auth.Subject;
@@ -23,52 +25,46 @@
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
 
 import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.network.Mode;
 import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 
 /**
- * Callback handler for Sasl clients. The callbacks required for the SASL mechanism
+ * Default callback handler for Sasl clients. The callbacks required for the SASL mechanism
  * configured for the client should be supported by this callback handler. See
  * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html">Java SASL API</a>
  * for the list of SASL callback handlers required for each SASL mechanism.
  */
-public class SaslClientCallbackHandler implements AuthCallbackHandler {
+public class SaslClientCallbackHandler implements AuthenticateCallbackHandler {
 
-    private boolean isKerberos;
-    private Subject subject;
+    private String mechanism;
 
     @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String mechanism) {
-        this.isKerberos = mechanism.equals(SaslConfigs.GSSAPI_MECHANISM);
-        this.subject = subject;
+    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.mechanism  = saslMechanism;
     }
 
     @Override
     public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
+        Subject subject = Subject.getSubject(AccessController.getContext());
         for (Callback callback : callbacks) {
             if (callback instanceof NameCallback) {
                 NameCallback nc = (NameCallback) callback;
-                if (!isKerberos && subject != null && !subject.getPublicCredentials(String.class).isEmpty()) {
+                if (subject != null && !subject.getPublicCredentials(String.class).isEmpty()) {
                     nc.setName(subject.getPublicCredentials(String.class).iterator().next());
                 } else
                     nc.setName(nc.getDefaultName());
             } else if (callback instanceof PasswordCallback) {
-                if (!isKerberos && subject != null && !subject.getPrivateCredentials(String.class).isEmpty()) {
+                if (subject != null && !subject.getPrivateCredentials(String.class).isEmpty()) {
                     char[] password = subject.getPrivateCredentials(String.class).iterator().next().toCharArray();
                     ((PasswordCallback) callback).setPassword(password);
                 } else {
                     String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" +
                              " client code does not currently support obtaining a password from the user.";
-                    if (isKerberos) {
-                        errorMessage += " Make sure -Djava.security.auth.login.config property passed to JVM and" +
-                             " the client is configured to use a ticket cache (using" +
-                             " the JAAS configuration setting 'useTicketCache=true)'. Make sure you are using" +
-                             " FQDN of the Kafka broker you are trying to connect to.";
-                    }
                     throw new UnsupportedCallbackException(callback, errorMessage);
                 }
             } else if (callback instanceof RealmCallback) {
@@ -83,7 +79,7 @@ public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
                     ac.setAuthorizedID(authzId);
             } else if (callback instanceof ScramExtensionsCallback) {
                 ScramExtensionsCallback sc = (ScramExtensionsCallback) callback;
-                if (!isKerberos && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
+                if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) {
                     sc.extensions((Map<String, String>) subject.getPublicCredentials(Map.class).iterator().next());
                 }
             }  else {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 2a80e5bc0e4..5140afb196d 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -28,7 +28,6 @@
 import org.apache.kafka.common.network.Authenticator;
 import org.apache.kafka.common.network.ChannelBuilders;
 import org.apache.kafka.common.network.ListenerName;
-import org.apache.kafka.common.network.Mode;
 import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.network.NetworkSend;
 import org.apache.kafka.common.network.Send;
@@ -46,18 +45,15 @@
 import org.apache.kafka.common.requests.SaslAuthenticateResponse;
 import org.apache.kafka.common.requests.SaslHandshakeRequest;
 import org.apache.kafka.common.requests.SaslHandshakeResponse;
-import org.apache.kafka.common.security.JaasContext;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
 import org.apache.kafka.common.security.auth.SaslAuthenticationContext;
 import org.apache.kafka.common.security.kerberos.KerberosName;
 import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
-import org.apache.kafka.common.security.scram.ScramCredential;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
-import org.apache.kafka.common.security.scram.ScramServerCallbackHandler;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.apache.kafka.common.utils.Utils;
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.ietf.jgss.GSSContext;
 import org.ietf.jgss.GSSCredential;
 import org.ietf.jgss.GSSException;
@@ -102,14 +98,12 @@
     private final SecurityProtocol securityProtocol;
     private final ListenerName listenerName;
     private final String connectionId;
-    private final Map<String, JaasContext> jaasContexts;
     private final Map<String, Subject> subjects;
-    private final CredentialCache credentialCache;
     private final TransportLayer transportLayer;
     private final Set<String> enabledMechanisms;
     private final Map<String, ?> configs;
     private final KafkaPrincipalBuilder principalBuilder;
-    private final DelegationTokenCache tokenCache;
+    private final Map<String, AuthenticateCallbackHandler> callbackHandlers;
 
     // Current SASL state
     private SaslState saslState = SaslState.INITIAL_REQUEST;
@@ -119,7 +113,6 @@
     private AuthenticationException pendingException = null;
     private SaslServer saslServer;
     private String saslMechanism;
-    private AuthCallbackHandler callbackHandler;
 
     // buffers used in `authenticate`
     private NetworkReceive netInBuffer;
@@ -128,23 +121,19 @@
     private boolean enableKafkaSaslAuthenticateHeaders;
 
     public SaslServerAuthenticator(Map<String, ?> configs,
+                                   Map<String, AuthenticateCallbackHandler> callbackHandlers,
                                    String connectionId,
-                                   Map<String, JaasContext> jaasContexts,
                                    Map<String, Subject> subjects,
                                    KerberosShortNamer kerberosNameParser,
-                                   CredentialCache credentialCache,
                                    ListenerName listenerName,
                                    SecurityProtocol securityProtocol,
-                                   TransportLayer transportLayer,
-                                   DelegationTokenCache tokenCache) throws IOException {
+                                   TransportLayer transportLayer) throws IOException {
+        this.callbackHandlers = callbackHandlers;
         this.connectionId = connectionId;
-        this.jaasContexts = jaasContexts;
         this.subjects = subjects;
-        this.credentialCache = credentialCache;
         this.listenerName = listenerName;
         this.securityProtocol = securityProtocol;
         this.enableKafkaSaslAuthenticateHeaders = false;
-        this.tokenCache = tokenCache;
         this.transportLayer = transportLayer;
 
         this.configs = configs;
@@ -154,8 +143,8 @@ public SaslServerAuthenticator(Map<String, ?> configs,
             throw new IllegalArgumentException("No SASL mechanisms are enabled");
         this.enabledMechanisms = new HashSet<>(enabledMechanisms);
         for (String mechanism : enabledMechanisms) {
-            if (!jaasContexts.containsKey(mechanism))
-                throw new IllegalArgumentException("Jaas context not specified for SASL mechanism " + mechanism);
+            if (!callbackHandlers.containsKey(mechanism))
+                throw new IllegalArgumentException("Callback handler not specified for SASL mechanism " + mechanism);
             if (!subjects.containsKey(mechanism))
                 throw new IllegalArgumentException("Subject cannot be null for SASL mechanism " + mechanism);
         }
@@ -168,11 +157,7 @@ public SaslServerAuthenticator(Map<String, ?> configs,
     private void createSaslServer(String mechanism) throws IOException {
         this.saslMechanism = mechanism;
         Subject subject = subjects.get(mechanism);
-        if (!ScramMechanism.isScram(mechanism))
-            callbackHandler = new SaslServerCallbackHandler(jaasContexts.get(mechanism));
-        else
-            callbackHandler = new ScramServerCallbackHandler(credentialCache.cache(mechanism, ScramCredential.class), tokenCache);
-        callbackHandler.configure(configs, Mode.SERVER, subject, saslMechanism);
+        final AuthenticateCallbackHandler callbackHandler = callbackHandlers.get(mechanism);
         if (mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) {
             saslServer = createSaslKerberosServer(callbackHandler, configs, subject);
         } else {
@@ -189,7 +174,7 @@ public SaslServer run() throws SaslException {
         }
     }
 
-    private SaslServer createSaslKerberosServer(final AuthCallbackHandler saslServerCallbackHandler, final Map<String, ?> configs, Subject subject) throws IOException {
+    private SaslServer createSaslKerberosServer(final AuthenticateCallbackHandler saslServerCallbackHandler, final Map<String, ?> configs, Subject subject) throws IOException {
         // server is using a JAAS-authenticated subject: determine service principal name and hostname from kafka server's subject.
         final String servicePrincipal = SaslClientAuthenticator.firstPrincipal(subject);
         KerberosName kerberosName;
@@ -316,8 +301,6 @@ public void close() throws IOException {
             Utils.closeQuietly((Closeable) principalBuilder, "principal builder");
         if (saslServer != null)
             saslServer.dispose();
-        if (callbackHandler != null)
-            callbackHandler.close();
     }
 
     private void setSaslState(SaslState saslState) throws IOException {
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java
index 7d5372db977..d3d43cbfa26 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java
@@ -16,51 +16,46 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
-import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 
-import org.apache.kafka.common.security.JaasContext;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import javax.security.auth.Subject;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
 
-import org.apache.kafka.common.network.Mode;
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
- * Callback handler for Sasl servers. The callbacks required for all the SASL
+ * Default callback handler for Sasl servers. The callbacks required for all the SASL
  * mechanisms enabled in the server should be supported by this callback handler. See
  * <a href="https://docs.oracle.com/javase/8/docs/technotes/guides/security/sasl/sasl-refguide.html">Java SASL API</a>
  * for the list of SASL callback handlers required for each SASL mechanism.
  */
-public class SaslServerCallbackHandler implements AuthCallbackHandler {
+public class SaslServerCallbackHandler implements AuthenticateCallbackHandler {
     private static final Logger LOG = LoggerFactory.getLogger(SaslServerCallbackHandler.class);
-    private final JaasContext jaasContext;
 
-    public SaslServerCallbackHandler(JaasContext jaasContext) throws IOException {
-        this.jaasContext = jaasContext;
-    }
+    private String mechanism;
 
     @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String saslMechanism) {
-    }
-
-    public JaasContext jaasContext() {
-        return jaasContext;
+    public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.mechanism = mechanism;
     }
 
     @Override
     public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
         for (Callback callback : callbacks) {
-            if (callback instanceof RealmCallback) {
+            if (callback instanceof RealmCallback)
                 handleRealmCallback((RealmCallback) callback);
-            } else if (callback instanceof AuthorizeCallback) {
+            else if (callback instanceof AuthorizeCallback && mechanism.equals(SaslConfigs.GSSAPI_MECHANISM))
                 handleAuthorizeCallback((AuthorizeCallback) callback);
-            }
+            else
+                throw new UnsupportedCallbackException(callback);
         }
     }
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java
new file mode 100644
index 00000000000..fa9cad261c5
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.security.kerberos;
+
+import org.apache.kafka.common.config.SaslConfigs;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Callback handler for SASL/GSSAPI clients.
+ */
+public class KerberosClientCallbackHandler implements AuthenticateCallbackHandler {
+
+    @Override
+    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        if (!saslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM))
+            throw new IllegalStateException("Kerberos callback handler should only be used with GSSAPI");
+    }
+
+    @Override
+    public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
+        for (Callback callback : callbacks) {
+            if (callback instanceof NameCallback) {
+                NameCallback nc = (NameCallback) callback;
+                nc.setName(nc.getDefaultName());
+            } else if (callback instanceof PasswordCallback) {
+                String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" +
+                             " client code does not currently support obtaining a password from the user.";
+                errorMessage += " Make sure -Djava.security.auth.login.config property passed to JVM and" +
+                             " the client is configured to use a ticket cache (using" +
+                             " the JAAS configuration setting 'useTicketCache=true)'. Make sure you are using" +
+                             " FQDN of the Kafka broker you are trying to connect to.";
+                throw new UnsupportedCallbackException(callback, errorMessage);
+            } else if (callback instanceof RealmCallback) {
+                RealmCallback rc = (RealmCallback) callback;
+                rc.setText(rc.getDefaultText());
+            } else if (callback instanceof AuthorizeCallback) {
+                AuthorizeCallback ac = (AuthorizeCallback) callback;
+                String authId = ac.getAuthenticationID();
+                String authzId = ac.getAuthorizationID();
+                ac.setAuthorized(authId.equals(authzId));
+                if (ac.isAuthorized())
+                    ac.setAuthorizedID(authzId);
+            }  else {
+                throw new UnsupportedCallbackException(callback, "Unrecognized SASL ClientCallback");
+            }
+        }
+    }
+
+    @Override
+    public void close() {
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java
index 65c3b1cbe44..ec996a8ccc6 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java
@@ -18,6 +18,7 @@
 
 import javax.security.auth.kerberos.KerberosPrincipal;
 import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.auth.login.Configuration;
 import javax.security.auth.login.LoginContext;
 import javax.security.auth.login.LoginException;
 import javax.security.auth.kerberos.KerberosTicket;
@@ -25,6 +26,7 @@
 
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.JaasUtils;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.authenticator.AbstractLogin;
 import org.apache.kafka.common.config.SaslConfigs;
 import org.apache.kafka.common.utils.KafkaThread;
@@ -33,11 +35,12 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Arrays;
 import java.util.Date;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 import java.util.Set;
-import java.util.Map;
 
 /**
  * This class is responsible for refreshing Kerberos credentials for
@@ -78,13 +81,15 @@
     private String serviceName;
     private long lastLogin;
 
-    public void configure(Map<String, ?> configs, JaasContext jaasContext) {
-        super.configure(configs, jaasContext);
+    @Override
+    public void configure(Map<String, ?> configs, String contextName, Configuration configuration,
+                          AuthenticateCallbackHandler callbackHandler) {
+        super.configure(configs, contextName, configuration, callbackHandler);
         this.ticketRenewWindowFactor = (Double) configs.get(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR);
         this.ticketRenewJitter = (Double) configs.get(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER);
         this.minTimeBeforeRelogin = (Long) configs.get(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN);
         this.kinitCmd = (String) configs.get(SaslConfigs.SASL_KERBEROS_KINIT_CMD);
-        this.serviceName = getServiceName(configs, jaasContext);
+        this.serviceName = getServiceName(configs, contextName, configuration);
     }
 
     /**
@@ -99,13 +104,13 @@ public LoginContext login() throws LoginException {
         subject = loginContext.getSubject();
         isKrbTicket = !subject.getPrivateCredentials(KerberosTicket.class).isEmpty();
 
-        List<AppConfigurationEntry> entries = jaasContext().configurationEntries();
-        if (entries.isEmpty()) {
+        AppConfigurationEntry[] entries = configuration().getAppConfigurationEntry(contextName());
+        if (entries.length == 0) {
             isUsingTicketCache = false;
             principal = null;
         } else {
             // there will only be a single entry
-            AppConfigurationEntry entry = entries.get(0);
+            AppConfigurationEntry entry = entries[0];
             if (entry.getOptions().get("useTicketCache") != null) {
                 String val = (String) entry.getOptions().get("useTicketCache");
                 isUsingTicketCache = val.equals("true");
@@ -280,8 +285,9 @@ public String serviceName() {
         return serviceName;
     }
 
-    private static String getServiceName(Map<String, ?> configs, JaasContext jaasContext) {
-        String jaasServiceName = jaasContext.configEntryOption(JaasUtils.SERVICE_NAME, null);
+    private static String getServiceName(Map<String, ?> configs, String contextName, Configuration configuration) {
+        List<AppConfigurationEntry> configEntries = Arrays.asList(configuration.getAppConfigurationEntry(contextName));
+        String jaasServiceName = JaasContext.configEntryOption(configEntries, JaasUtils.SERVICE_NAME, null);
         String configServiceName = (String) configs.get(SaslConfigs.SASL_KERBEROS_SERVICE_NAME);
         if (jaasServiceName != null && configServiceName != null && !jaasServiceName.equals(configServiceName)) {
             String message = String.format("Conflicting serviceName values found in JAAS and Kafka configs " +
@@ -360,7 +366,7 @@ private void reLogin() throws LoginException {
             loginContext.logout();
             //login and also update the subject field of this instance to
             //have the new credentials (pass it to the LoginContext constructor)
-            loginContext = new LoginContext(jaasContext().name(), subject, null, jaasContext().configuration());
+            loginContext = new LoginContext(contextName(), subject, null, configuration());
             log.info("Initiating re-login for {}", principal);
             loginContext.login();
         }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java
new file mode 100644
index 00000000000..7f42645e487
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.plain;
+
+import javax.security.auth.callback.Callback;
+
+/*
+ * Authentication callback for SASL/PLAIN authentication. Callback handler must
+ * set authenticated flag to true if the client provided password in the callback
+ * matches the expected password.
+ */
+public class PlainAuthenticateCallback implements Callback {
+    private final char[] password;
+    private boolean authenticated;
+
+    /**
+     * Creates a callback with the password provided by the client
+     * @param password The password provided by the client during SASL/PLAIN authentication
+     */
+    public PlainAuthenticateCallback(char[] password) {
+        this.password = password;
+    }
+
+    /**
+     * Returns the password provided by the client during SASL/PLAIN authentication
+     */
+    public char[] password() {
+        return password;
+    }
+
+    /**
+     * Returns true if client password matches expected password, false otherwise.
+     * This state is set the server-side callback handler.
+     */
+    public boolean authenticated() {
+        return this.authenticated;
+    }
+
+    /**
+     * Sets the authenticated state. This is set by the server-side callback handler
+     * by matching the client provided password with the expected password.
+     *
+     * @param authenticated true indicates successful authentication
+     */
+    public void authenticated(boolean authenticated) {
+        this.authenticated = authenticated;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java
index c8b29fc69a9..f0a5971ebeb 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.security.plain;
 
+import org.apache.kafka.common.security.plain.internal.PlainSaslServerProvider;
+
 import java.util.Map;
 
 import javax.security.auth.Subject;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServer.java
similarity index 87%
rename from clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServer.java
rename to clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServer.java
index e54887f652f..811d9e94aca 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServer.java
@@ -14,21 +14,22 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.plain;
+package org.apache.kafka.common.security.plain.internal;
 
 import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
 import java.util.Map;
 
+import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
 import javax.security.sasl.Sasl;
 import javax.security.sasl.SaslException;
 import javax.security.sasl.SaslServer;
 import javax.security.sasl.SaslServerFactory;
 
 import org.apache.kafka.common.errors.SaslAuthenticationException;
-import org.apache.kafka.common.security.JaasContext;
-import org.apache.kafka.common.security.authenticator.SaslServerCallbackHandler;
+import org.apache.kafka.common.security.plain.PlainAuthenticateCallback;
 
 /**
  * Simple SaslServer implementation for SASL/PLAIN. In order to make this implementation
@@ -46,15 +47,13 @@
 public class PlainSaslServer implements SaslServer {
 
     public static final String PLAIN_MECHANISM = "PLAIN";
-    private static final String JAAS_USER_PREFIX = "user_";
-
-    private final JaasContext jaasContext;
 
+    private final CallbackHandler callbackHandler;
     private boolean complete;
     private String authorizationId;
 
-    public PlainSaslServer(JaasContext jaasContext) {
-        this.jaasContext = jaasContext;
+    public PlainSaslServer(CallbackHandler callbackHandler) {
+        this.callbackHandler = callbackHandler;
     }
 
     /**
@@ -101,12 +100,15 @@ public PlainSaslServer(JaasContext jaasContext) {
             throw new SaslException("Authentication failed: password not specified");
         }
 
-        String expectedPassword = jaasContext.configEntryOption(JAAS_USER_PREFIX + username,
-                PlainLoginModule.class.getName());
-        if (!password.equals(expectedPassword)) {
-            throw new SaslAuthenticationException("Authentication failed: Invalid username or password");
+        NameCallback nameCallback = new NameCallback("username", username);
+        PlainAuthenticateCallback authenticateCallback = new PlainAuthenticateCallback(password.toCharArray());
+        try {
+            callbackHandler.handle(new Callback[]{nameCallback, authenticateCallback});
+        } catch (Throwable e) {
+            throw new SaslAuthenticationException("Authentication failed: credentials for user could not be verified", e);
         }
-
+        if (!authenticateCallback.authenticated())
+            throw new SaslAuthenticationException("Authentication failed: Invalid username or password");
         if (!authorizationIdFromClient.isEmpty() && !authorizationIdFromClient.equals(username))
             throw new SaslAuthenticationException("Authentication failed: Client requested an authorization id that is different from username");
 
@@ -167,10 +169,7 @@ public SaslServer createSaslServer(String mechanism, String protocol, String ser
             if (!PLAIN_MECHANISM.equals(mechanism))
                 throw new SaslException(String.format("Mechanism \'%s\' is not supported. Only PLAIN is supported.", mechanism));
 
-            if (!(cbh instanceof SaslServerCallbackHandler))
-                throw new SaslException("CallbackHandler must be of type SaslServerCallbackHandler, but it is: " + cbh.getClass());
-
-            return new PlainSaslServer(((SaslServerCallbackHandler) cbh).jaasContext());
+            return new PlainSaslServer(cbh);
         }
 
         @Override
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerProvider.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServerProvider.java
rename to clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerProvider.java
index ae1424417f8..c2229532449 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainSaslServerProvider.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerProvider.java
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.plain;
+package org.apache.kafka.common.security.plain.internal;
 
 import java.security.Provider;
 import java.security.Security;
 
-import org.apache.kafka.common.security.plain.PlainSaslServer.PlainSaslServerFactory;
+import org.apache.kafka.common.security.plain.internal.PlainSaslServer.PlainSaslServerFactory;
 
 public class PlainSaslServerProvider extends Provider {
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainServerCallbackHandler.java
new file mode 100644
index 00000000000..84fbdfd4c62
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internal/PlainServerCallbackHandler.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.security.plain.internal;
+
+import org.apache.kafka.common.security.JaasContext;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.security.plain.PlainAuthenticateCallback;
+import org.apache.kafka.common.security.plain.PlainLoginModule;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
+
+public class PlainServerCallbackHandler implements AuthenticateCallbackHandler {
+
+    private static final String JAAS_USER_PREFIX = "user_";
+    private List<AppConfigurationEntry> jaasConfigEntries;
+
+    @Override
+    public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.jaasConfigEntries = jaasConfigEntries;
+    }
+
+    @Override
+    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+        String username = null;
+        for (Callback callback: callbacks) {
+            if (callback instanceof NameCallback)
+                username = ((NameCallback) callback).getDefaultName();
+            else if (callback instanceof PlainAuthenticateCallback) {
+                PlainAuthenticateCallback plainCallback = (PlainAuthenticateCallback) callback;
+                boolean authenticated = authenticate(username, plainCallback.password());
+                plainCallback.authenticated(authenticated);
+            } else
+                throw new UnsupportedCallbackException(callback);
+        }
+    }
+
+    protected boolean authenticate(String username, char[] password) throws IOException {
+        if (username == null)
+            return false;
+        else {
+            String expectedPassword = JaasContext.configEntryOption(jaasConfigEntries,
+                    JAAS_USER_PREFIX + username,
+                    PlainLoginModule.class.getName());
+            return expectedPassword != null && Arrays.equals(password, expectedPassword.toCharArray());
+        }
+    }
+
+    @Override
+    public void close() throws KafkaException {
+    }
+
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java
index 09ff0aacd3c..dfbfef15b4f 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java
@@ -16,6 +16,11 @@
  */
 package org.apache.kafka.common.security.scram;
 
+/**
+ * SCRAM credential class that encapsulates the credential data persisted for each user that is
+ * accessible to the server. See <a href="https://tools.ietf.org/html/rfc5802#section-5">RFC rfc5802</a>
+ * for details.
+ */
 public class ScramCredential {
 
     private final byte[] salt;
@@ -23,6 +28,9 @@
     private final byte[] storedKey;
     private final int iterations;
 
+    /**
+     * Constructs a new credential.
+     */
     public ScramCredential(byte[] salt, byte[] storedKey, byte[] serverKey, int iterations) {
         this.salt = salt;
         this.serverKey = serverKey;
@@ -30,18 +38,30 @@ public ScramCredential(byte[] salt, byte[] storedKey, byte[] serverKey, int iter
         this.iterations = iterations;
     }
 
+    /**
+     * Returns the salt used to process this credential using the SCRAM algorithm.
+     */
     public byte[] salt() {
         return salt;
     }
 
+    /**
+     * Server key computed from the client password using the SCRAM algorithm.
+     */
     public byte[] serverKey() {
         return serverKey;
     }
 
+    /**
+     * Stored key computed from the client password using the SCRAM algorithm.
+     */
     public byte[] storedKey() {
         return storedKey;
     }
 
+    /**
+     * Number of iterations used to process this credential using the SCRAM algorithm.
+     */
     public int iterations() {
         return iterations;
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java
index 931210a0e07..d5988cbeb89 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java
@@ -18,14 +18,23 @@
 
 import javax.security.auth.callback.Callback;
 
+/**
+ * Callback used for SCRAM mechanisms.
+ */
 public class ScramCredentialCallback implements Callback {
     private ScramCredential scramCredential;
 
-    public ScramCredential scramCredential() {
-        return scramCredential;
-    }
-
+    /**
+     * Sets the SCRAM credential for this instance.
+     */
     public void scramCredential(ScramCredential scramCredential) {
         this.scramCredential = scramCredential;
     }
-}
\ No newline at end of file
+
+    /**
+     * Returns the SCRAM credential if set on this instance.
+     */
+    public ScramCredential scramCredential() {
+        return scramCredential;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
index b40468bd650..debe163e36b 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java
@@ -21,13 +21,25 @@
 import java.util.Collections;
 import java.util.Map;
 
+/**
+ * Optional callback used for SCRAM mechanisms if any extensions need to be set
+ * in the SASL/SCRAM exchange.
+ */
 public class ScramExtensionsCallback implements Callback {
     private Map<String, String> extensions = Collections.emptyMap();
 
+    /**
+     * Returns the extension names and values that are sent by the client to
+     * the server in the initial client SCRAM authentication message.
+     * Default is an empty map.
+     */
     public Map<String, String> extensions() {
         return extensions;
     }
 
+    /**
+     * Sets the SCRAM extensions on this callback.
+     */
     public void extensions(Map<String, String> extensions) {
         this.extensions = extensions;
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java
index 43df515256f..20d1f221b1c 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java
@@ -16,6 +16,9 @@
  */
 package org.apache.kafka.common.security.scram;
 
+import org.apache.kafka.common.security.scram.internal.ScramSaslClientProvider;
+import org.apache.kafka.common.security.scram.internal.ScramSaslServerProvider;
+
 import java.util.Collections;
 import java.util.Map;
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialUtils.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtils.java
similarity index 96%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialUtils.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtils.java
index b4875d6f57e..91e28a62e1d 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialUtils.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtils.java
@@ -14,12 +14,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.util.Collection;
 import java.util.Properties;
 
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
 import org.apache.kafka.common.utils.Base64;
 
 /**
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensions.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramExtensions.java
similarity index 95%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensions.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramExtensions.java
index 0f461c0e3c2..66d9362a3e6 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensions.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramExtensions.java
@@ -14,7 +14,9 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
+
+import org.apache.kafka.common.security.scram.ScramLoginModule;
 
 import java.util.Collections;
 import java.util.HashMap;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramFormatter.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramFormatter.java
similarity index 94%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramFormatter.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramFormatter.java
index 406c28599b3..6fcb7a152db 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramFormatter.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramFormatter.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
@@ -27,9 +27,10 @@
 import javax.crypto.spec.SecretKeySpec;
 
 import org.apache.kafka.common.KafkaException;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 
 /**
  * Scram message salt and hash functions defined in <a href="https://tools.ietf.org/html/rfc5802">RFC 5802</a>.
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMechanism.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMechanism.java
similarity index 97%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramMechanism.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMechanism.java
index d8c0c6d3d60..73be4cf24fb 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMechanism.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMechanism.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.util.Collection;
 import java.util.Collections;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMessages.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMessages.java
similarity index 99%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramMessages.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMessages.java
index 05b3d775540..439b274f561 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramMessages.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramMessages.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import org.apache.kafka.common.utils.Base64;
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClient.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClient.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClient.java
index 71109df2ddd..a98a86d61cb 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClient.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClient.java
@@ -14,9 +14,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
@@ -34,9 +33,10 @@
 import javax.security.sasl.SaslException;
 
 import org.apache.kafka.common.errors.IllegalSaslStateException;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.security.scram.ScramExtensionsCallback;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -100,9 +100,15 @@ public boolean hasInitialResponse() {
                     ScramExtensionsCallback extensionsCallback = new ScramExtensionsCallback();
 
                     try {
-                        callbackHandler.handle(new Callback[]{nameCallback, extensionsCallback});
-                    } catch (IOException | UnsupportedCallbackException e) {
-                        throw new SaslException("User name could not be obtained", e);
+                        callbackHandler.handle(new Callback[]{nameCallback});
+                        try {
+                            callbackHandler.handle(new Callback[]{extensionsCallback});
+                        } catch (UnsupportedCallbackException e) {
+                            log.debug("Extensions callback is not supported by client callback handler {}, no extensions will be added",
+                                    callbackHandler);
+                        }
+                    } catch (Throwable e) {
+                        throw new SaslException("User name or extensions could not be obtained", e);
                     }
 
                     String username = nameCallback.getName();
@@ -121,7 +127,7 @@ public boolean hasInitialResponse() {
                     PasswordCallback passwordCallback = new PasswordCallback("Password:", false);
                     try {
                         callbackHandler.handle(new Callback[]{passwordCallback});
-                    } catch (IOException | UnsupportedCallbackException e) {
+                    } catch (Throwable e) {
                         throw new SaslException("User name could not be obtained", e);
                     }
                     this.clientFinalMessage = handleServerFirstMessage(passwordCallback.getPassword());
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClientProvider.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClientProvider.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClientProvider.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClientProvider.java
index d389f044a99..4d9ff81309b 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslClientProvider.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslClientProvider.java
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.security.Provider;
 import java.security.Security;
 
-import org.apache.kafka.common.security.scram.ScramSaslClient.ScramSaslClientFactory;
+import org.apache.kafka.common.security.scram.internal.ScramSaslClient.ScramSaslClientFactory;
 
 public class ScramSaslClientProvider extends Provider {
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServer.java
similarity index 92%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServer.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServer.java
index 314c1d413ba..deee0b8fb33 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServer.java
@@ -14,9 +14,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import java.io.IOException;
 import java.security.InvalidKeyException;
 import java.security.NoSuchAlgorithmException;
 import java.util.Arrays;
@@ -27,17 +26,20 @@
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
-import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.security.sasl.SaslException;
 import javax.security.sasl.SaslServer;
 import javax.security.sasl.SaslServerFactory;
 
+import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.IllegalSaslStateException;
 import org.apache.kafka.common.errors.SaslAuthenticationException;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.ScramCredentialCallback;
+import org.apache.kafka.common.security.scram.ScramLoginModule;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCredentialCallback;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
@@ -133,7 +135,9 @@ public ScramSaslServer(ScramMechanism mechanism, Map<String, ?> props, CallbackH
                                 scramCredential.iterations());
                         setState(State.RECEIVE_CLIENT_FINAL_MESSAGE);
                         return serverFirstMessage.toBytes();
-                    } catch (IOException | NumberFormatException | UnsupportedCallbackException e) {
+                    } catch (SaslException | AuthenticationException e) {
+                        throw e;
+                    } catch (Throwable e) {
                         throw new SaslException("Authentication failed: Credentials could not be obtained", e);
                     }
 
@@ -154,7 +158,7 @@ public ScramSaslServer(ScramMechanism mechanism, Map<String, ?> props, CallbackH
                 default:
                     throw new IllegalSaslStateException("Unexpected challenge in Sasl server state " + state);
             }
-        } catch (SaslException e) {
+        } catch (SaslException | AuthenticationException e) {
             clearCredentials();
             setState(State.FAILED);
             throw e;
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerProvider.java
similarity index 90%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServerProvider.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerProvider.java
index 9f2a6b3d256..099e50e19dd 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramSaslServerProvider.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerProvider.java
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.security.Provider;
 import java.security.Security;
 
-import org.apache.kafka.common.security.scram.ScramSaslServer.ScramSaslServerFactory;
+import org.apache.kafka.common.security.scram.internal.ScramSaslServer.ScramSaslServerFactory;
 
 public class ScramSaslServerProvider extends Provider {
 
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramServerCallbackHandler.java
similarity index 82%
rename from clients/src/main/java/org/apache/kafka/common/security/scram/ScramServerCallbackHandler.java
rename to clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramServerCallbackHandler.java
index 5e37eae9d4f..377aa3d3df5 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramServerCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internal/ScramServerCallbackHandler.java
@@ -14,23 +14,25 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 
-import javax.security.auth.Subject;
 import javax.security.auth.callback.Callback;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 
-import org.apache.kafka.common.network.Mode;
-import org.apache.kafka.common.security.authenticator.AuthCallbackHandler;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.ScramCredentialCallback;
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCredentialCallback;
 
-public class ScramServerCallbackHandler implements AuthCallbackHandler {
+public class ScramServerCallbackHandler implements AuthenticateCallbackHandler {
 
     private final CredentialCache.Cache<ScramCredential> credentialCache;
     private final DelegationTokenCache tokenCache;
@@ -42,6 +44,11 @@ public ScramServerCallbackHandler(CredentialCache.Cache<ScramCredential> credent
         this.tokenCache = tokenCache;
     }
 
+    @Override
+    public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        this.saslMechanism = mechanism;
+    }
+
     @Override
     public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
         String username = null;
@@ -60,11 +67,6 @@ else if (callback instanceof DelegationTokenCredentialCallback) {
         }
     }
 
-    @Override
-    public void configure(Map<String, ?> configs, Mode mode, Subject subject, String saslMechanism) {
-        this.saslMechanism = saslMechanism;
-    }
-
     @Override
     public void close() {
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java
index 78575b81bb6..adea210e678 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationTokenCache.java
@@ -19,8 +19,8 @@
 
 import org.apache.kafka.common.security.authenticator.CredentialCache;
 import org.apache.kafka.common.security.scram.ScramCredential;
-import org.apache.kafka.common.security.scram.ScramCredentialUtils;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramCredentialUtils;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 
 import java.util.Collection;
 import java.util.HashMap;
diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
index 0352adec72d..fab8e934d8e 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
@@ -24,8 +24,8 @@
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
-import org.apache.kafka.common.security.scram.ScramCredentialUtils;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.test.TestCondition;
@@ -79,8 +79,12 @@ public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtoco
         this.newChannels = Collections.synchronizedList(new ArrayList<SocketChannel>());
         this.credentialCache = credentialCache;
         this.tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames());
-        if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL)
-            ScramCredentialUtils.createCache(credentialCache, ScramMechanism.mechanismNames());
+        if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) {
+            for (String mechanism : ScramMechanism.mechanismNames()) {
+                if (credentialCache.cache(mechanism, ScramCredential.class) == null)
+                    credentialCache.createCache(mechanism, ScramCredential.class);
+            }
+        }
         if (channelBuilder == null)
             channelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, credentialCache, tokenCache);
         this.metrics = new Metrics();
diff --git a/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java b/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java
index 05294cfc5c2..81a883cebf9 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java
@@ -31,6 +31,9 @@
             .define(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Type.LIST,
                     BrokerSecurityConfigs.DEFAULT_SASL_ENABLED_MECHANISMS,
                     Importance.MEDIUM, BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_DOC)
+            .define(BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, Type.CLASS,
+                    null,
+                    Importance.MEDIUM, BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC)
             .define(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Type.CLASS,
                     null, Importance.MEDIUM, BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC)
             .withClientSslSupport()
diff --git a/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java
index a30c09ff316..fdf368788c0 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java
@@ -22,7 +22,7 @@
 import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder;
 import org.apache.kafka.common.security.kerberos.KerberosName;
 import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.easymock.EasyMock;
 import org.easymock.EasyMockSupport;
 import org.junit.Test;
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java
index 8be72fb5b8c..5436b2a1dfc 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java
@@ -59,17 +59,17 @@ public void testClientLoginManager() throws Exception {
         JaasContext staticContext = JaasContext.loadClientContext(Collections.<String, Object>emptyMap());
 
         LoginManager dynamicLogin = LoginManager.acquireLoginManager(dynamicContext, "PLAIN",
-                false, configs);
+                DefaultLogin.class, configs);
         assertEquals(dynamicPlainContext, dynamicLogin.cacheKey());
         LoginManager staticLogin = LoginManager.acquireLoginManager(staticContext, "SCRAM-SHA-256",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(dynamicLogin, staticLogin);
         assertEquals("KafkaClient", staticLogin.cacheKey());
 
         assertSame(dynamicLogin, LoginManager.acquireLoginManager(dynamicContext, "PLAIN",
-                false, configs));
+                DefaultLogin.class, configs));
         assertSame(staticLogin, LoginManager.acquireLoginManager(staticContext, "SCRAM-SHA-256",
-                false, configs));
+                DefaultLogin.class, configs));
 
         verifyLoginManagerRelease(dynamicLogin, 2, dynamicContext, configs);
         verifyLoginManagerRelease(staticLogin, 2, staticContext, configs);
@@ -86,23 +86,23 @@ public void testServerLoginManager() throws Exception {
         JaasContext scramJaasContext = JaasContext.loadServerContext(listenerName, "SCRAM-SHA-256", configs);
 
         LoginManager dynamicPlainLogin = LoginManager.acquireLoginManager(plainJaasContext, "PLAIN",
-                false, configs);
+                DefaultLogin.class, configs);
         assertEquals(dynamicPlainContext, dynamicPlainLogin.cacheKey());
         LoginManager dynamicDigestLogin = LoginManager.acquireLoginManager(digestJaasContext, "DIGEST-MD5",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(dynamicPlainLogin, dynamicDigestLogin);
         assertEquals(dynamicDigestContext, dynamicDigestLogin.cacheKey());
         LoginManager staticScramLogin = LoginManager.acquireLoginManager(scramJaasContext, "SCRAM-SHA-256",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(dynamicPlainLogin, staticScramLogin);
         assertEquals("KafkaServer", staticScramLogin.cacheKey());
 
         assertSame(dynamicPlainLogin, LoginManager.acquireLoginManager(plainJaasContext, "PLAIN",
-                false, configs));
+                DefaultLogin.class, configs));
         assertSame(dynamicDigestLogin, LoginManager.acquireLoginManager(digestJaasContext, "DIGEST-MD5",
-                false, configs));
+                DefaultLogin.class, configs));
         assertSame(staticScramLogin, LoginManager.acquireLoginManager(scramJaasContext, "SCRAM-SHA-256",
-                false, configs));
+                DefaultLogin.class, configs));
 
         verifyLoginManagerRelease(dynamicPlainLogin, 2, plainJaasContext, configs);
         verifyLoginManagerRelease(dynamicDigestLogin, 2, digestJaasContext, configs);
@@ -116,13 +116,13 @@ private void verifyLoginManagerRelease(LoginManager loginManager, int acquireCou
         for (int i = 0; i < acquireCount - 1; i++)
             loginManager.release();
         assertSame(loginManager, LoginManager.acquireLoginManager(jaasContext, "PLAIN",
-                false, configs));
+                DefaultLogin.class, configs));
 
         // Release all references and verify that new LoginManager is created on next acquire
         for (int i = 0; i < 2; i++) // release all references
             loginManager.release();
         LoginManager newLoginManager = LoginManager.acquireLoginManager(jaasContext, "PLAIN",
-                false, configs);
+                DefaultLogin.class, configs);
         assertNotSame(loginManager, newLoginManager);
         newLoginManager.release();
     }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
index b8edc612f69..bfd1d976d22 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java
@@ -16,6 +16,32 @@
  */
 package org.apache.kafka.common.security.authenticator;
 
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.nio.ByteBuffer;
+import java.nio.channels.SelectionKey;
+import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.security.auth.Subject;
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.Configuration;
+import javax.security.auth.login.AppConfigurationEntry;
+import javax.security.auth.login.LoginContext;
+import javax.security.auth.login.LoginException;
+
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.config.SaslConfigs;
@@ -37,6 +63,7 @@
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.security.auth.Login;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.AbstractResponse;
@@ -54,32 +81,20 @@
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
 import org.apache.kafka.common.security.scram.ScramCredential;
-import org.apache.kafka.common.security.scram.ScramCredentialUtils;
-import org.apache.kafka.common.security.scram.ScramFormatter;
+import org.apache.kafka.common.security.scram.internal.ScramCredentialUtils;
+import org.apache.kafka.common.security.scram.internal.ScramFormatter;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 import org.apache.kafka.common.security.token.delegation.TokenInformation;
 import org.apache.kafka.common.utils.SecurityUtils;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
+import org.apache.kafka.common.security.authenticator.TestDigestLoginModule.DigestServerCallbackHandler;
+import org.apache.kafka.common.security.plain.internal.PlainServerCallbackHandler;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import javax.security.auth.Subject;
-import javax.security.auth.login.Configuration;
-import java.io.IOException;
-import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
-import java.nio.channels.SelectionKey;
-import java.security.NoSuchAlgorithmException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -110,6 +125,7 @@ public void setup() throws Exception {
         saslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores);
         saslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores);
         credentialCache = new CredentialCache();
+        TestLogin.loginCount.set(0);
     }
 
     @After
@@ -234,6 +250,7 @@ public void testMechanismPluggability() throws Exception {
         String node = "0";
         SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
         configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5"));
+        configureDigestMd5ServerCallback(securityProtocol);
 
         server = createEchoServer(securityProtocol);
         createAndCheckClientConnection(securityProtocol, node);
@@ -247,6 +264,7 @@ public void testMechanismPluggability() throws Exception {
     public void testMultipleServerMechanisms() throws Exception {
         SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
         configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN", "SCRAM-SHA-256"));
+        configureDigestMd5ServerCallback(securityProtocol);
         server = createEchoServer(securityProtocol);
         updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
 
@@ -680,6 +698,207 @@ public void testInvalidLoginModule() throws Exception {
         }
     }
 
+    /**
+     * Tests SASL client authentication callback handler override.
+     */
+    @Test
+    public void testClientAuthenticateCallbackHandler() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        saslClientConfigs.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, TestClientCallbackHandler.class.getName());
+        jaasConfig.setClientOptions("PLAIN", "", ""); // remove username, password in login context
+
+        Map<String, Object> options = new HashMap<>();
+        options.put("user_" + TestClientCallbackHandler.USERNAME, TestClientCallbackHandler.PASSWORD);
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), options);
+        server = createEchoServer(securityProtocol);
+        createAndCheckClientConnection(securityProtocol, "good");
+
+        options.clear();
+        options.put("user_" + TestClientCallbackHandler.USERNAME, "invalid-password");
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), options);
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+    }
+
+    /**
+     * Tests SASL server authentication callback handler override.
+     */
+    @Test
+    public void testServerAuthenticateCallbackHandler() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), new HashMap<String, Object>());
+        String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN");
+        saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class.getName());
+        server = createEchoServer(securityProtocol);
+
+        // Set client username/password to the values used by `TestServerCallbackHandler`
+        jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
+        createAndCheckClientConnection(securityProtocol, "good");
+
+        // Set client username/password to the invalid values
+        jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalid-password");
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+    }
+
+    /**
+     * Test that callback handlers are only applied to connections for the mechanisms
+     * configured for the handler. Test enables two mechanisms 'PLAIN` and `DIGEST-MD5`
+     * on the servers with different callback handlers for the two mechanisms. Verifies
+     * that clients using both mechanisms authenticate successfully.
+     */
+    @Test
+    public void testAuthenticateCallbackHandlerMechanisms() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN"));
+
+        // Connections should fail using the digest callback handler if listener.mechanism prefix not specified
+        saslServerConfigs.put("plain." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class);
+        saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                DigestServerCallbackHandler.class);
+        server = createEchoServer(securityProtocol);
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+
+        // Connections should succeed using the server callback handler associated with the listener
+        ListenerName listener = ListenerName.forSecurityProtocol(securityProtocol);
+        saslServerConfigs.remove("plain." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
+        saslServerConfigs.remove("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
+        saslServerConfigs.put(listener.saslMechanismConfigPrefix("plain") + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class);
+        saslServerConfigs.put(listener.saslMechanismConfigPrefix("digest-md5") + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                DigestServerCallbackHandler.class);
+        server = createEchoServer(securityProtocol);
+
+        // Verify that DIGEST-MD5 (currently configured for client) works with `DigestServerCallbackHandler`
+        createAndCheckClientConnection(securityProtocol, "good-digest-md5");
+
+        // Verify that PLAIN works with `TestServerCallbackHandler`
+        jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
+        saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN");
+        createAndCheckClientConnection(securityProtocol, "good-plain");
+    }
+
+    /**
+     * Tests SASL login class override.
+     */
+    @Test
+    public void testClientLoginOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.setClientOptions("PLAIN", "invaliduser", "invalidpassword");
+        server = createEchoServer(securityProtocol);
+
+        // Connection should succeed using login override that sets correct username/password in Subject
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CLASS, TestLogin.class.getName());
+        createAndCheckClientConnection(securityProtocol, "1");
+        assertEquals(1, TestLogin.loginCount.get());
+
+        // Connection should fail without login override since username/password in jaas config is invalid
+        saslClientConfigs.remove(SaslConfigs.SASL_LOGIN_CLASS);
+        createAndCheckClientConnectionFailure(securityProtocol, "invalid");
+        assertEquals(1, TestLogin.loginCount.get());
+    }
+
+    /**
+     * Tests SASL server login class override.
+     */
+    @Test
+    public void testServerLoginOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        String prefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN");
+        saslServerConfigs.put(prefix + SaslConfigs.SASL_LOGIN_CLASS, TestLogin.class.getName());
+        server = createEchoServer(securityProtocol);
+
+        // Login is performed when server channel builder is created (before any connections are made on the server)
+        assertEquals(1, TestLogin.loginCount.get());
+
+        createAndCheckClientConnection(securityProtocol, "1");
+        assertEquals(1, TestLogin.loginCount.get());
+    }
+
+    /**
+     * Tests SASL login callback class override.
+     */
+    @Test
+    public void testClientLoginCallbackOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, TestPlainLoginModule.class.getName(),
+                Collections.<String, Object>emptyMap());
+        server = createEchoServer(securityProtocol);
+
+        // Connection should succeed using login callback override that sets correct username/password
+        saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, TestLoginCallbackHandler.class.getName());
+        createAndCheckClientConnection(securityProtocol, "1");
+
+        // Connection should fail without login callback override since username/password in jaas config is invalid
+        saslClientConfigs.remove(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        try {
+            createClientConnection(securityProtocol, "invalid");
+        } catch (Exception e) {
+            assertTrue("Unexpected exception " + e.getCause(), e.getCause() instanceof LoginException);
+        }
+    }
+
+    /**
+     * Tests SASL server login callback class override.
+     */
+    @Test
+    public void testServerLoginCallbackOverride() throws Exception {
+        SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
+        TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
+        jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, TestPlainLoginModule.class.getName(),
+                Collections.<String, Object>emptyMap());
+        jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
+        ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
+        String prefix = listenerName.saslMechanismConfigPrefix("PLAIN");
+        saslServerConfigs.put(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestServerCallbackHandler.class);
+        Class<?> loginCallback = TestLoginCallbackHandler.class;
+
+        try {
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with default login handler");
+        } catch (KafkaException e) {
+            // Expected exception
+        }
+
+        try {
+            saslServerConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with login handler config without listener+mechanism prefix");
+        } catch (KafkaException e) {
+            // Expected exception
+            saslServerConfigs.remove(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        }
+
+        try {
+            saslServerConfigs.put("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with login handler config without listener prefix");
+        } catch (KafkaException e) {
+            // Expected exception
+            saslServerConfigs.remove("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        }
+
+        try {
+            saslServerConfigs.put(listenerName.configPrefix() + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+            createEchoServer(securityProtocol);
+            fail("Should have failed to create server with login handler config without mechanism prefix");
+        } catch (KafkaException e) {
+            // Expected exception
+            saslServerConfigs.remove("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS);
+        }
+
+        // Connection should succeed using login callback override for mechanism
+        saslServerConfigs.put(prefix + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback);
+        server = createEchoServer(securityProtocol);
+        createAndCheckClientConnection(securityProtocol, "1");
+    }
+
     /**
      * Tests that mechanisms with default implementation in Kafka may be disabled in
      * the Kafka server by removing from the enabled mechanism list.
@@ -1028,10 +1247,12 @@ private NioEchoServer startServerWithoutSaslAuthenticateHeader(final SecurityPro
                 securityProtocol, listenerName, false, saslMechanism, true, credentialCache, null) {
 
             @Override
-            protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs, String id,
-                            TransportLayer transportLayer, Map<String, Subject> subjects) throws IOException {
-                return new SaslServerAuthenticator(configs, id, jaasContexts, subjects, null,
-                                credentialCache, listenerName, securityProtocol, transportLayer, null) {
+            protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs,
+                                                                       Map<String, AuthenticateCallbackHandler> callbackHandlers,
+                                                                       String id,
+                                                                       TransportLayer transportLayer,
+                                                                       Map<String, Subject> subjects) throws IOException {
+                return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, null, listenerName, securityProtocol, transportLayer) {
 
                     @Override
                     protected ApiVersionsResponse apiVersionsResponse() {
@@ -1072,11 +1293,15 @@ private void createClientConnectionWithoutSaslAuthenticateHeader(final SecurityP
                 securityProtocol, listenerName, false, saslMechanism, true, null, null) {
 
             @Override
-            protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs, String id,
-                    String serverHost, String servicePrincipal,
-                    TransportLayer transportLayer, Subject subject) throws IOException {
-
-                return new SaslClientAuthenticator(configs, id, subject,
+            protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
+                                                                       AuthenticateCallbackHandler callbackHandler,
+                                                                       String id,
+                                                                       String serverHost,
+                                                                       String servicePrincipal,
+                                                                       TransportLayer transportLayer,
+                                                                       Subject subject) throws IOException {
+
+                return new SaslClientAuthenticator(configs, callbackHandler, id, subject,
                         servicePrincipal, serverHost, saslMechanism, true, transportLayer) {
                     @Override
                     protected SaslHandshakeRequest createSaslHandshakeRequest(short version) {
@@ -1173,9 +1398,19 @@ private void authenticateUsingSaslPlainAndCheckConnection(String node, boolean e
     private TestJaasConfig configureMechanisms(String clientMechanism, List<String> serverMechanisms) {
         saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism);
         saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms);
+        if (serverMechanisms.contains("DIGEST-MD5")) {
+            saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                    TestDigestLoginModule.DigestServerCallbackHandler.class.getName());
+        }
         return TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms);
     }
 
+    private void configureDigestMd5ServerCallback(SecurityProtocol securityProtocol) {
+        String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("DIGEST-MD5");
+        saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
+                TestDigestLoginModule.DigestServerCallbackHandler.class);
+    }
+
     private void createSelector(SecurityProtocol securityProtocol, Map<String, Object> clientConfigs) {
         if (selector != null) {
             selector.close();
@@ -1261,6 +1496,28 @@ private ByteBuffer waitForResponse() throws IOException {
         return selector.completedReceives().get(0).payload();
     }
 
+    public static class TestServerCallbackHandler extends PlainServerCallbackHandler {
+
+        static final String USERNAME = "TestServerCallbackHandler-user";
+        static final String PASSWORD = "TestServerCallbackHandler-password";
+        private volatile boolean configured;
+
+        @Override
+        public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            if (configured)
+                throw new IllegalStateException("Server callback handler configured twice");
+            configured = true;
+            super.configure(configs, mechanism, jaasConfigEntries);
+        }
+
+        @Override
+        protected boolean authenticate(String username, char[] password) throws IOException {
+            if (!configured)
+                throw new IllegalStateException("Server callback handler not configured");
+            return USERNAME.equals(username) && new String(password).equals(PASSWORD);
+        }
+    }
+
     @SuppressWarnings("unchecked")
     private void updateScramCredentialCache(String username, String password) throws NoSuchAlgorithmException {
         for (String mechanism : (List<String>) saslServerConfigs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG)) {
@@ -1289,4 +1546,121 @@ private void updateTokenCredentialCache(String username, String password) throws
             }
         }
     }
+
+    public static class TestClientCallbackHandler implements AuthenticateCallbackHandler {
+
+        static final String USERNAME = "TestClientCallbackHandler-user";
+        static final String PASSWORD = "TestClientCallbackHandler-password";
+        private volatile boolean configured;
+
+        @Override
+        public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            if (configured)
+                throw new IllegalStateException("Client callback handler configured twice");
+            configured = true;
+        }
+
+        @Override
+        public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+            if (!configured)
+                throw new IllegalStateException("Client callback handler not configured");
+            for (Callback callback : callbacks) {
+                if (callback instanceof NameCallback)
+                    ((NameCallback) callback).setName(USERNAME);
+                else if (callback instanceof PasswordCallback)
+                    ((PasswordCallback) callback).setPassword(PASSWORD.toCharArray());
+                else
+                    throw new UnsupportedCallbackException(callback);
+            }
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    public static class TestLogin implements Login {
+
+        static AtomicInteger loginCount = new AtomicInteger();
+
+        private String contextName;
+        private Configuration configuration;
+        private Subject subject;
+        @Override
+        public void configure(Map<String, ?> configs, String contextName, Configuration configuration,
+                              AuthenticateCallbackHandler callbackHandler) {
+            assertEquals(1, configuration.getAppConfigurationEntry(contextName).length);
+            this.contextName = contextName;
+            this.configuration = configuration;
+        }
+
+        @Override
+        public LoginContext login() throws LoginException {
+            LoginContext context = new LoginContext(contextName, null, new AbstractLogin.DefaultLoginCallbackHandler(), configuration);
+            context.login();
+            subject = context.getSubject();
+            subject.getPublicCredentials().clear();
+            subject.getPrivateCredentials().clear();
+            subject.getPublicCredentials().add(TestJaasConfig.USERNAME);
+            subject.getPrivateCredentials().add(TestJaasConfig.PASSWORD);
+            loginCount.incrementAndGet();
+            return context;
+        }
+
+        @Override
+        public Subject subject() {
+            return subject;
+        }
+
+        @Override
+        public String serviceName() {
+            return "kafka";
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    public static class TestLoginCallbackHandler implements AuthenticateCallbackHandler {
+        private volatile boolean configured = false;
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+            if (configured)
+                throw new IllegalStateException("Login callback handler configured twice");
+            configured = true;
+        }
+
+        @Override
+        public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+            if (!configured)
+                throw new IllegalStateException("Login callback handler not configured");
+
+            for (Callback callback : callbacks) {
+                if (callback instanceof NameCallback)
+                    ((NameCallback) callback).setName(TestJaasConfig.USERNAME);
+                else if (callback instanceof PasswordCallback)
+                    ((PasswordCallback) callback).setPassword(TestJaasConfig.PASSWORD.toCharArray());
+            }
+        }
+
+        @Override
+        public void close() {
+        }
+    }
+
+    public static final class TestPlainLoginModule extends PlainLoginModule {
+        @Override
+        public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) {
+            try {
+                NameCallback nameCallback = new NameCallback("name:");
+                PasswordCallback passwordCallback = new PasswordCallback("password:", false);
+                callbackHandler.handle(new Callback[]{nameCallback, passwordCallback});
+                subject.getPublicCredentials().add(nameCallback.getName());
+                subject.getPrivateCredentials().add(new String(passwordCallback.getPassword()));
+            } catch (Exception e) {
+                throw new SaslAuthenticationException("Login initialization failed", e);
+            }
+        }
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
index 17d31bda72c..3ec30317f5e 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java
@@ -22,13 +22,12 @@
 import org.apache.kafka.common.network.ListenerName;
 import org.apache.kafka.common.network.TransportLayer;
 import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.auth.SecurityProtocol;
 import org.apache.kafka.common.protocol.types.Struct;
 import org.apache.kafka.common.requests.RequestHeader;
 import org.apache.kafka.common.security.JaasContext;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
 import org.easymock.Capture;
 import org.easymock.EasyMock;
 import org.easymock.IAnswer;
@@ -41,7 +40,7 @@
 import java.util.HashMap;
 import java.util.Map;
 
-import static org.apache.kafka.common.security.scram.ScramMechanism.SCRAM_SHA_256;
+import static org.apache.kafka.common.security.scram.internal.ScramMechanism.SCRAM_SHA_256;
 import static org.junit.Assert.fail;
 
 public class SaslServerAuthenticatorTest {
@@ -112,8 +111,10 @@ private SaslServerAuthenticator setupAuthenticator(Map<String, ?> configs, Trans
         Map<String, JaasContext> jaasContexts = Collections.singletonMap(mechanism,
                 new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null));
         Map<String, Subject> subjects = Collections.singletonMap(mechanism, new Subject());
-        return new SaslServerAuthenticator(configs, "node", jaasContexts, subjects, null, new CredentialCache(),
-                new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, new DelegationTokenCache(ScramMechanism.mechanismNames()));
+        Map<String, AuthenticateCallbackHandler> callbackHandlers = Collections.<String, AuthenticateCallbackHandler>singletonMap(
+                mechanism, new SaslServerCallbackHandler());
+        return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null,
+                new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer);
     }
 
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java
index f1ef740e5db..97b0b2715b2 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java
@@ -17,62 +17,47 @@
 package org.apache.kafka.common.security.authenticator;
 
 import java.io.IOException;
-import java.security.Provider;
-import java.security.Security;
-import java.util.Arrays;
-import java.util.Enumeration;
-import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import javax.security.auth.callback.Callback;
-import javax.security.auth.callback.CallbackHandler;
 import javax.security.auth.callback.NameCallback;
 import javax.security.auth.callback.PasswordCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.auth.login.AppConfigurationEntry;
 import javax.security.sasl.AuthorizeCallback;
 import javax.security.sasl.RealmCallback;
-import javax.security.sasl.Sasl;
-import javax.security.sasl.SaslException;
-import javax.security.sasl.SaslServer;
-import javax.security.sasl.SaslServerFactory;
 
+import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
 
 /**
- * Digest-MD5 login module for multi-mechanism tests. Since callback handlers are not configurable in Kafka
- * yet, this replaces the standard Digest-MD5 SASL server provider with one that invokes the test callback handler.
+ * Digest-MD5 login module for multi-mechanism tests.
  * This login module uses the same format as PlainLoginModule and hence simply reuses the same methods.
  *
  */
 public class TestDigestLoginModule extends PlainLoginModule {
 
-    private static final SaslServerFactory STANDARD_DIGEST_SASL_SERVER_FACTORY;
-    static {
-        SaslServerFactory digestSaslServerFactory = null;
-        Enumeration<SaslServerFactory> factories = Sasl.getSaslServerFactories();
-        Map<String, Object> emptyProps = new HashMap<>();
-        while (factories.hasMoreElements()) {
-            SaslServerFactory factory = factories.nextElement();
-            if (Arrays.asList(factory.getMechanismNames(emptyProps)).contains("DIGEST-MD5")) {
-                digestSaslServerFactory = factory;
-                break;
-            }
-        }
-        STANDARD_DIGEST_SASL_SERVER_FACTORY = digestSaslServerFactory;
-        Security.insertProviderAt(new DigestSaslServerProvider(), 1);
-    }
+    public static class DigestServerCallbackHandler implements AuthenticateCallbackHandler {
 
-    public static class DigestServerCallbackHandler implements CallbackHandler {
+        @Override
+        public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
+        }
 
         @Override
         public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+            String username = null;
             for (Callback callback : callbacks) {
                 if (callback instanceof NameCallback) {
                     NameCallback nameCallback = (NameCallback) callback;
-                    nameCallback.setName(nameCallback.getDefaultName());
+                    if (TestJaasConfig.USERNAME.equals(nameCallback.getDefaultName())) {
+                        nameCallback.setName(nameCallback.getDefaultName());
+                        username = TestJaasConfig.USERNAME;
+                    }
                 } else if (callback instanceof PasswordCallback) {
                     PasswordCallback passwordCallback = (PasswordCallback) callback;
-                    passwordCallback.setPassword(TestJaasConfig.PASSWORD.toCharArray());
+                    if (TestJaasConfig.USERNAME.equals(username))
+                        passwordCallback.setPassword(TestJaasConfig.PASSWORD.toCharArray());
                 } else if (callback instanceof RealmCallback) {
                     RealmCallback realmCallback = (RealmCallback) callback;
                     realmCallback.setText(realmCallback.getDefaultText());
@@ -85,30 +70,9 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
                 }
             }
         }
-    }
-
-    public static class DigestSaslServerFactory implements SaslServerFactory {
-
-        @Override
-        public SaslServer createSaslServer(String mechanism, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh)
-                throws SaslException {
-            return STANDARD_DIGEST_SASL_SERVER_FACTORY.createSaslServer(mechanism, protocol, serverName, props, new DigestServerCallbackHandler());
-        }
 
         @Override
-        public String[] getMechanismNames(Map<String, ?> props) {
-            return new String[] {"DIGEST-MD5"};
-        }
-    }
-
-    public static class DigestSaslServerProvider extends Provider {
-
-        private static final long serialVersionUID = 1L;
-
-        @SuppressWarnings("deprecation")
-        protected DigestSaslServerProvider() {
-            super("Test SASL/Digest-MD5 Server Provider", 1.0, "Test SASL/Digest-MD5 Server Provider for Kafka");
-            put("SaslServerFactory.DIGEST-MD5", TestDigestLoginModule.DigestSaslServerFactory.class.getName());
+        public void close() {
         }
     }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java
index dafa79db98f..3ee7c2ca4d7 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java
@@ -28,7 +28,7 @@
 import org.apache.kafka.common.config.types.Password;
 import org.apache.kafka.common.security.plain.PlainLoginModule;
 import org.apache.kafka.common.security.scram.ScramLoginModule;
-import org.apache.kafka.common.security.scram.ScramMechanism;
+import org.apache.kafka.common.security.scram.internal.ScramMechanism;
 
 public class TestJaasConfig extends Configuration {
 
diff --git a/clients/src/test/java/org/apache/kafka/common/security/plain/PlainSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerTest.java
similarity index 89%
rename from clients/src/test/java/org/apache/kafka/common/security/plain/PlainSaslServerTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerTest.java
index 86baf3e07af..1410c8accbc 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/plain/PlainSaslServerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/plain/internal/PlainSaslServerTest.java
@@ -14,8 +14,9 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.plain;
+package org.apache.kafka.common.security.plain.internal;
 
+import org.apache.kafka.common.security.plain.PlainLoginModule;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -46,7 +47,9 @@ public void setUp() throws Exception {
         options.put("user_" + USER_B, PASSWORD_B);
         jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), options);
         JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null);
-        saslServer = new PlainSaslServer(jaasContext);
+        PlainServerCallbackHandler callbackHandler = new PlainServerCallbackHandler();
+        callbackHandler.configure(null, "PLAIN", jaasContext.configurationEntries());
+        saslServer = new PlainSaslServer(callbackHandler);
     }
 
     @Test
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramCredentialUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtilsTest.java
similarity index 97%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramCredentialUtilsTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtilsTest.java
index e9dd285f54e..a1a1d20d63e 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramCredentialUtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramCredentialUtilsTest.java
@@ -14,23 +14,23 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import org.junit.Test;
+import java.security.NoSuchAlgorithmException;
+import java.util.Arrays;
 
+import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
+
+import org.junit.Before;
+import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
-
-import java.security.NoSuchAlgorithmException;
-import java.util.Arrays;
-
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertTrue;
 
-import org.apache.kafka.common.security.authenticator.CredentialCache;
-import org.junit.Before;
 
 public class ScramCredentialUtilsTest {
 
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramFormatterTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramFormatterTest.java
similarity index 91%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramFormatterTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramFormatterTest.java
index a86e0ddb800..b06b039768a 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramFormatterTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramFormatterTest.java
@@ -14,19 +14,18 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
 import org.apache.kafka.common.utils.Base64;
-import org.junit.Test;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
 
+import org.junit.Test;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
-
 public class ScramFormatterTest {
 
     /**
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramMessagesTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramMessagesTest.java
similarity index 96%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramMessagesTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramMessagesTest.java
index 7b04ede6e78..d856f373a78 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramMessagesTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramMessagesTest.java
@@ -14,29 +14,28 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
-
-import org.apache.kafka.common.utils.Base64;
-import org.junit.Before;
-import org.junit.Test;
+package org.apache.kafka.common.security.scram.internal;
 
 import java.nio.charset.StandardCharsets;
 import java.util.Collections;
 
 import javax.security.sasl.SaslException;
 
+import org.apache.kafka.common.security.scram.internal.ScramMessages.AbstractScramMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ClientFirstMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFinalMessage;
+import org.apache.kafka.common.security.scram.internal.ScramMessages.ServerFirstMessage;
+import org.apache.kafka.common.utils.Base64;
+
+import org.junit.Before;
+import org.junit.Test;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
-import org.apache.kafka.common.security.scram.ScramMessages.AbstractScramMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage;
-import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage;
-
 public class ScramMessagesTest {
 
     private static final String[] VALID_EXTENSIONS = {
diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerTest.java
similarity index 96%
rename from clients/src/test/java/org/apache/kafka/common/security/scram/ScramSaslServerTest.java
rename to clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerTest.java
index 82ad91415be..3c4b82d7921 100644
--- a/clients/src/test/java/org/apache/kafka/common/security/scram/ScramSaslServerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internal/ScramSaslServerTest.java
@@ -14,19 +14,20 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.kafka.common.security.scram;
+package org.apache.kafka.common.security.scram.internal;
 
-import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
-import org.junit.Before;
-import org.junit.Test;
 
 import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 
-import static org.junit.Assert.assertTrue;
-
 import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.security.authenticator.CredentialCache;
+import org.apache.kafka.common.security.scram.ScramCredential;
+import org.apache.kafka.common.security.token.delegation.DelegationTokenCache;
+
+import org.junit.Before;
+import org.junit.Test;
+import static org.junit.Assert.assertTrue;
 
 public class ScramSaslServerTest {
 
diff --git a/core/src/main/scala/kafka/admin/ConfigCommand.scala b/core/src/main/scala/kafka/admin/ConfigCommand.scala
index 3563448feb6..c19599d3cc0 100644
--- a/core/src/main/scala/kafka/admin/ConfigCommand.scala
+++ b/core/src/main/scala/kafka/admin/ConfigCommand.scala
@@ -32,7 +32,7 @@ import org.apache.kafka.clients.CommonClientConfigs
 import org.apache.kafka.clients.admin.{AlterConfigsOptions, ConfigEntry, DescribeConfigsOptions, AdminClient => JAdminClient, Config => JConfig}
 import org.apache.kafka.common.config.ConfigResource
 import org.apache.kafka.common.security.JaasUtils
-import org.apache.kafka.common.security.scram._
+import org.apache.kafka.common.security.scram.internal.{ScramCredentialUtils, ScramFormatter, ScramMechanism}
 import org.apache.kafka.common.utils.{Sanitizer, Time, Utils}
 
 import scala.collection._
diff --git a/core/src/main/scala/kafka/security/CredentialProvider.scala b/core/src/main/scala/kafka/security/CredentialProvider.scala
index 0e7ebb672b1..6f9c2527735 100644
--- a/core/src/main/scala/kafka/security/CredentialProvider.scala
+++ b/core/src/main/scala/kafka/security/CredentialProvider.scala
@@ -20,9 +20,10 @@ package kafka.security
 import java.util.{Collection, Properties}
 
 import org.apache.kafka.common.security.authenticator.CredentialCache
-import org.apache.kafka.common.security.scram.{ScramCredential, ScramCredentialUtils, ScramMechanism}
+import org.apache.kafka.common.security.scram.ScramCredential
 import org.apache.kafka.common.config.ConfigDef
 import org.apache.kafka.common.config.ConfigDef._
+import org.apache.kafka.common.security.scram.internal.{ScramCredentialUtils, ScramMechanism}
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCache
 
 class CredentialProvider(scramMechanisms: Collection[String], val tokenCache: DelegationTokenCache) {
diff --git a/core/src/main/scala/kafka/server/DelegationTokenManager.scala b/core/src/main/scala/kafka/server/DelegationTokenManager.scala
index 008dc323d7a..4a947a17fcf 100644
--- a/core/src/main/scala/kafka/server/DelegationTokenManager.scala
+++ b/core/src/main/scala/kafka/server/DelegationTokenManager.scala
@@ -29,7 +29,8 @@ import kafka.utils.{CoreUtils, Json, Logging}
 import kafka.zk.{DelegationTokenChangeNotificationSequenceZNode, DelegationTokenChangeNotificationZNode, DelegationTokensZNode, KafkaZkClient}
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.security.auth.KafkaPrincipal
-import org.apache.kafka.common.security.scram.{ScramCredential, ScramFormatter, ScramMechanism}
+import org.apache.kafka.common.security.scram.internal.{ScramFormatter, ScramMechanism}
+import org.apache.kafka.common.security.scram.ScramCredential
 import org.apache.kafka.common.security.token.delegation.{DelegationToken, DelegationTokenCache, TokenInformation}
 import org.apache.kafka.common.utils.{Base64, Sanitizer, SecurityUtils, Time}
 
diff --git a/core/src/main/scala/kafka/server/DynamicConfigManager.scala b/core/src/main/scala/kafka/server/DynamicConfigManager.scala
index 728f88ca8f7..d56b46fe992 100644
--- a/core/src/main/scala/kafka/server/DynamicConfigManager.scala
+++ b/core/src/main/scala/kafka/server/DynamicConfigManager.scala
@@ -22,9 +22,9 @@ import java.nio.charset.StandardCharsets
 import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener}
 import kafka.utils.{Json, Logging}
 import kafka.utils.json.JsonObject
-import kafka.zk.{KafkaZkClient, AdminZkClient, ConfigEntityChangeNotificationZNode, ConfigEntityChangeNotificationSequenceZNode}
+import kafka.zk.{AdminZkClient, ConfigEntityChangeNotificationSequenceZNode, ConfigEntityChangeNotificationZNode, KafkaZkClient}
 import org.apache.kafka.common.config.types.Password
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.utils.Time
 
 import scala.collection.JavaConverters._
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index cf22305caf7..9d7e64ce027 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -425,6 +425,10 @@ object KafkaConfig {
   val SaslMechanismInterBrokerProtocolProp = "sasl.mechanism.inter.broker.protocol"
   val SaslJaasConfigProp = SaslConfigs.SASL_JAAS_CONFIG
   val SaslEnabledMechanismsProp = BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG
+  val SaslServerCallbackHandlerClassProp = BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS
+  val SaslClientCallbackHandlerClassProp = SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS
+  val SaslLoginClassProp = SaslConfigs.SASL_LOGIN_CLASS
+  val SaslLoginCallbackHandlerClassProp = SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS
   val SaslKerberosServiceNameProp = SaslConfigs.SASL_KERBEROS_SERVICE_NAME
   val SaslKerberosKinitCmdProp = SaslConfigs.SASL_KERBEROS_KINIT_CMD
   val SaslKerberosTicketRenewWindowFactorProp = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR
@@ -713,7 +717,11 @@ object KafkaConfig {
   /** ********* Sasl Configuration ****************/
   val SaslMechanismInterBrokerProtocolDoc = "SASL mechanism used for inter-broker communication. Default is GSSAPI."
   val SaslJaasConfigDoc = SaslConfigs.SASL_JAAS_CONFIG_DOC
-  val SaslEnabledMechanismsDoc = SaslConfigs.SASL_ENABLED_MECHANISMS_DOC
+  val SaslEnabledMechanismsDoc = BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_DOC
+  val SaslServerCallbackHandlerClassDoc = BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC
+  val SaslClientCallbackHandlerClassDoc = SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC
+  val SaslLoginClassDoc = SaslConfigs.SASL_LOGIN_CLASS_DOC
+  val SaslLoginCallbackHandlerClassDoc = SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC
   val SaslKerberosServiceNameDoc = SaslConfigs.SASL_KERBEROS_SERVICE_NAME_DOC
   val SaslKerberosKinitCmdDoc = SaslConfigs.SASL_KERBEROS_KINIT_CMD_DOC
   val SaslKerberosTicketRenewWindowFactorDoc = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR_DOC
@@ -937,6 +945,10 @@ object KafkaConfig {
       .define(SaslMechanismInterBrokerProtocolProp, STRING, Defaults.SaslMechanismInterBrokerProtocol, MEDIUM, SaslMechanismInterBrokerProtocolDoc)
       .define(SaslJaasConfigProp, PASSWORD, null, MEDIUM, SaslJaasConfigDoc)
       .define(SaslEnabledMechanismsProp, LIST, Defaults.SaslEnabledMechanisms, MEDIUM, SaslEnabledMechanismsDoc)
+      .define(SaslServerCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslServerCallbackHandlerClassDoc)
+      .define(SaslClientCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslClientCallbackHandlerClassDoc)
+      .define(SaslLoginClassProp, CLASS, null, MEDIUM, SaslLoginClassDoc)
+      .define(SaslLoginCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslLoginCallbackHandlerClassDoc)
       .define(SaslKerberosServiceNameProp, STRING, null, MEDIUM, SaslKerberosServiceNameDoc)
       .define(SaslKerberosKinitCmdProp, STRING, Defaults.SaslKerberosKinitCmd, MEDIUM, SaslKerberosKinitCmdDoc)
       .define(SaslKerberosTicketRenewWindowFactorProp, DOUBLE, Defaults.SaslKerberosTicketRenewWindowFactor, MEDIUM, SaslKerberosTicketRenewWindowFactorDoc)
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala
index d7ca65658f6..53632cd6c75 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -44,7 +44,7 @@ import org.apache.kafka.common.network._
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{ControlledShutdownRequest, ControlledShutdownResponse}
 import org.apache.kafka.common.security.auth.SecurityProtocol
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.security.token.delegation.DelegationTokenCache
 import org.apache.kafka.common.security.{JaasContext, JaasUtils}
 import org.apache.kafka.common.utils.{AppInfoParser, LogContext, Time}
diff --git a/core/src/main/scala/kafka/utils/VerifiableProperties.scala b/core/src/main/scala/kafka/utils/VerifiableProperties.scala
index de4f654c1ee..5d70db5e11d 100755
--- a/core/src/main/scala/kafka/utils/VerifiableProperties.scala
+++ b/core/src/main/scala/kafka/utils/VerifiableProperties.scala
@@ -122,7 +122,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
     require(v >= range._1 && v <= range._2, name + " has value " + v + " which is not in the range " + range + ".")
     v
   }
-  
+
   /**
    * Get a required argument as a double
    * @param name The property name
@@ -130,7 +130,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
    * @throws IllegalArgumentException If the given property is not present
    */
   def getDouble(name: String): Double = getString(name).toDouble
-  
+
   /**
    * Get an optional argument as a double
    * @param name The property name
@@ -141,7 +141,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
       getDouble(name)
     else
       default
-  } 
+  }
 
   /**
    * Read a boolean value from the properties instance
@@ -158,7 +158,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
       v.toBoolean
     }
   }
-  
+
   def getBoolean(name: String) = getString(name).toBoolean
 
   /**
@@ -178,7 +178,7 @@ class VerifiableProperties(val props: Properties) extends Logging {
     require(containsKey(name), "Missing required property '" + name + "'")
     getProperty(name)
   }
-  
+
   /**
    * Get a Map[String, String] from a property list in the form k1:v2, k2:v2, ...
    */
diff --git a/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala
index 24f435e9fb1..27a6d1127d5 100644
--- a/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala
@@ -24,7 +24,7 @@ import kafka.utils.{JaasTestUtils, TestUtils, ZkUtils}
 import org.apache.kafka.clients.admin.AdminClientConfig
 import org.apache.kafka.common.config.SaslConfigs
 import org.apache.kafka.common.security.auth.SecurityProtocol
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.security.token.delegation.DelegationToken
 import org.junit.Before
 
diff --git a/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala
index f5409b1e47b..a5bf33171a4 100644
--- a/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala
@@ -63,6 +63,7 @@ abstract class SaslEndToEndAuthorizationTest extends EndToEndAuthorizationTest {
     consumer2Config ++= consumerConfig
     // consumer2 retrieves its credentials from the static JAAS configuration, so we test also this path
     consumer2Config.remove(SaslConfigs.SASL_JAAS_CONFIG)
+    consumer2Config.remove(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS)
 
     val consumer2 = TestUtils.createNewConsumer(brokerList,
                                                 securityProtocol = securityProtocol,
diff --git a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
index 08351aa0c7a..efb8c48c7c4 100644
--- a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala
@@ -16,22 +16,33 @@
   */
 package kafka.api
 
-import kafka.utils.{CoreUtils, JaasTestUtils, TestUtils, ZkUtils}
+import java.security.AccessController
+import javax.security.auth.callback._
+import javax.security.auth.Subject
+import javax.security.auth.login.AppConfigurationEntry
+
+import kafka.server.KafkaConfig
+import kafka.utils.{CoreUtils, TestUtils, ZkUtils}
+import kafka.utils.JaasTestUtils._
+import org.apache.kafka.common.config.SaslConfigs
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs
+import org.apache.kafka.common.network.ListenerName
 import org.apache.kafka.common.security.JaasUtils
-import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal, KafkaPrincipalBuilder, SaslAuthenticationContext}
+import org.apache.kafka.common.security.auth._
+import org.apache.kafka.common.security.plain.PlainAuthenticateCallback
 import org.junit.Test
 
 object SaslPlainSslEndToEndAuthorizationTest {
+
   class TestPrincipalBuilder extends KafkaPrincipalBuilder {
 
     override def build(context: AuthenticationContext): KafkaPrincipal = {
       context match {
         case ctx: SaslAuthenticationContext =>
           ctx.server.getAuthorizationID match {
-            case JaasTestUtils.KafkaPlainAdmin =>
+            case KafkaPlainAdmin =>
               new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "admin")
-            case JaasTestUtils.KafkaPlainUser =>
+            case KafkaPlainUser =>
               new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user")
             case _ =>
               KafkaPrincipal.ANONYMOUS
@@ -39,18 +50,84 @@ object SaslPlainSslEndToEndAuthorizationTest {
       }
     }
   }
+
+  object Credentials {
+    val allUsers = Map(KafkaPlainUser -> "user1-password",
+      KafkaPlainUser2 -> KafkaPlainPassword2,
+      KafkaPlainAdmin -> "broker-password")
+  }
+
+  class TestServerCallbackHandler extends AuthenticateCallbackHandler {
+    def configure(configs: java.util.Map[String, _], saslMechanism: String, jaasConfigEntries: java.util.List[AppConfigurationEntry]) {}
+    def handle(callbacks: Array[Callback]) {
+      var username: String = null
+      for (callback <- callbacks) {
+        if (callback.isInstanceOf[NameCallback])
+          username = callback.asInstanceOf[NameCallback].getDefaultName
+        else if (callback.isInstanceOf[PlainAuthenticateCallback]) {
+          val plainCallback = callback.asInstanceOf[PlainAuthenticateCallback]
+          plainCallback.authenticated(Credentials.allUsers(username) == new String(plainCallback.password))
+        } else
+          throw new UnsupportedCallbackException(callback)
+      }
+    }
+    def close() {}
+  }
+
+  class TestClientCallbackHandler extends AuthenticateCallbackHandler {
+    def configure(configs: java.util.Map[String, _], saslMechanism: String, jaasConfigEntries: java.util.List[AppConfigurationEntry]) {}
+    def handle(callbacks: Array[Callback]) {
+      val subject = Subject.getSubject(AccessController.getContext())
+      val username = subject.getPublicCredentials(classOf[String]).iterator().next()
+      for (callback <- callbacks) {
+        if (callback.isInstanceOf[NameCallback])
+          callback.asInstanceOf[NameCallback].setName(username)
+        else if (callback.isInstanceOf[PasswordCallback]) {
+          if (username == KafkaPlainUser || username == KafkaPlainAdmin)
+            callback.asInstanceOf[PasswordCallback].setPassword(Credentials.allUsers(username).toCharArray)
+        } else
+          throw new UnsupportedCallbackException(callback)
+      }
+    }
+    def close() {}
+  }
 }
 
+
+// This test uses SASL callback handler overrides for server connections of Kafka broker
+// and client connections of Kafka producers and consumers. Client connections from Kafka brokers
+// used for inter-broker communication also use custom callback handlers. The second client used in
+// the multi-user test SaslEndToEndAuthorizationTest#testTwoConsumersWithDifferentSaslCredentials uses
+// static JAAS configuration with default callback handlers to test those code paths as well.
 class SaslPlainSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest {
-  import SaslPlainSslEndToEndAuthorizationTest.TestPrincipalBuilder
+  import SaslPlainSslEndToEndAuthorizationTest._
 
   this.serverConfig.setProperty(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, classOf[TestPrincipalBuilder].getName)
+  this.serverConfig.put(KafkaConfig.SaslClientCallbackHandlerClassProp, classOf[TestClientCallbackHandler].getName)
+  val mechanismPrefix = ListenerName.forSecurityProtocol(SecurityProtocol.SASL_SSL).saslMechanismConfigPrefix("PLAIN")
+  this.serverConfig.put(s"$mechanismPrefix${KafkaConfig.SaslServerCallbackHandlerClassProp}", classOf[TestServerCallbackHandler].getName)
+  this.producerConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName)
+  this.consumerConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName)
+  private val plainLogin = s"org.apache.kafka.common.security.plain.PlainLoginModule username=$KafkaPlainUser required;"
+  this.producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin)
+  this.consumerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin)
 
   override protected def kafkaClientSaslMechanism = "PLAIN"
   override protected def kafkaServerSaslMechanisms = List("PLAIN")
+
   override val clientPrincipal = "user"
   override val kafkaPrincipal = "admin"
 
+  override def jaasSections(kafkaServerSaslMechanisms: Seq[String],
+                            kafkaClientSaslMechanism: Option[String],
+                            mode: SaslSetupMode,
+                            kafkaServerEntryName: String): Seq[JaasSection] = {
+    val brokerLogin = new PlainLoginModule(KafkaPlainAdmin, "") // Password provided by callback handler
+    val clientLogin = new PlainLoginModule(KafkaPlainUser2, KafkaPlainPassword2)
+    Seq(JaasSection(kafkaServerEntryName, Seq(brokerLogin)),
+      JaasSection(KafkaClientContextName, Seq(clientLogin))) ++ zkSections
+  }
+
   /**
    * Checks that secure paths created by broker and acl paths created by AclCommand
    * have expected ACLs.
diff --git a/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala
index 000fc219b49..d304ffc1cae 100644
--- a/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala
@@ -16,9 +16,9 @@
   */
 package kafka.api
 
-import org.apache.kafka.common.security.scram.ScramMechanism
 import kafka.utils.JaasTestUtils
 import kafka.zk.ConfigEntityChangeNotificationZNode
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 
 import scala.collection.JavaConverters._
 import org.junit.Before
diff --git a/core/src/test/scala/integration/kafka/api/SaslSetup.scala b/core/src/test/scala/integration/kafka/api/SaslSetup.scala
index 273b24759e9..ab2819e0ed6 100644
--- a/core/src/test/scala/integration/kafka/api/SaslSetup.scala
+++ b/core/src/test/scala/integration/kafka/api/SaslSetup.scala
@@ -30,7 +30,7 @@ import org.apache.kafka.common.config.SaslConfigs
 import org.apache.kafka.common.config.internals.BrokerSecurityConfigs
 import org.apache.kafka.common.security.JaasUtils
 import org.apache.kafka.common.security.authenticator.LoginManager
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 
 /*
  * Implements an enumeration for the modes enabled here:
diff --git a/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala b/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala
index a17f060996b..66e98f5b1ce 100644
--- a/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala
+++ b/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala
@@ -28,7 +28,7 @@ import org.apache.kafka.clients.admin._
 import org.apache.kafka.common.config.ConfigResource
 import org.apache.kafka.common.internals.KafkaFutureImpl
 import org.apache.kafka.common.Node
-import org.apache.kafka.common.security.scram.ScramCredentialUtils
+import org.apache.kafka.common.security.scram.internal.ScramCredentialUtils
 import org.apache.kafka.common.utils.Sanitizer
 import org.easymock.EasyMock
 import org.junit.Assert._
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index a057e54bd7e..28e2ec09cee 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -38,7 +38,7 @@ import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader}
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.utils.{LogContext, MockTime, Time}
 import org.apache.log4j.Level
 import org.junit.Assert._
diff --git a/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala b/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala
index 8c035482ee2..b8388b4cb49 100644
--- a/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala
@@ -29,7 +29,7 @@ import kafka.utils.TestUtils
 import kafka.zk.ZooKeeperTestHarness
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.security.auth.KafkaPrincipal
-import org.apache.kafka.common.security.scram.ScramMechanism
+import org.apache.kafka.common.security.scram.internal.ScramMechanism
 import org.apache.kafka.common.security.token.delegation.{DelegationToken, DelegationTokenCache, TokenInformation}
 import org.apache.kafka.common.utils.{MockTime, SecurityUtils}
 import org.junit.Assert._
diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala
index 0213c12814e..81470c0b412 100755
--- a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala
@@ -692,6 +692,10 @@ class KafkaConfigTest {
         //Sasl Configs
         case KafkaConfig.SaslMechanismInterBrokerProtocolProp => // ignore
         case KafkaConfig.SaslEnabledMechanismsProp =>
+        case KafkaConfig.SaslClientCallbackHandlerClassProp =>
+        case KafkaConfig.SaslServerCallbackHandlerClassProp =>
+        case KafkaConfig.SaslLoginClassProp =>
+        case KafkaConfig.SaslLoginCallbackHandlerClassProp =>
         case KafkaConfig.SaslKerberosServiceNameProp => // ignore string
         case KafkaConfig.SaslKerberosKinitCmdProp =>
         case KafkaConfig.SaslKerberosTicketRenewWindowFactorProp =>
diff --git a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala
index 9ce3b01025a..10c73452a3c 100644
--- a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala
@@ -114,7 +114,7 @@ object JaasTestUtils {
   val KafkaServerContextName = "KafkaServer"
   val KafkaServerPrincipalUnqualifiedName = "kafka"
   private val KafkaServerPrincipal = KafkaServerPrincipalUnqualifiedName + "/localhost@EXAMPLE.COM"
-  private val KafkaClientContextName = "KafkaClient"
+  val KafkaClientContextName = "KafkaClient"
   val KafkaClientPrincipalUnqualifiedName = "client"
   private val KafkaClientPrincipal = KafkaClientPrincipalUnqualifiedName + "@EXAMPLE.COM"
   val KafkaClientPrincipalUnqualifiedName2 = "client2"


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> KIP-86: Configurable SASL callback handlers
> -------------------------------------------
>
>                 Key: KAFKA-4292
>                 URL: https://issues.apache.org/jira/browse/KAFKA-4292
>             Project: Kafka
>          Issue Type: Improvement
>          Components: security
>    Affects Versions: 0.10.1.0
>            Reporter: Rajini Sivaram
>            Assignee: Rajini Sivaram
>            Priority: Major
>
> Implementation of KIP-86: https://cwiki.apache.org/confluence/display/KAFKA/KIP-86%3A+Configurable+SASL+callback+handlers



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)