You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by rs...@apache.org on 2020/05/29 08:44:50 UTC

[kafka] branch 2.5 updated: KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705)

This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.5 by this push:
     new 9ee465f  KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705)
9ee465f is described below

commit 9ee465f8ccc5933ea10d6cb211c5045010840fb5
Author: Rajini Sivaram <ra...@googlemail.com>
AuthorDate: Fri May 29 09:32:57 2020 +0100

    KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705)
    
    Reviewers: Ismael Juma <is...@juma.me.uk>, Chia-Ping Tsai <ch...@gmail.com>
---
 .../apache/kafka/common/network/KafkaChannel.java  |  5 +-
 .../org/apache/kafka/common/network/Selector.java  | 35 +++++++--
 .../apache/kafka/common/network/SelectorTest.java  | 85 +++++++++++++++++++++
 .../main/scala/kafka/network/SocketServer.scala    |  4 +-
 .../unit/kafka/network/SocketServerTest.scala      | 87 +++++++++++++++++-----
 5 files changed, 188 insertions(+), 28 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
index 4e4edd4..0ed9ee0 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
@@ -28,7 +28,6 @@ import java.net.Socket;
 import java.net.SocketAddress;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
-import java.util.Objects;
 import java.util.Optional;
 import java.util.function.Supplier;
 
@@ -471,12 +470,12 @@ public class KafkaChannel implements AutoCloseable {
             return false;
         }
         KafkaChannel that = (KafkaChannel) o;
