You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ij...@apache.org on 2017/09/08 17:42:53 UTC
kafka git commit: KAFKA-5790, KAFKA-5607;
Improve error handling in SocketServer to avoid issues later
Repository: kafka
Updated Branches:
refs/heads/trunk 4769e3d92 -> 6ebd9d35f
KAFKA-5790, KAFKA-5607; Improve error handling in SocketServer to avoid issues later
Changes:
1. When an exception is encountered in any of the methods in `Processor` while processing a channel, log the exception and close the connection. Continue to process other channels.
2. Fixes KAFKA-5790: SocketServer.processNewResponses should not skip a response if exception is thrown.
3. For `IllegalStateException` and `IOException` in `poll()`, don't close the `Selector`. Log the exception and continue.
4. Close channel on any failed send in `Selector`.
5. When closing channel fails or is closed, leave channel state as-is, indicating the state in which the channel was moved to closing.
6. Add tests for various failure scenarios.
7. Fix timing issue in `SocketServerTest.testConnectionIdReuse` by waiting for new connections to be processed by the server.
Author: Rajini Sivaram <ra...@googlemail.com>
Reviewers: Jun Rao <ju...@gmail.com>, Ismael Juma <is...@juma.me.uk>
Closes #3548 from rajinisivaram/KAFKA-5607
Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/6ebd9d35
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/6ebd9d35
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/6ebd9d35
Branch: refs/heads/trunk
Commit: 6ebd9d35fc081bca3f1efeba8b70ec71cde85270
Parents: 4769e3d
Author: Rajini Sivaram <ra...@googlemail.com>
Authored: Fri Sep 8 18:42:42 2017 +0100
Committer: Ismael Juma <is...@juma.me.uk>
Committed: Fri Sep 8 18:42:42 2017 +0100
----------------------------------------------------------------------
.../apache/kafka/common/network/Selector.java | 73 +--
.../kafka/common/network/SelectorTest.java | 92 +++-
.../kafka/common/network/SslSelectorTest.java | 5 +
.../main/scala/kafka/network/SocketServer.scala | 208 +++++---
.../unit/kafka/network/SocketServerTest.scala | 516 ++++++++++++++++++-
5 files changed, 762 insertions(+), 132 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/clients/src/main/java/org/apache/kafka/common/network/Selector.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index e873252..d321ba5 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -20,7 +20,6 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.channels.CancelledKeyException;
-import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.channels.UnresolvedAddressException;
@@ -216,19 +215,7 @@ public class Selector implements Selectable, AutoCloseable {
throw e;
}
SelectionKey key = socketChannel.register(nioSelector, SelectionKey.OP_CONNECT);
- KafkaChannel channel;
- try {
- channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool);
- } catch (Exception e) {
- try {
- socketChannel.close();
- } finally {
- key.cancel();
- }
- throw new IOException("Channel could not be created for socket " + socketChannel, e);
- }
- key.attach(channel);
- this.channels.put(id, channel);
+ KafkaChannel channel = buildChannel(socketChannel, id, key);
if (connected) {
// OP_CONNECT won't trigger for immediately connected channels
@@ -247,18 +234,36 @@ public class Selector implements Selectable, AutoCloseable {
* Kafka brokers add an incrementing index to the connection id to avoid reuse in the timing window
* where an existing connection may not yet have been closed by the broker when a new connection with
* the same remote host:port is processed.
+ * </p><p>
+ * If a `KafkaChannel` cannot be created for this connection, the `socketChannel` is closed
+ * and its selection key cancelled.
* </p>
*/
- public void register(String id, SocketChannel socketChannel) throws ClosedChannelException {
+ public void register(String id, SocketChannel socketChannel) throws IOException {
if (this.channels.containsKey(id))
throw new IllegalStateException("There is already a connection for id " + id);
if (this.closingChannels.containsKey(id))
throw new IllegalStateException("There is already a connection for id " + id + " that is still being closed");
SelectionKey key = socketChannel.register(nioSelector, SelectionKey.OP_READ);
- KafkaChannel channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool);
+ buildChannel(socketChannel, id, key);
+ }
+
+ private KafkaChannel buildChannel(SocketChannel socketChannel, String id, SelectionKey key) throws IOException {
+ KafkaChannel channel;
+ try {
+ channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool);
+ } catch (Exception e) {
+ try {
+ socketChannel.close();
+ } finally {
+ key.cancel();
+ }
+ throw new IOException("Channel could not be created for socket " + socketChannel, e);
+ }
key.attach(channel);
this.channels.put(id, channel);
+ return channel;
}
/**
@@ -292,15 +297,24 @@ public class Selector implements Selectable, AutoCloseable {
*/
public void send(Send send) {
String connectionId = send.destination();
- if (closingChannels.containsKey(connectionId))
+ KafkaChannel channel = openOrClosingChannelOrFail(connectionId);
+ if (closingChannels.containsKey(connectionId)) {
+ // ensure notification via `disconnected`, leave channel in the state in which closing was triggered
this.failedSends.add(connectionId);
- else {
- KafkaChannel channel = channelOrFail(connectionId, false);
+ } else {
try {
channel.setSend(send);
- } catch (CancelledKeyException e) {
+ } catch (Exception e) {
+ // update the state for consistency, the channel will be discarded after `close`
+ channel.state(ChannelState.FAILED_SEND);
+ // ensure notification via `disconnected`
this.failedSends.add(connectionId);
close(channel, false);
+ if (!(e instanceof CancelledKeyException)) {
+ log.error("Unexpected exception during send, closing connection {} and rethrowing exception {}",
+ connectionId, e);
+ throw e;
+ }
}
}
}
@@ -535,7 +549,7 @@ public class Selector implements Selectable, AutoCloseable {
@Override
public void mute(String id) {
- KafkaChannel channel = channelOrFail(id, true);
+ KafkaChannel channel = openOrClosingChannelOrFail(id);
mute(channel);
}
@@ -546,7 +560,7 @@ public class Selector implements Selectable, AutoCloseable {
@Override
public void unmute(String id) {
- KafkaChannel channel = channelOrFail(id, true);
+ KafkaChannel channel = openOrClosingChannelOrFail(id);
unmute(channel);
}
@@ -603,12 +617,8 @@ public class Selector implements Selectable, AutoCloseable {
it.remove();
}
}
- for (String channel : this.failedSends) {
- KafkaChannel failedChannel = closingChannels.get(channel);
- if (failedChannel != null)
- failedChannel.state(ChannelState.FAILED_SEND);
+ for (String channel : this.failedSends)
this.disconnected.put(channel, ChannelState.FAILED_SEND);
- }
this.failedSends.clear();
this.madeReadProgressLastPoll = false;
}
@@ -641,6 +651,11 @@ public class Selector implements Selectable, AutoCloseable {
// channel state here anyway to avoid confusion.
channel.state(ChannelState.LOCAL_CLOSE);
close(channel, false);
+ } else {
+ KafkaChannel closingChannel = this.closingChannels.remove(id);
+ // Close any closing channel, leave the channel in the state in which closing was triggered
+ if (closingChannel != null)
+ doClose(closingChannel, false);
}
}
@@ -702,9 +717,9 @@ public class Selector implements Selectable, AutoCloseable {
return channel != null && channel.ready();
}
- private KafkaChannel channelOrFail(String id, boolean maybeClosing) {
+ private KafkaChannel openOrClosingChannelOrFail(String id) {
KafkaChannel channel = this.channels.get(id);
- if (channel == null && maybeClosing)
+ if (channel == null)
channel = this.closingChannels.get(id);
if (channel == null)
throw new IllegalStateException("Attempt to retrieve channel for which there is no connection. Connection id " + id + " existing connections " + channels.keySet());
http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 9671b36..e88a4ee 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -16,6 +16,7 @@
*/
package org.apache.kafka.common.network;
+import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.memory.SimpleMemoryPool;
import org.apache.kafka.common.metrics.Metrics;
@@ -43,6 +44,8 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -87,6 +90,10 @@ public class SelectorTest {
this.metrics.close();
}
+ public SecurityProtocol securityProtocol() {
+ return SecurityProtocol.PLAINTEXT;
+ }
+
/**
* Validate that when the server disconnects, a client send ends up with that node in the disconnected list.
*/
@@ -111,13 +118,20 @@ public class SelectorTest {
/**
* Sending a request with one already in flight should result in an exception
*/
- @Test(expected = IllegalStateException.class)
+ @Test
public void testCantSendWithInProgress() throws Exception {
String node = "0";
blockingConnect(node);
selector.send(createSend(node, "test1"));
- selector.send(createSend(node, "test2"));
- selector.poll(1000L);
+ try {
+ selector.send(createSend(node, "test2"));
+ fail("IllegalStateException not thrown when sending a request with one in flight");
+ } catch (IllegalStateException e) {
+ // Expected exception
+ }
+ selector.poll(0);
+ assertTrue("Channel not closed", selector.disconnected().containsKey(node));
+ assertEquals(ChannelState.FAILED_SEND, selector.disconnected().get(node));
}
/**
@@ -274,6 +288,48 @@ public class SelectorTest {
assertEquals("The response should be from the previously muted node", "1", selector.completedReceives().get(0).source());
}
+ @Test
+ public void registerFailure() throws Exception {
+ ChannelBuilder channelBuilder = new PlaintextChannelBuilder() {
+ @Override
+ public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize,
+ MemoryPool memoryPool) throws KafkaException {
+ throw new RuntimeException("Test exception");
+ }
+ @Override
+ public void close() {
+ }
+ };
+ Selector selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder);
+ SocketChannel socketChannel = SocketChannel.open();
+ socketChannel.configureBlocking(false);
+ try {
+ selector.register("1", socketChannel);
+ fail("Register did not fail");
+ } catch (IOException e) {
+ assertTrue("Unexpected exception: " + e, e.getCause().getMessage().contains("Test exception"));
+ assertFalse("Socket not closed", socketChannel.isOpen());
+ }
+ selector.close();
+ }
+
+ @Test
+ public void testCloseConnectionInClosingState() throws Exception {
+ KafkaChannel channel = createConnectionWithStagedReceives(5);
+ String id = channel.id();
+ time.sleep(6000); // The max idle time is 5000ms
+ selector.poll(0);
+ assertEquals(channel, selector.closingChannel(id));
+ assertNull("Channel not expired", selector.channel(id));
+ assertEquals(ChannelState.EXPIRED, channel.state());
+ selector.close(id);
+ assertNull("Channel not removed from channels", selector.channel(id));
+ assertNull("Channel not removed from closingChannels", selector.closingChannel(id));
+ assertTrue("Unexpected disconnect notification", selector.disconnected().isEmpty());
+ assertEquals(ChannelState.EXPIRED, channel.state());
+ selector.poll(0);
+ assertTrue("Unexpected disconnect notification", selector.disconnected().isEmpty());
+ }
@Test
public void testCloseOldestConnection() throws Exception {
@@ -297,22 +353,32 @@ public class SelectorTest {
verifyCloseOldestConnectionWithStagedReceives(5);
}
- private void verifyCloseOldestConnectionWithStagedReceives(int maxStagedReceives) throws Exception {
+ private KafkaChannel createConnectionWithStagedReceives(int maxStagedReceives) throws Exception {
String id = "0";
blockingConnect(id);
KafkaChannel channel = selector.channel(id);
+ int retries = 100;
- selector.mute(id);
- for (int i = 0; i <= maxStagedReceives; i++) {
- selector.send(createSend(id, String.valueOf(i)));
- selector.poll(1000);
- }
-
- selector.unmute(id);
do {
- selector.poll(1000);
- } while (selector.completedReceives().isEmpty());
+ selector.mute(id);
+ for (int i = 0; i <= maxStagedReceives; i++) {
+ selector.send(createSend(id, String.valueOf(i)));
+ selector.poll(1000);
+ }
+
+ selector.unmute(id);
+ do {
+ selector.poll(1000);
+ } while (selector.completedReceives().isEmpty());
+ } while (selector.numStagedReceives(channel) == 0 && --retries > 0);
+ assertTrue("No staged receives after 100 attempts", selector.numStagedReceives(channel) > 0);
+ return channel;
+ }
+
+ private void verifyCloseOldestConnectionWithStagedReceives(int maxStagedReceives) throws Exception {
+ KafkaChannel channel = createConnectionWithStagedReceives(maxStagedReceives);
+ String id = channel.id();
int stagedReceives = selector.numStagedReceives(channel);
int completedReceives = 0;
while (selector.disconnected().isEmpty()) {
http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
index 46d3b79..bada4d4 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
@@ -75,6 +75,11 @@ public class SslSelectorTest extends SelectorTest {
this.metrics.close();
}
+ @Override
+ public SecurityProtocol securityProtocol() {
+ return SecurityProtocol.PLAINTEXT;
+ }
+
/**
* Tests that SSL renegotiation initiated by the server are handled correctly by the client
* @throws Exception
http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/core/src/main/scala/kafka/network/SocketServer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala
index 3d54c8a..c93fb55 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -35,7 +35,7 @@ import org.apache.kafka.common.errors.InvalidRequestException
import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool}
import org.apache.kafka.common.metrics._
import org.apache.kafka.common.metrics.stats.Rate
-import org.apache.kafka.common.network.{ChannelBuilders, KafkaChannel, ListenerName, Selectable, Send, Selector => KSelector}
+import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, KafkaChannel, ListenerName, Selectable, Send, Selector => KSelector}
import org.apache.kafka.common.protocol.SecurityProtocol
import org.apache.kafka.common.security.auth.KafkaPrincipal
import org.apache.kafka.common.protocol.types.SchemaException
@@ -44,7 +44,7 @@ import org.apache.kafka.common.utils.{KafkaThread, Time}
import scala.collection._
import JavaConverters._
-import scala.util.control.{ControlThrowable, NonFatal}
+import scala.util.control.ControlThrowable
/**
* An NIO socket server. The threading model is
@@ -229,20 +229,6 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ
protected def isRunning: Boolean = alive.get
/**
- * Close the connection identified by `connectionId` and decrement the connection count.
- */
- def close(selector: KSelector, connectionId: String): Unit = {
- val channel = selector.channel(connectionId)
- if (channel != null) {
- debug(s"Closing selector connection $connectionId")
- val address = channel.socketAddress
- if (address != null)
- connectionQuotas.dec(address)
- selector.close(connectionId)
- }
- }
-
- /**
* Close `channel` and decrement the connection count.
*/
def close(channel: SocketChannel): Unit = {
@@ -431,7 +417,10 @@ private[kafka] class Processor(val id: Int,
Map("networkProcessor" -> id.toString)
)
- private val selector = new KSelector(
+ private val selector = createSelector(
+ ChannelBuilders.serverChannelBuilder(listenerName, securityProtocol, config, credentialProvider.credentialCache))
+ // Visible to override for testing
+ protected[network] def createSelector(channelBuilder: ChannelBuilder): KSelector = new KSelector(
maxRequestSize,
connectionsMaxIdleMs,
metrics,
@@ -440,7 +429,7 @@ private[kafka] class Processor(val id: Int,
metricTags,
false,
true,
- ChannelBuilders.serverChannelBuilder(listenerName, securityProtocol, config, credentialProvider.credentialCache),
+ channelBuilder,
memoryPool)
// Connection ids have the format `localAddr:localPort-remoteAddr:remotePort-index`. The index is a
@@ -450,35 +439,53 @@ private[kafka] class Processor(val id: Int,
override def run() {
startupComplete()
- while (isRunning) {
- try {
- // setup any new connections that have been queued up
- configureNewConnections()
- // register any new responses for writing
- processNewResponses()
- poll()
- processCompletedReceives()
- processCompletedSends()
- processDisconnected()
- } catch {
- // We catch all the throwables here to prevent the processor thread from exiting. We do this because
- // letting a processor exit might cause a bigger impact on the broker. Usually the exceptions thrown would
- // be either associated with a specific socket channel or a bad request. We just ignore the bad socket channel
- // or request. This behavior might need to be reviewed if we see an exception that need the entire broker to stop.
- case e: ControlThrowable => throw e
- case e: Throwable =>
- error("Processor got uncaught exception.", e)
+ try {
+ while (isRunning) {
+ try {
+ // setup any new connections that have been queued up
+ configureNewConnections()
+ // register any new responses for writing
+ processNewResponses()
+ poll()
+ processCompletedReceives()
+ processCompletedSends()
+ processDisconnected()
+ } catch {
+ // We catch all the throwables here to prevent the processor thread from exiting. We do this because
+ // letting a processor exit might cause a bigger impact on the broker. This behavior might need to be
+ // reviewed if we see an exception that needs the entire broker to stop. Usually the exceptions thrown would
+ // be either associated with a specific socket channel or a bad request. These exceptions are caught and
+ // processed by the individual methods above which close the failing channel and continue processing other
+ // channels. So this catch block should only ever see ControlThrowables.
+ case e: Throwable => processException("Processor got uncaught exception.", e)
+ }
}
+ } finally {
+ debug("Closing selector - processor " + id)
+ swallowError(closeAll())
+ shutdownComplete()
+ }
+ }
+
+ private def processException(errorMessage: String, throwable: Throwable) {
+ throwable match {
+ case e: ControlThrowable => throw e
+ case e => error(errorMessage, e)
}
+ }
- debug("Closing selector - processor " + id)
- swallowError(closeAll())
- shutdownComplete()
+ private def processChannelException(channelId: String, errorMessage: String, throwable: Throwable) {
+ if (openOrClosingChannel(channelId).isDefined) {
+ error(s"Closing socket for ${channelId} because of error", throwable)
+ close(channelId)
+ }
+ processException(errorMessage, throwable)
}
private def processNewResponses() {
- var curr = requestChannel.receiveResponse(id)
- while (curr != null) {
+ var curr: RequestChannel.Response = null
+ while ({curr = requestChannel.receiveResponse(id); curr != null}) {
+ val channelId = curr.request.context.connectionId
try {
curr.responseAction match {
case RequestChannel.NoOpAction =>
@@ -486,9 +493,7 @@ private[kafka] class Processor(val id: Int,
// that are sitting in the server's socket buffer
updateRequestMetrics(curr)
trace("Socket server received empty response to send, registering for read: " + curr)
- val channelId = curr.request.context.connectionId
- if (selector.channel(channelId) != null || selector.closingChannel(channelId) != null)
- selector.unmute(channelId)
+ openOrClosingChannel(channelId).foreach(c => selector.unmute(c.id))
case RequestChannel.SendAction =>
val responseSend = curr.responseSend.getOrElse(
throw new IllegalStateException(s"responseSend must be defined for SendAction, response: $curr"))
@@ -496,10 +501,11 @@ private[kafka] class Processor(val id: Int,
case RequestChannel.CloseConnectionAction =>
updateRequestMetrics(curr)
trace("Closing socket connection actively according to the response code.")
- close(selector, curr.request.context.connectionId)
+ close(channelId)
}
- } finally {
- curr = requestChannel.receiveResponse(id)
+ } catch {
+ case e: Throwable =>
+ processChannelException(channelId, s"Exception while processing response for $channelId", e)
}
}
}
@@ -526,43 +532,50 @@ private[kafka] class Processor(val id: Int,
try selector.poll(300)
catch {
case e @ (_: IllegalStateException | _: IOException) =>
- error(s"Closing processor $id due to illegal state or IO exception")
- swallow(closeAll())
- shutdownComplete()
- throw e
+ // The exception is not re-thrown and any completed sends/receives/connections/disconnections
+ // from this poll will be processed.
+ error(s"Processor $id poll failed due to illegal state or IO exception")
}
}
private def processCompletedReceives() {
selector.completedReceives.asScala.foreach { receive =>
try {
- val openChannel = selector.channel(receive.source)
- // Only methods that are safe to call on a disconnected channel should be invoked on 'openOrClosingChannel'.
- val openOrClosingChannel = if (openChannel != null) openChannel else selector.closingChannel(receive.source)
- val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, openOrClosingChannel.principal.getName)
- val header = RequestHeader.parse(receive.payload)
- val context = new RequestContext(header, receive.source, openOrClosingChannel.socketAddress,
- principal, listenerName, securityProtocol)
- val req = new RequestChannel.Request(processor = id, context = context,
- startTimeNanos = time.nanoseconds, memoryPool, receive.payload)
- requestChannel.sendRequest(req)
- selector.mute(receive.source)
+ openOrClosingChannel(receive.source) match {
+ case Some(channel) =>
+ val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, channel.principal.getName)
+ val header = RequestHeader.parse(receive.payload)
+ val context = new RequestContext(header, receive.source, channel.socketAddress,
+ principal, listenerName, securityProtocol)
+ val req = new RequestChannel.Request(processor = id, context = context,
+ startTimeNanos = time.nanoseconds, memoryPool, receive.payload)
+ requestChannel.sendRequest(req)
+ selector.mute(receive.source)
+ case None =>
+ // This should never happen since completed receives are processed immediately after `poll()`
+ throw new IllegalStateException(s"Channel ${receive.source} removed from selector before processing completed receive")
+ }
} catch {
- case e @ (_: InvalidRequestException | _: SchemaException) =>
- // note that even though we got an exception, we can assume that receive.source is valid. Issues with constructing a valid receive object were handled earlier
- error(s"Closing socket for ${receive.source} because of error", e)
- close(selector, receive.source)
+ // note that even though we got an exception, we can assume that receive.source is valid.
+ // Issues with constructing a valid receive object were handled earlier
+ case e: Throwable =>
+ processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e)
}
}
}
private def processCompletedSends() {
selector.completedSends.asScala.foreach { send =>
- val resp = inflightResponses.remove(send.destination).getOrElse {
- throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`")
+ try {
+ val resp = inflightResponses.remove(send.destination).getOrElse {
+ throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`")
+ }
+ updateRequestMetrics(resp)
+ selector.unmute(send.destination)
+ } catch {
+ case e: Throwable => processChannelException(send.destination,
+ s"Exception while processing completed send to ${send.destination}", e)
}
- updateRequestMetrics(resp)
- selector.unmute(send.destination)
}
}
@@ -574,12 +587,35 @@ private[kafka] class Processor(val id: Int,
private def processDisconnected() {
selector.disconnected.keySet.asScala.foreach { connectionId =>
- val remoteHost = ConnectionId.fromString(connectionId).getOrElse {
- throw new IllegalStateException(s"connectionId has unexpected format: $connectionId")
- }.remoteHost
- inflightResponses.remove(connectionId).foreach(updateRequestMetrics)
- // the channel has been closed by the selector but the quotas still need to be updated
- connectionQuotas.dec(InetAddress.getByName(remoteHost))
+ try {
+ val remoteHost = ConnectionId.fromString(connectionId).getOrElse {
+ throw new IllegalStateException(s"connectionId has unexpected format: $connectionId")
+ }.remoteHost
+ inflightResponses.remove(connectionId).foreach(updateRequestMetrics)
+ // the channel has been closed by the selector but the quotas still need to be updated
+ connectionQuotas.dec(InetAddress.getByName(remoteHost))
+ } catch {
+ case e: Throwable => processException(s"Exception while processing disconnection of $connectionId", e)
+ }
+ }
+ }
+
+ /**
+ * Close the connection identified by `connectionId` and decrement the connection count.
+ * The channel will be immediately removed from the selector's `channels` or `closingChannels`
+ * and no further disconnect notifications will be sent for this channel by the selector.
+ * If responses are pending for the channel, they are dropped and metrics is updated.
+ * If the channel has already been removed from selector, no action is taken.
+ */
+ private def close(connectionId: String): Unit = {
+ openOrClosingChannel(connectionId).foreach { channel =>
+ debug(s"Closing selector connection $connectionId")
+ val address = channel.socketAddress
+ if (address != null)
+ connectionQuotas.dec(address)
+ selector.close(connectionId)
+
+ inflightResponses.remove(connectionId).foreach(response => updateRequestMetrics(response))
}
}
@@ -601,13 +637,12 @@ private[kafka] class Processor(val id: Int,
debug(s"Processor $id listening to new connection from ${channel.socket.getRemoteSocketAddress}")
selector.register(connectionId(channel.socket), channel)
} catch {
- // We explicitly catch all non fatal exceptions and close the socket to avoid a socket leak. The other
- // throwables will be caught in processor and logged as uncaught exceptions.
- case NonFatal(e) =>
- val remoteAddress = channel.getRemoteAddress
+ // We explicitly catch all exceptions and close the socket to avoid a socket leak.
+ case e: Throwable =>
+ val remoteAddress = channel.socket.getRemoteSocketAddress
// need to close the channel here to avoid a socket leak.
close(channel)
- error(s"Processor $id closed connection from $remoteAddress", e)
+ processException(s"Processor $id closed connection from $remoteAddress", e)
}
}
}
@@ -617,7 +652,7 @@ private[kafka] class Processor(val id: Int,
*/
private def closeAll() {
selector.channels.asScala.foreach { channel =>
- close(selector, channel.id)
+ close(channel.id)
}
selector.close()
}
@@ -633,10 +668,13 @@ private[kafka] class Processor(val id: Int,
connId
}
+ // Only for testing
+ private[network] def inflightResponseCount: Int = inflightResponses.size
+
// Visible for testing
- private[network] def openOrClosingChannel(connectionId: String): Option[KafkaChannel] = {
- Option(selector.channel(connectionId)).orElse(Option(selector.closingChannel(connectionId)))
- }
+ // Only methods that are safe to call on a disconnected channel should be invoked on 'openOrClosingChannel'.
+ private[network] def openOrClosingChannel(connectionId: String): Option[KafkaChannel] =
+ Option(selector.channel(connectionId)).orElse(Option(selector.closingChannel(connectionId)))
/* For test usage */
private[network] def channel(connectionId: String): Option[KafkaChannel] =
http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 9b32635..227714b 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -20,6 +20,7 @@ package kafka.network
import java.io._
import java.net._
import java.nio.ByteBuffer
+import java.nio.channels.SocketChannel
import java.util.{HashMap, Random}
import javax.net.ssl._
@@ -32,7 +33,7 @@ import kafka.utils.TestUtils
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.memory.MemoryPool
import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.network.{KafkaChannel, ListenerName, NetworkSend, Send}
+import org.apache.kafka.common.network.{ChannelBuilder, ChannelState, KafkaChannel, ListenerName, NetworkReceive, NetworkSend, Selector, Send}
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
import org.apache.kafka.common.record.{MemoryRecords, RecordBatch}
import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader}
@@ -43,7 +44,9 @@ import org.junit._
import org.scalatest.junit.JUnitSuite
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.util.control.ControlThrowable
class SocketServerTest extends JUnitSuite {
val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
@@ -58,6 +61,7 @@ class SocketServerTest extends JUnitSuite {
val config = KafkaConfig.fromProps(props)
val metrics = new Metrics
val credentialProvider = new CredentialProvider(config.saslEnabledMechanisms)
+ val localAddress = InetAddress.getLoopbackAddress
// Clean-up any metrics left around by previous tests
for (metricName <- YammerMetrics.defaultRegistry.allMetrics.keySet.asScala)
@@ -116,6 +120,19 @@ class SocketServerTest extends JUnitSuite {
socket
}
+ // Create a client connection, process one request and return (client socket, connectionId)
+ def connectAndProcessRequest(s: SocketServer): (Socket, String) = {
+ val socket = connect(s)
+ val request = sendAndReceiveRequest(socket, s)
+ processRequest(s.requestChannel, request)
+ (socket, request.context.connectionId)
+ }
+
+ def sendAndReceiveRequest(socket: Socket, server: SocketServer): RequestChannel.Request = {
+ sendRequest(socket, producerRequestBytes)
+ receiveRequest(server.requestChannel)
+ }
+
@After
def tearDown() {
metrics.close()
@@ -259,7 +276,9 @@ class SocketServerTest extends JUnitSuite {
val idleTimeMs = 60000
val time = new MockTime()
props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString)
+ props.put("listeners", "PLAINTEXT://localhost:0")
val serverMetrics = new Metrics
+ @volatile var selector: TestableSelector = null
val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0"
val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, time, credentialProvider) {
override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
@@ -267,6 +286,11 @@ class SocketServerTest extends JUnitSuite {
new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas,
config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool) {
override protected[network] def connectionId(socket: Socket): String = overrideConnectionId
+ override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = {
+ val testableSelector = new TestableSelector(config, channelBuilder, time, metrics)
+ selector = testableSelector
+ testableSelector
+ }
}
}
}
@@ -275,15 +299,26 @@ class SocketServerTest extends JUnitSuite {
def openOrClosingChannel: Option[KafkaChannel] = overrideServer.processor(0).openOrClosingChannel(overrideConnectionId)
def connectionCount = overrideServer.connectionCount(InetAddress.getByName("127.0.0.1"))
+ // Create a client connection and wait for server to register the connection with the selector. For
+ // test scenarios below where `Selector.register` fails, the wait ensures that checks are performed
+ // only after `register` is processed by the server.
+ def connectAndWaitForConnectionRegister(): Socket = {
+ val connections = selector.operationCounts(SelectorOperation.Register)
+ val socket = connect(overrideServer)
+ TestUtils.waitUntilTrue(() =>
+ selector.operationCounts(SelectorOperation.Register) == connections + 1, "Connection not registered")
+ socket
+ }
+
try {
overrideServer.startup()
- val socket1 = connect(overrideServer)
+ val socket1 = connectAndWaitForConnectionRegister()
TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to create channel")
val channel1 = openChannel.getOrElse(throw new RuntimeException("Channel not found"))
// Create new connection with same id when `channel1` is still open and in Selector.channels
// Check that new connection is closed and openChannel still contains `channel1`
- connect(overrideServer)
+ connectAndWaitForConnectionRegister()
TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel")
assertSame(channel1, openChannel.getOrElse(throw new RuntimeException("Channel not found")))
@@ -297,7 +332,7 @@ class SocketServerTest extends JUnitSuite {
// Create new connection with same id when when `channel1` is in Selector.closingChannels
// Check that new connection is closed and openOrClosingChannel still contains `channel1`
- connect(overrideServer)
+ connectAndWaitForConnectionRegister()
TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel")
assertSame(channel1, openOrClosingChannel.getOrElse(throw new RuntimeException("Channel not found")))
@@ -306,7 +341,7 @@ class SocketServerTest extends JUnitSuite {
TestUtils.waitUntilTrue(() => connectionCount == 0 && openOrClosingChannel.isEmpty, "Failed to remove channel with failed send")
// Check that new connections can be created with the same id since `channel1` is no longer in Selector
- connect(overrideServer)
+ connectAndWaitForConnectionRegister()
TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to open new channel")
val newChannel = openChannel.getOrElse(throw new RuntimeException("Channel not found"))
assertNotSame(channel1, newChannel)
@@ -593,4 +628,475 @@ class SocketServerTest extends JUnitSuite {
}
}
+ /**
+ * Tests exception handling in [[Processor.configureNewConnections]]. Exception is
+ * injected into [[Selector.register]] which is used to register each new connection.
+ * Test creates two connections in a single iteration by waking up the selector only
+ * when two connections are ready.
+ * Verifies that
+ * - first failed connection is closed
+ * - second connection is processed successfully after the first fails with an exception
+ * - processor is healthy after the exception
+ */
+ @Test
+ def configureNewConnectionException(): Unit = {
+ withTestableServer { testableServer =>
+ val testableSelector = testableServer.testableSelector
+
+ testableSelector.updateMinWakeup(2)
+ testableSelector.addFailure(SelectorOperation.Register)
+ val sockets = (1 to 2).map(_ => connect(testableServer))
+ testableSelector.waitForOperations(SelectorOperation.Register, 2)
+ TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 1, "Failed channel not removed")
+
+ assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+ }
+ }
+
+ /**
+ * Tests exception handling in [[Processor.processNewResponses]]. Exception is
+ * injected into [[Selector.send]] which is used to send the new response.
+ * Test creates two responses in a single iteration by waking up the selector only
+ * when two responses are ready.
+ * Verifies that
+ * - first failed channel is closed
+ * - second response is processed successfully after the first fails with an exception
+ * - processor is healthy after the exception
+ */
+ @Test
+ def processNewResponseException(): Unit = {
+ withTestableServer { testableServer =>
+ val testableSelector = testableServer.testableSelector
+ testableSelector.updateMinWakeup(2)
+
+ val sockets = (1 to 2).map(_ => connect(testableServer))
+ sockets.foreach(sendRequest(_, producerRequestBytes))
+
+ testableServer.testableSelector.addFailure(SelectorOperation.Send)
+ sockets.foreach(_ => processRequest(testableServer.requestChannel))
+ testableSelector.waitForOperations(SelectorOperation.Send, 2)
+ testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
+
+ assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+ }
+ }
+
+ /**
+ * Tests exception handling in [[Processor.processNewResponses]] when [[Selector.send]]
+ * fails with `CancelledKeyException`, which is handled by the selector using a different
+ * code path. Test scenario is similar to [[SocketServerTest.processNewResponseException]].
+ */
+ @Test
+ def sendCancelledKeyException(): Unit = {
+ withTestableServer { testableServer =>
+ val testableSelector = testableServer.testableSelector
+ testableSelector.updateMinWakeup(2)
+
+ val sockets = (1 to 2).map(_ => connect(testableServer))
+ sockets.foreach(sendRequest(_, producerRequestBytes))
+ val requestChannel = testableServer.requestChannel
+
+ val requests = sockets.map(_ => receiveRequest(requestChannel))
+ val failedConnectionId = requests(0).context.connectionId
+ // `KafkaChannel.disconnect()` cancels the selection key, triggering CancelledKeyException during send
+ testableSelector.channel(failedConnectionId).disconnect()
+ requests.foreach(processRequest(requestChannel, _))
+ testableSelector.waitForOperations(SelectorOperation.Send, 2)
+ testableServer.waitForChannelClose(failedConnectionId, locallyClosed = false)
+
+ val successfulSocket = if (isSocketConnectionId(failedConnectionId, sockets(0))) sockets(1) else sockets(0)
+ assertProcessorHealthy(testableServer, Seq(successfulSocket))
+ }
+ }
+
+ /**
+ * Tests exception handling in [[Processor.processNewResponses]] when [[Selector.send]]
+ * to a channel in closing state throws an exception. Test scenario is similar to
+ * [[SocketServerTest.processNewResponseException]].
+ */
+ @Test
+ def closingChannelException(): Unit = {
+ withTestableServer { testableServer =>
+ val testableSelector = testableServer.testableSelector
+ testableSelector.updateMinWakeup(2)
+
+ val sockets = (1 to 2).map(_ => connect(testableServer))
+ val serializedBytes = producerRequestBytes
+ val request = sendRequestsUntilStagedReceive(testableServer, sockets(0), serializedBytes)
+ sendRequest(sockets(1), serializedBytes)
+
+ testableSelector.addFailure(SelectorOperation.Send)
+ sockets(0).close()
+ processRequest(testableServer.requestChannel, request)
+ processRequest(testableServer.requestChannel) // Also process request from other channel
+ testableSelector.waitForOperations(SelectorOperation.Send, 2)
+ testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = true)
+
+ assertProcessorHealthy(testableServer, Seq(sockets(1)))
+ }
+ }
+
+ /**
+ * Tests exception handling in [[Processor.processCompletedReceives]]. Exception is
+ * injected into [[Selector.mute]] which is used to mute the channel when a receive is complete.
+ * Test creates two receives in a single iteration by caching completed receives until two receives
+ * are complete.
+ * Verifies that
+ * - first failed channel is closed
+ * - second receive is processed successfully after the first fails with an exception
+ * - processor is healthy after the exception
+ */
+ @Test
+ def processCompletedReceiveException(): Unit = {
+ withTestableServer { testableServer =>
+ val sockets = (1 to 2).map(_ => connect(testableServer))
+ val testableSelector = testableServer.testableSelector
+ val requestChannel = testableServer.requestChannel
+
+ testableSelector.cachedCompletedReceives.minPerPoll = 2
+ testableSelector.addFailure(SelectorOperation.Mute)
+ sockets.foreach(sendRequest(_, producerRequestBytes))
+ val requests = sockets.map(_ => receiveRequest(requestChannel))
+ testableSelector.waitForOperations(SelectorOperation.Mute, 2)
+ testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
+ requests.foreach(processRequest(requestChannel, _))
+
+ assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+ }
+ }
+
+ /**
+ * Tests exception handling in [[Processor.processCompletedSends]]. Exception is
+ * injected into [[Selector.unmute]] which is used to unmute the channel after send is complete.
+ * Test creates two completed sends in a single iteration by caching completed sends until two
+ * sends are complete.
+ * Verifies that
+ * - first failed channel is closed
+ * - second send is processed successfully after the first fails with an exception
+ * - processor is healthy after the exception
+ */
+ @Test
+ def processCompletedSendException(): Unit = {
+ withTestableServer { testableServer =>
+ val testableSelector = testableServer.testableSelector
+ val sockets = (1 to 2).map(_ => connect(testableServer))
+ val requests = sockets.map(sendAndReceiveRequest(_, testableServer))
+
+ testableSelector.addFailure(SelectorOperation.Unmute)
+ requests.foreach(processRequest(testableServer.requestChannel, _))
+ testableSelector.waitForOperations(SelectorOperation.Unmute, 2)
+ testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
+
+ assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+ }
+ }
+
+ /**
+ * Tests exception handling in [[Processor.processDisconnected]]. An invalid connectionId
+ * is inserted to the disconnected list just before the actual valid one.
+ * Verifies that
+ * - first invalid connectionId is ignored
+ * - second disconnected channel is processed successfully after the first fails with an exception
+ * - processor is healthy after the exception
+ */
+ @Test
+ def processDisconnectedException(): Unit = {
+ withTestableServer { testableServer =>
+ val (socket, connectionId) = connectAndProcessRequest(testableServer)
+ val testableSelector = testableServer.testableSelector
+
+ // Add an invalid connectionId to `Selector.disconnected` list before the actual disconnected channel
+ // and check that the actual id is processed and the invalid one ignored.
+ testableSelector.cachedDisconnected.minPerPoll = 2
+ testableSelector.cachedDisconnected.deferredValues += "notAValidConnectionId" -> ChannelState.EXPIRED
+ socket.close()
+ testableSelector.operationCounts.clear()
+ testableSelector.waitForOperations(SelectorOperation.Poll, 1)
+ testableServer.waitForChannelClose(connectionId, locallyClosed = false)
+
+ assertProcessorHealthy(testableServer)
+ }
+ }
+
+ /**
+ * Tests that `Processor` continues to function correctly after a failed [[Selector.poll]].
+ */
+ @Test
+ def pollException(): Unit = {
+ withTestableServer { testableServer =>
+ val (socket, _) = connectAndProcessRequest(testableServer)
+ val testableSelector = testableServer.testableSelector
+
+ testableSelector.addFailure(SelectorOperation.Poll)
+ testableSelector.operationCounts.clear()
+ testableSelector.waitForOperations(SelectorOperation.Poll, 2)
+
+ assertProcessorHealthy(testableServer, Seq(socket))
+ }
+ }
+
+ /**
+ * Tests handling of `ControlThrowable`. Verifies that the selector is closed.
+ */
+ @Test
+ def controlThrowable(): Unit = {
+ withTestableServer { testableServer =>
+ val (socket, _) = connectAndProcessRequest(testableServer)
+ val testableSelector = testableServer.testableSelector
+
+ testableSelector.operationCounts.clear()
+ testableSelector.addFailure(SelectorOperation.Poll,
+ Some(new RuntimeException("ControlThrowable exception during poll()") with ControlThrowable))
+ testableSelector.waitForOperations(SelectorOperation.Poll, 1)
+
+ testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1)
+ }
+ }
+
+ private def withTestableServer(testWithServer: TestableSocketServer => Unit): Unit = {
+ props.put("listeners", "PLAINTEXT://localhost:0")
+ val testableServer = new TestableSocketServer
+ testableServer.startup()
+ try {
+ testWithServer(testableServer)
+ } finally {
+ testableServer.shutdown()
+ testableServer.metrics.close()
+ }
+ }
+
+ private def assertProcessorHealthy(testableServer: TestableSocketServer, healthySockets: Seq[Socket] = Seq.empty): Unit = {
+ val selector = testableServer.testableSelector
+ selector.reset()
+ val requestChannel = testableServer.requestChannel
+
+ // Check that existing channels behave as expected
+ healthySockets.foreach { socket =>
+ val request = sendAndReceiveRequest(socket, testableServer)
+ processRequest(requestChannel, request)
+ socket.close()
+ }
+ TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 0, "Channels not removed")
+
+ // Check new channel behaves as expected
+ val (socket, connectionId) = connectAndProcessRequest(testableServer)
+ assertArrayEquals(producerRequestBytes, receiveResponse(socket))
+ assertNotNull("Channel should not have been closed", selector.channel(connectionId))
+ assertNull("Channel should not be closing", selector.closingChannel(connectionId))
+ socket.close()
+ TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 0, "Channels not removed")
+ }
+
+ // Since all sockets use the same local host, it is sufficient to check the local port
+ def isSocketConnectionId(connectionId: String, socket: Socket): Boolean =
+ connectionId.contains(s":${socket.getLocalPort}-")
+
+ class TestableSocketServer extends SocketServer(KafkaConfig.fromProps(props),
+ new Metrics, Time.SYSTEM, credentialProvider) {
+
+ @volatile var selector: Option[TestableSelector] = None
+
+ override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
+ protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = {
+ new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas,
+ config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool) {
+
+ override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = {
+ val testableSelector = new TestableSelector(config, channelBuilder, time, metrics)
+ assertEquals(None, selector)
+ selector = Some(testableSelector)
+ testableSelector
+ }
+ }
+ }
+
+ def testableSelector: TestableSelector =
+ selector.getOrElse(throw new IllegalStateException("Selector not created"))
+
+ def waitForChannelClose(connectionId: String, locallyClosed: Boolean): Unit = {
+ val selector = testableSelector
+ if (locallyClosed) {
+ TestUtils.waitUntilTrue(() => selector.allLocallyClosedChannels.contains(connectionId),
+ s"Channel not closed: $connectionId")
+ assertTrue("Unexpected disconnect notification", testableSelector.allDisconnectedChannels.isEmpty)
+ } else {
+ TestUtils.waitUntilTrue(() => selector.allDisconnectedChannels.contains(connectionId),
+ s"Disconnect notification not received: $connectionId")
+ assertTrue("Channel closed locally", testableSelector.allLocallyClosedChannels.isEmpty)
+ }
+ val openCount = selector.allChannels.size - 1 // minus one for the channel just closed above
+ TestUtils.waitUntilTrue(() => connectionCount(localAddress) == openCount, "Connection count not decremented")
+ TestUtils.waitUntilTrue(() =>
+ processor(0).inflightResponseCount == 0, "Inflight responses not cleared")
+ assertNull("Channel not removed", selector.channel(connectionId))
+ assertNull("Closing channel not removed", selector.closingChannel(connectionId))
+ }
+ }
+
+ sealed trait SelectorOperation
+ object SelectorOperation {
+ case object Register extends SelectorOperation
+ case object Poll extends SelectorOperation
+ case object Send extends SelectorOperation
+ case object Mute extends SelectorOperation
+ case object Unmute extends SelectorOperation
+ case object Wakeup extends SelectorOperation
+ case object Close extends SelectorOperation
+ case object CloseSelector extends SelectorOperation
+ }
+
+ class TestableSelector(config: KafkaConfig, channelBuilder: ChannelBuilder, time: Time, metrics: Metrics)
+ extends Selector(config.socketRequestMaxBytes, config.connectionsMaxIdleMs,
+ metrics, time, "socket-server", new HashMap, false, true, channelBuilder, MemoryPool.NONE) {
+
+ val failures = mutable.Map[SelectorOperation, Exception]()
+ val operationCounts = mutable.Map[SelectorOperation, Int]().withDefaultValue(0)
+ val allChannels = mutable.Set[String]()
+ val allLocallyClosedChannels = mutable.Set[String]()
+ val allDisconnectedChannels = mutable.Set[String]()
+ val allFailedChannels = mutable.Set[String]()
+
+ // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until
+ // the number of elements of that type reaches `minPerPoll`. This enables tests to verify
+ // that failed processing doesn't impact subsequent processing within the same iteration.
+ class PollData[T] {
+ var minPerPoll = 1
+ val deferredValues = mutable.Buffer[T]()
+ val currentPollValues = mutable.Buffer[T]()
+ def update(newValues: mutable.Buffer[T]): Unit = {
+ if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) {
+ if (deferredValues.nonEmpty) {
+ currentPollValues ++= deferredValues
+ deferredValues.clear()
+ }
+ currentPollValues ++= newValues
+ } else
+ deferredValues ++= newValues
+ }
+ def reset(): Unit = {
+ currentPollValues.clear()
+ }
+ }
+ val cachedCompletedReceives = new PollData[NetworkReceive]()
+ val cachedCompletedSends = new PollData[Send]()
+ val cachedDisconnected = new PollData[(String, ChannelState)]()
+ val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected)
+ @volatile var minWakeupCount = 0
+ @volatile var pollTimeoutOverride: Option[Long] = None
+
+ def addFailure(operation: SelectorOperation, exception: Option[Exception] = None) {
+ failures += operation ->
+ exception.getOrElse(new IllegalStateException(s"Test exception during $operation"))
+ }
+
+ private def onOperation(operation: SelectorOperation,
+ connectionId: Option[String] = None, onFailure: => Unit = {}): Unit = {
+ operationCounts(operation) += 1
+ failures.remove(operation).foreach { e =>
+ connectionId.foreach(allFailedChannels.add)
+ onFailure
+ throw e
+ }
+ }
+
+ def waitForOperations(operation: SelectorOperation, minExpectedTotal: Int): Unit = {
+ TestUtils.waitUntilTrue(() =>
+ operationCounts.getOrElse(operation, 0) >= minExpectedTotal, "Operations not performed within timeout")
+ }
+
+ def runOp[T](operation: SelectorOperation, connectionId: Option[String],
+ onFailure: => Unit = {})(code: => T): T = {
+ // If a failure is set on `operation`, throw that exception even if `code` fails
+ try code
+ finally onOperation(operation, connectionId, onFailure)
+ }
+
+ override def register(id: String, socketChannel: SocketChannel): Unit = {
+ runOp(SelectorOperation.Register, Some(id), onFailure = close(id)) {
+ super.register(id, socketChannel)
+ }
+ }
+
+ override def send(s: Send): Unit = {
+ runOp(SelectorOperation.Send, Some(s.destination)) {
+ super.send(s)
+ }
+ }
+
+ override def poll(timeout: Long): Unit = {
+ try {
+ allCachedPollData.foreach(_.reset)
+ runOp(SelectorOperation.Poll, None) {
+ super.poll(pollTimeoutOverride.getOrElse(timeout))
+ }
+ } finally {
+ super.channels.asScala.foreach(allChannels += _.id)
+ allDisconnectedChannels ++= super.disconnected.asScala.keys
+ cachedCompletedReceives.update(super.completedReceives.asScala)
+ cachedCompletedSends.update(super.completedSends.asScala)
+ cachedDisconnected.update(super.disconnected.asScala.toBuffer)
+ }
+ }
+
+ override def mute(id: String): Unit = {
+ runOp(SelectorOperation.Mute, Some(id)) {
+ super.mute(id)
+ }
+ }
+
+ override def unmute(id: String): Unit = {
+ runOp(SelectorOperation.Unmute, Some(id)) {
+ super.unmute(id)
+ }
+ }
+
+ override def wakeup(): Unit = {
+ runOp(SelectorOperation.Wakeup, None) {
+ if (minWakeupCount > 0)
+ minWakeupCount -= 1
+ if (minWakeupCount <= 0)
+ super.wakeup()
+ }
+ }
+
+ override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava
+
+ override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava
+
+ override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava
+
+ override def close(id: String): Unit = {
+ runOp(SelectorOperation.Close, Some(id)) {
+ super.close(id)
+ allLocallyClosedChannels += id
+ }
+ }
+
+ override def close(): Unit = {
+ runOp(SelectorOperation.CloseSelector, None) {
+ super.close()
+ }
+ }
+
+ def updateMinWakeup(count: Int): Unit = {
+ minWakeupCount = count
+ // For tests that ignore wakeup to process responses together, increase poll timeout
+ // to ensure that poll doesn't complete before the responses are ready
+ pollTimeoutOverride = Some(1000L)
+ // Wakeup current poll to force new poll timeout to take effect
+ super.wakeup()
+ }
+
+ def reset(): Unit = {
+ failures.clear()
+ allCachedPollData.foreach(_.minPerPoll = 1)
+ }
+
+ def notFailed(sockets: Seq[Socket]): Seq[Socket] = {
+ // Each test generates failure for exactly one failed channel
+ assertEquals(1, allFailedChannels.size)
+ val failedConnectionId = allFailedChannels.head
+ sockets.filterNot(socket => isSocketConnectionId(failedConnectionId, socket))
+ }
+ }
}