You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/12/04 17:48:46 UTC

[kafka] branch 2.2 updated: KAFKA-9190; Close connections with expired authentication sessions (#7723)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.2 by this push:
     new 61530da  KAFKA-9190; Close connections with expired authentication sessions (#7723)
61530da is described below

commit 61530da45bd97c18786024f43c999981ca2d6479
Author: Ron Dagostino <rd...@confluent.io>
AuthorDate: Wed Dec 4 08:39:39 2019 -0800

    KAFKA-9190; Close connections with expired authentication sessions (#7723)
    
    This patch fixes a bug in `SocketServer` in the expiration of connections which have not re-authenticated quickly enough. Previously these connections were left hanging, but now they are properly closed and cleaned up. This was one cause of the flaky test failures in `EndToEndAuthorizationTest.testNoDescribeProduceOrConsumeWithoutTopicDescribeAcl`.
    
    Reviewers: Jason Gustafson<ja...@confluent.io>, Rajini Sivaram <ra...@googlemail.com>
---
 .../apache/kafka/common/network/KafkaChannel.java  |  6 +-
 .../authenticator/SaslServerAuthenticator.java     | 17 ++---
 .../main/scala/kafka/network/SocketServer.scala    |  5 +-
 .../unit/kafka/network/SocketServerTest.scala      | 85 ++++++++++++++++++++--
 4 files changed, 94 insertions(+), 19 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
index 3bca276..df85e97 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
@@ -571,8 +571,8 @@ public class KafkaChannel implements AutoCloseable {
          * We've delayed getting the time as long as possible in case we don't need it,
          * but at this point we need it -- so get it now.
          */
-        long nowNanos = nowNanosSupplier.get().longValue();
-        if (nowNanos < authenticator.clientSessionReauthenticationTimeNanos().longValue())
+        long nowNanos = nowNanosSupplier.get();
+        if (nowNanos < authenticator.clientSessionReauthenticationTimeNanos())
             return false;
         swapAuthenticatorsAndBeginReauthentication(new ReauthenticationContext(authenticator, receive, nowNanos));
         receive = null;
@@ -605,7 +605,7 @@ public class KafkaChannel implements AutoCloseable {
      */
     public boolean serverAuthenticationSessionExpired(long nowNanos) {
         Long serverSessionExpirationTimeNanos = authenticator.serverSessionExpirationTimeNanos();
-        return serverSessionExpirationTimeNanos != null && nowNanos - serverSessionExpirationTimeNanos.longValue() > 0;
+        return serverSessionExpirationTimeNanos != null && nowNanos - serverSessionExpirationTimeNanos > 0;
     }
     
     /**
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 1e62e7f..a22edae 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -663,30 +663,29 @@ public class SaslServerAuthenticator implements Authenticator {
             Long connectionsMaxReauthMs = connectionsMaxReauthMsByMechanism.get(saslMechanism);
             if (credentialExpirationMs != null || connectionsMaxReauthMs != null) {
                 if (credentialExpirationMs == null)
-                    retvalSessionLifetimeMs = zeroIfNegative(connectionsMaxReauthMs.longValue());
+                    retvalSessionLifetimeMs = zeroIfNegative(connectionsMaxReauthMs);
                 else if (connectionsMaxReauthMs == null)
-                    retvalSessionLifetimeMs = zeroIfNegative(credentialExpirationMs.longValue() - authenticationEndMs);
+                    retvalSessionLifetimeMs = zeroIfNegative(credentialExpirationMs - authenticationEndMs);
                 else
                     retvalSessionLifetimeMs = zeroIfNegative(
-                            Math.min(credentialExpirationMs.longValue() - authenticationEndMs,
-                                    connectionsMaxReauthMs.longValue()));
+                            Math.min(credentialExpirationMs - authenticationEndMs,
+                                    connectionsMaxReauthMs));
                 if (retvalSessionLifetimeMs > 0L)
-                    sessionExpirationTimeNanos = Long
-                            .valueOf(authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs);
+                    sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
             }
             if (credentialExpirationMs != null) {
                 if (sessionExpirationTimeNanos != null)
                     LOG.debug(
                             "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client",
                             connectionsMaxReauthMs, new Date(credentialExpirationMs),
-                            Long.valueOf(credentialExpirationMs.longValue() - authenticationEndMs),
+                            credentialExpirationMs - authenticationEndMs,
                             new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs,
                             retvalSessionLifetimeMs);
                 else
                     LOG.debug(
                             "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); no session expiration, sending 0 ms to client",
                             connectionsMaxReauthMs, new Date(credentialExpirationMs),
-                            Long.valueOf(credentialExpirationMs.longValue() - authenticationEndMs));
+                            credentialExpirationMs - authenticationEndMs);
             } else {
                 if (sessionExpirationTimeNanos != null)
                     LOG.debug(
@@ -706,7 +705,7 @@ public class SaslServerAuthenticator implements Authenticator {
                 return null;
             // record at least 1 ms if there is some latency
             long latencyNanos = authenticationEndNanos - reauthenticationBeginNanos;
-            return latencyNanos == 0L ? 0L : Math.max(1L, Long.valueOf(Math.round(latencyNanos / 1000.0 / 1000.0)));
+            return latencyNanos == 0L ? 0L : Math.max(1L, Math.round(latencyNanos / 1000.0 / 1000.0));
         }
 
         private long zeroIfNegative(long value) {
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala
index d8c4a74..193d484 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -847,8 +847,9 @@ private[kafka] class Processor(val id: Int,
             else {
               val nowNanos = time.nanoseconds()
               if (channel.serverAuthenticationSessionExpired(nowNanos)) {
-                channel.disconnect()
-                debug(s"Disconnected expired channel: $channel : $header")
+                // be sure to decrease connection count and drop any in-flight responses
+                debug(s"Disconnecting expired channel: $channel : $header")
+                close(channel.id)
                 expiredConnectionsKilledCount.record(null, 1, 0)
               } else {
                 val connectionId = receive.source
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 9ccb3f5..b0dc052 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -26,19 +26,20 @@ import java.util.{HashMap, Properties, Random}
 import com.yammer.metrics.core.{Gauge, Meter}
 import com.yammer.metrics.{Metrics => YammerMetrics}
 import javax.net.ssl._
-
 import kafka.security.CredentialProvider
 import kafka.server.{KafkaConfig, ThrottledChannel}
 import kafka.utils.Implicits._
 import kafka.utils.TestUtils
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.memory.MemoryPool
+import org.apache.kafka.common.message.SaslAuthenticateRequestData
+import org.apache.kafka.common.message.SaslHandshakeRequestData
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.network.KafkaChannel.ChannelMuteState
 import org.apache.kafka.common.network.{ChannelBuilder, ChannelState, KafkaChannel, ListenerName, NetworkReceive, NetworkSend, Selector, Send}
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.MemoryRecords
-import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader}
+import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader, SaslAuthenticateRequest, SaslHandshakeRequest}
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.common.security.scram.internals.ScramMechanism
 import org.apache.kafka.common.utils.{LogContext, MockTime, Time}
@@ -107,6 +108,14 @@ class SocketServerTest extends JUnitSuite {
       outgoing.flush()
   }
 
+  def sendApiRequest(socket: Socket, request: AbstractRequest, header: RequestHeader) = {
+    val byteBuffer = request.serialize(header)
+    byteBuffer.rewind()
+    val serializedBytes = new Array[Byte](byteBuffer.remaining)
+    byteBuffer.get(serializedBytes)
+    sendRequest(socket, serializedBytes)
+  }
+
   def receiveResponse(socket: Socket): Array[Byte] = {
     val incoming = new DataInputStream(socket.getInputStream)
     val len = incoming.readInt()
@@ -654,7 +663,72 @@ class SocketServerTest extends JUnitSuite {
   }
 
   @Test
-  def testSessionPrincipal() {
+  def testSaslReauthenticationFailure(): Unit = {
+    shutdownServerAndMetrics(server) // we will use our own instance because we require custom configs
+    val username = "admin"
+    val password = "admin-secret"
+    val reauthMs = 1500
+    val brokerProps = new Properties
+    brokerProps.setProperty("listeners", "SASL_PLAINTEXT://localhost:0")
+    brokerProps.setProperty("security.inter.broker.protocol", "SASL_PLAINTEXT")
+    brokerProps.setProperty("listener.name.sasl_plaintext.plain.sasl.jaas.config",
+      "org.apache.kafka.common.security.plain.PlainLoginModule required " +
+        "username=\"%s\" password=\"%s\" user_%s=\"%s\";".format(username, password, username, password))
+    brokerProps.setProperty("sasl.mechanism.inter.broker.protocol", "PLAIN")
+    brokerProps.setProperty("listener.name.sasl_plaintext.sasl.enabled.mechanisms", "PLAIN")
+    brokerProps.setProperty("num.network.threads", "1")
+    brokerProps.setProperty("connections.max.reauth.ms", reauthMs.toString)
+    val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect,
+      saslProperties = Some(brokerProps), enableSaslPlaintext = true)
+    val time = new MockTime()
+    val overrideServer = new TestableSocketServer(KafkaConfig.fromProps(overrideProps), time = time)
+    try {
+      overrideServer.startup()
+      val socket = connect(overrideServer, ListenerName.forSecurityProtocol(SecurityProtocol.SASL_PLAINTEXT))
+
+      val correlationId = -1
+      val clientId = ""
+      // send a SASL handshake request
+      val saslHandshakeRequest = new SaslHandshakeRequest.Builder(new SaslHandshakeRequestData().setMechanism("PLAIN"))
+        .build()
+      val saslHandshakeHeader = new RequestHeader(ApiKeys.SASL_HANDSHAKE, saslHandshakeRequest.version, clientId,
+        correlationId)
+      sendApiRequest(socket, saslHandshakeRequest, saslHandshakeHeader)
+      receiveResponse(socket)
+
+      // now send credentials within a SaslAuthenticateRequest
+      val authBytes = "admin\u0000admin\u0000admin-secret".getBytes("UTF-8")
+      val saslAuthenticateRequest = new SaslAuthenticateRequest.Builder(new SaslAuthenticateRequestData()
+        .setAuthBytes(authBytes)).build()
+      val saslAuthenticateHeader = new RequestHeader(ApiKeys.SASL_AUTHENTICATE, saslAuthenticateRequest.version,
+        clientId, correlationId)
+      sendApiRequest(socket, saslAuthenticateRequest, saslAuthenticateHeader)
+      receiveResponse(socket)
+      assertEquals(1, overrideServer.testableSelector.channels.size)
+
+      // advance the clock long enough to cause server-side disconnection upon next send...
+      time.sleep(reauthMs * 2)
+      // ...and now send something to trigger the disconnection
+      val ackTimeoutMs = 10000
+      val ack = 0: Short
+      val emptyRequest = ProduceRequest.Builder.forCurrentMagic(ack, ackTimeoutMs,
+        new HashMap[TopicPartition, MemoryRecords]()).build()
+      val emptyHeader = new RequestHeader(ApiKeys.PRODUCE, emptyRequest.version, clientId, correlationId)
+      sendApiRequest(socket, emptyRequest, emptyHeader)
+      // wait a little bit for the server-side disconnection to occur since it happens asynchronously
+      try {
+        TestUtils.waitUntilTrue(() => overrideServer.testableSelector.channels.isEmpty,
+          "Expired connection was not closed", 1000, 100)
+      } finally {
+        socket.close()
+      }
+    } finally {
+      shutdownServerAndMetrics(overrideServer)
+    }
+  }
+
+  @Test
+  def testSessionPrincipal(): Unit = {
     val socket = connect()
     val bytes = new Array[Byte](40)
     sendRequest(socket, bytes, Some(0))
@@ -1191,8 +1265,9 @@ class SocketServerTest extends JUnitSuite {
     }
   }
 
-  class TestableSocketServer(config : KafkaConfig = config, val connectionQueueSize: Int = 20) extends SocketServer(config,
-      new Metrics, Time.SYSTEM, credentialProvider) {
+  class TestableSocketServer(config : KafkaConfig = config, val connectionQueueSize: Int = 20,
+                             override val time: Time = Time.SYSTEM) extends SocketServer(config,
+      new Metrics, time, credentialProvider) {
 
     @volatile var selector: Option[TestableSelector] = None