-        return Objects.equals(id, that.id);
+        return id.equals(that.id);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(id);
+        return id.hashCode();
     }
 
     @Override
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index cb91cad..06f7048 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -107,7 +107,7 @@ public class Selector implements Selectable, AutoCloseable {
     private final Set<KafkaChannel> explicitlyMutedChannels;
     private boolean outOfMemory;
     private final List<Send> completedSends;
-    private final LinkedHashMap<KafkaChannel, NetworkReceive> completedReceives;
+    private final LinkedHashMap<String, NetworkReceive> completedReceives;
     private final Set<SelectionKey> immediatelyConnectedKeys;
     private final Map<String, KafkaChannel> closingChannels;
     private Set<SelectionKey> keysWithBufferedRead;
@@ -804,7 +804,33 @@ public class Selector implements Selectable, AutoCloseable {
     }
 
     /**
-     * Clear the results from the prior poll
+     * Clears completed receives. This is used by SocketServer to remove references to
+     * receive buffers after processing completed receives, without waiting for the next
+     * poll().
+     */
+    public void clearCompletedReceives() {
+        this.completedReceives.clear();
+    }
+
+    /**
+     * Clears completed sends. This is used by SocketServer to remove references to
+     * send buffers after processing completed sends, without waiting for the next
+     * poll().
+     */
+    public void clearCompletedSends() {
+        this.completedSends.clear();
+    }
+
+    /**
+     * Clears all the results from the previous poll. This is invoked by Selector at the start of
+     * a poll() when all the results from the previous poll are expected to have been handled.
+     * <p>
+     * SocketServer uses {@link #clearCompletedSends()} and {@link #clearCompletedSends()} to
+     * clear `completedSends` and `completedReceives` as soon as they are processed to avoid
+     * holding onto large request/response buffers from multiple connections longer than necessary.
+     * Clients rely on Selector invoking {@link #clear()} at the start of each poll() since memory usage
+     * is less critical and clearing once-per-poll provides the flexibility to process these results in
+     * any order before the next poll.
      */
     private void clear() {
         this.completedSends.clear();
@@ -935,7 +961,6 @@ public class Selector implements Selectable, AutoCloseable {
         }
 
         this.sensors.connectionClosed.record();
-        this.completedReceives.remove(channel);
         this.explicitlyMutedChannels.remove(channel);
         if (notifyDisconnect)
             this.disconnected.put(channel.id(), channel.state());
@@ -1015,7 +1040,7 @@ public class Selector implements Selectable, AutoCloseable {
      * Check if given channel has a completed receive
      */
     private boolean hasCompletedReceive(KafkaChannel channel) {
-        return completedReceives.containsKey(channel);
+        return completedReceives.containsKey(channel.id());
     }
 
     /**
@@ -1025,7 +1050,7 @@ public class Selector implements Selectable, AutoCloseable {
         if (hasCompletedReceive(channel))
             throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id());
 
-        this.completedReceives.put(channel, networkReceive);
+        this.completedReceives.put(channel.id(), networkReceive);
         sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs);
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 57b0153..ac773ee 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -48,6 +48,7 @@ import java.nio.channels.SocketChannel;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -341,6 +342,36 @@ public class SelectorTest {
         assertEquals("", blockingRequest(node, ""));
     }
 
+    @Test
+    public void testClearCompletedSendsAndReceives() throws Exception {
+        int bufferSize = 1024;
+        String node = "0";
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        connect(node, addr);
+        String request = TestUtils.randomString(bufferSize);
+        selector.send(createSend(node, request));
+        boolean sent = false;
+        boolean received = false;
+        while (!sent || !received) {
+            selector.poll(1000L);
+            assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size());
+            if (!selector.completedSends().isEmpty()) {
+                assertEquals(1, selector.completedSends().size());
+                selector.clearCompletedSends();
+                assertEquals(0, selector.completedSends().size());
+                sent = true;
+            }
+
+            if (!selector.completedReceives().isEmpty()) {
+                assertEquals(1, selector.completedReceives().size());
+                assertEquals(request, asString(selector.completedReceives().iterator().next()));
+                selector.clearCompletedReceives();
+                assertEquals(0, selector.completedReceives().size());
+                received = true;
+            }
+        }
+    }
+
     @Test(expected = IllegalStateException.class)
     public void testExistingConnectionId() throws IOException {
         blockingConnect("0");
@@ -904,6 +935,60 @@ public class SelectorTest {
         assertEquals(asList(send), selector.completedSends());
     }
 
+    /**
+     * Ensure that no errors are thrown if channels are closed while processing multiple completed receives
+     */
+    @Test
+    public void testChannelCloseWhileProcessingReceives() throws Exception {
+        int numChannels = 4;
+        Map<String, KafkaChannel> channels = TestUtils.fieldValue(selector, Selector.class, "channels");
+        Set<SelectionKey> selectionKeys = new HashSet<>();
+        for (int i = 0; i < numChannels; i++) {
+            String id = String.valueOf(i);
+            KafkaChannel channel = mock(KafkaChannel.class);
+            channels.put(id, channel);
+            when(channel.id()).thenReturn(id);
+            when(channel.state()).thenReturn(ChannelState.READY);
+            when(channel.isConnected()).thenReturn(true);
+            when(channel.ready()).thenReturn(true);
+            when(channel.read()).thenReturn(1L);
+
+            SelectionKey selectionKey = mock(SelectionKey.class);
+            when(channel.selectionKey()).thenReturn(selectionKey);
+            when(selectionKey.isValid()).thenReturn(true);
+            when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ);
+            selectionKey.attach(channel);
+            selectionKeys.add(selectionKey);
+
+            NetworkReceive receive = mock(NetworkReceive.class);
+            when(receive.source()).thenReturn(id);
+            when(receive.size()).thenReturn(10);
+            when(receive.bytesRead()).thenReturn(1);
+            when(receive.payload()).thenReturn(ByteBuffer.allocate(10));
+            when(channel.maybeCompleteReceive()).thenReturn(receive);
+        }
+
+        selector.pollSelectionKeys(selectionKeys, false, System.nanoTime());
+        assertEquals(numChannels, selector.completedReceives().size());
+        Set<KafkaChannel> closed = new HashSet<>();
+        Set<KafkaChannel> notClosed = new HashSet<>();
+        for (NetworkReceive receive : selector.completedReceives()) {
+            KafkaChannel channel = selector.channel(receive.source());
+            assertNotNull(channel);
+            if (closed.size() < 2) {
+                selector.close(channel.id());
+                closed.add(channel);
+            } else
+                notClosed.add(channel);
+        }
+        assertEquals(notClosed, new HashSet<>(selector.channels()));
+        closed.forEach(channel -> assertNull(selector.channel(channel.id())));
+
+        selector.poll(0);
+        assertEquals(0, selector.completedReceives().size());
+    }
+
+
     private String blockingRequest(String node, String s) throws IOException {
         selector.send(createSend(node, s));
         selector.poll(1000L);
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala
index 35d9d7c..fb69ab9 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -779,7 +779,7 @@ private[kafka] class Processor(val id: Int,
     }
   }
 
-  private def processException(errorMessage: String, throwable: Throwable): Unit = {
+  private[network] def processException(errorMessage: String, throwable: Throwable): Unit = {
     throwable match {
       case e: ControlThrowable => throw e
       case e => error(errorMessage, e)
@@ -915,6 +915,7 @@ private[kafka] class Processor(val id: Int,
           processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e)
       }
     }
+    selector.clearCompletedReceives()
   }
 
   private def processCompletedSends(): Unit = {
@@ -938,6 +939,7 @@ private[kafka] class Processor(val id: Int,
           s"Exception while processing completed send to ${send.destination}", e)
       }
     }
+    selector.clearCompletedSends()
   }
 
   private def updateRequestMetrics(response: RequestChannel.Response): Unit = {
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 91fce5b..b40c763 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -1570,6 +1570,8 @@ class SocketServerTest {
       testableSelector.waitForOperations(SelectorOperation.Poll, 1)
 
       testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1)
+      assertEquals(1, testableServer.uncaughtExceptions)
+      testableServer.uncaughtExceptions = 0
     })
   }
 
@@ -1648,6 +1650,7 @@ class SocketServerTest {
         testWithServer(testableServer)
     } finally {
       shutdownServerAndMetrics(testableServer)
+      assertEquals(0, testableServer.uncaughtExceptions)
     }
   }
 
@@ -1702,6 +1705,7 @@ class SocketServerTest {
       new Metrics, time, credentialProvider) {
 
     @volatile var selector: Option[TestableSelector] = None
+    @volatile var uncaughtExceptions = 0
 
     override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
                                 protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = {
@@ -1714,6 +1718,12 @@ class SocketServerTest {
            selector = Some(testableSelector)
            testableSelector
         }
+
+        override private[network] def processException(errorMessage: String, throwable: Throwable): Unit = {
+          if (errorMessage.contains("uncaught exception"))
+            uncaughtExceptions += 1
+          super.processException(errorMessage, throwable)
+        }
       }
     }
 
@@ -1766,12 +1776,19 @@ class SocketServerTest {
     // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until
     // the number of elements of that type reaches `minPerPoll`. This enables tests to verify
     // that failed processing doesn't impact subsequent processing within the same iteration.
-    class PollData[T] {
+    abstract class PollData[T] {
       var minPerPoll = 1
       val deferredValues = mutable.Buffer[T]()
-      val currentPollValues = mutable.Buffer[T]()
-      def update(newValues: mutable.Buffer[T]): Unit = {
-        if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) {
+
+      /**
+       * Process new results and return the results for the current poll if at least
+       * `minPerPoll` results are available including any deferred results. Otherwise
+       * add the provided values to the deferred set and return an empty buffer. This allows
+       * tests to process `minPerPoll` elements as the results of a single poll iteration.
+       */
+      protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = {
+        val currentPollValues = mutable.Buffer[T]()
+        if (deferredValues.size + newValues.size >= minPerPoll) {
           if (deferredValues.nonEmpty) {
             currentPollValues ++= deferredValues
             deferredValues.clear()
@@ -1779,14 +1796,49 @@ class SocketServerTest {
           currentPollValues ++= newValues
         } else
           deferredValues ++= newValues
+
+        currentPollValues
       }
-      def reset(): Unit = {
-        currentPollValues.clear()
+
+      /**
+       * Process results from the appropriate buffer in Selector and update the buffer to either
+       * defer and return nothing or return all results including previously deferred values.
+       */
+      def updateResults(): Unit
+    }
+
+    class CompletedReceivesPollData(selector: TestableSelector) extends PollData[NetworkReceive] {
+      val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(selector, classOf[Selector], "completedReceives")
+
+      override def updateResults(): Unit = {
+        val currentReceives = update(selector.completedReceives.asScala.toBuffer)
+        completedReceivesMap.clear()
+        currentReceives.foreach { receive =>
+          val channelOpt = Option(selector.channel(receive.source)).orElse(Option(selector.closingChannel(receive.source)))
+          channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) }
+        }
       }
     }
-    val cachedCompletedReceives = new PollData[NetworkReceive]()
-    val cachedCompletedSends = new PollData[Send]()
-    val cachedDisconnected = new PollData[(String, ChannelState)]()
+
+    class CompletedSendsPollData(selector: TestableSelector) extends PollData[Send] {
+      override def updateResults(): Unit = {
+        val currentSends = update(selector.completedSends.asScala)
+        selector.completedSends.clear()
+        currentSends.foreach { selector.completedSends.add }
+      }
+    }
+
+    class DisconnectedPollData(selector: TestableSelector) extends PollData[(String, ChannelState)] {
+      override def updateResults(): Unit = {
+        val currentDisconnected = update(selector.disconnected.asScala.toBuffer)
+        selector.disconnected.clear()
+        currentDisconnected.foreach { case (channelId, state) => selector.disconnected.put(channelId, state) }
+      }
+    }
+
+    val cachedCompletedReceives = new CompletedReceivesPollData(this)
+    val cachedCompletedSends = new CompletedSendsPollData(this)
+    val cachedDisconnected = new DisconnectedPollData(this)
     val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected)
     val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]()
     @volatile var minWakeupCount = 0
@@ -1833,20 +1885,23 @@ class SocketServerTest {
 
     override def poll(timeout: Long): Unit = {
       try {
+        assertEquals(0, super.completedReceives().size)
+        assertEquals(0, super.completedSends().size)
+
         pollCallback.apply()
         while (!pendingClosingChannels.isEmpty) {
           makeClosing(pendingClosingChannels.poll())
         }
-        allCachedPollData.foreach(_.reset)
         runOp(SelectorOperation.Poll, None) {
           super.poll(pollTimeoutOverride.getOrElse(timeout))
         }
       } finally {
         super.channels.asScala.foreach(allChannels += _.id)
         allDisconnectedChannels ++= super.disconnected.asScala.keys
-        cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer)
-        cachedCompletedSends.update(super.completedSends.asScala)
-        cachedDisconnected.update(super.disconnected.asScala.toBuffer)
+
+        cachedCompletedReceives.updateResults()
+        cachedCompletedSends.updateResults()
+        cachedDisconnected.updateResults()
       }
     }
 
@@ -1871,12 +1926,6 @@ class SocketServerTest {
       }
     }
 
-    override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava
-
-    override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava
-
-    override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava
-
     override def close(id: String): Unit = {
       runOp(SelectorOperation.Close, Some(id)) {
         super.close(id)